Reihaneh's picture
Update app.py
1ff84be verified
raw
history blame
4.46 kB
import os, torch, gradio as gr
from transformers import AutoProcessor, AutoModelForCTC
# --- CONFIG ---
MODEL_ID = os.getenv("MODEL_ID", "Reihaneh/wav2vec2_fy_nl_best_frisian_1")
HF_TOKEN = os.getenv("HF_TOKEN") # only if the repo is private
# Faster hub downloads (optional but nice if you also add the package 'hf-transfer')
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# --- LOAD ONCE (local cache persists across Space restarts) ---
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCTC.from_pretrained(MODEL_ID, token=HF_TOKEN, torch_dtype=dtype)
model.to(device).eval()
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
def transcribe(audio):
if audio is None:
return ""
sr, wav = audio # gr.Audio(type="numpy") returns (sr, np.ndarray)
if wav.ndim == 2: # mono
wav = wav[:, 0]
# quick resample if needed (basic, avoids extra deps)
if sr != target_sr:
# naive linear resample to avoid importing librosa
import numpy as np
import math
ratio = target_sr / sr
idx = (np.arange(int(math.ceil(wav.shape[0] * ratio))) / ratio).astype(int)
idx = idx[idx < wav.shape[0]]
wav = wav[idx]
inputs = processor(wav, sampling_rate=target_sr, return_tensors="pt")
with torch.inference_mode():
logits = model(inputs.input_values.to(device)).logits
ids = torch.argmax(logits, dim=-1)
text = processor.batch_decode(ids)[0]
return text
with gr.Blocks(title="Frisian ASR") as demo:
gr.Markdown("## πŸŽ™οΈ Frisian ASR\nUpload audio or use your mic.")
audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Audio")
out = gr.Textbox(label="Transcript")
gr.Button("Transcribe").click(transcribe, inputs=audio, outputs=out)
# optional: auto-run when audio changes
# audio.change(transcribe, inputs=audio, outputs=out, every=0)
demo.queue(concurrency_count=1, max_size=16)
if __name__ == "__main__":
demo.launch()
#import gradio as gr
#gr.Interface.load("models/Reihaneh/wav2vec2_fy_nl_best_frisian_1").launch()
#import gradio as gr
#gr.load("Reihaneh/wav2vec2_fy_nl_best_frisian_1").launch(share=True)
'''
import os, torch, gradio as gr
from transformers import pipeline
import torch
# optional: keep memory low on CPU Spaces
if not torch.cuda.is_available():
torch.set_num_threads(max(1, os.cpu_count() // 2))
MODEL_ID = os.getenv("MODEL_ID", "Reihaneh/wav2vec2_fy_nl_best_frisian_1") # set in Space Secrets or hardcode
HF_TOKEN = os.getenv("HF_TOKEN") # only needed if model is private
asr = pipeline(
task="automatic-speech-recognition",
model=MODEL_ID,
# common safe defaults; adjust for your model
chunk_length_s=20, # streaming-style chunking for long audio
stride_length_s=(4, 2),
device_map="auto", # uses GPU if available
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
token=HF_TOKEN, # works for private models
)
def transcribe(audio, language=None):
"""
audio comes as (sr, np.ndarray) from gr.Audio when type='numpy'.
If your model expects specific sampling rate, set 'sampling_rate' below.
"""
if audio is None:
return ""
result = asr(
audio,
generate_kwargs={"task": "transcribe"} if "whisper" in MODEL_ID.lower() else None,
return_timestamps=False,
)
# result is either a dict with 'text' or a string depending on transformers version
return result["text"] if isinstance(result, dict) and "text" in result else str(result)
with gr.Blocks(title="ASR Demo") as demo:
gr.Markdown("# πŸŽ™οΈ ASR Demo\nUpload a file or speak into your mic.")
with gr.Row():
audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Audio")
lang = gr.Dropdown(choices=[None,"fy-NL","nl-NL","en-US"], value=None, label="(Optional) Language hint")
out = gr.Textbox(label="Transcription")
btn = gr.Button("Transcribe")
btn.click(fn=transcribe, inputs=[audio, lang], outputs=out)
# optional: live/auto mode (non-streaming)
audio.change(fn=transcribe, inputs=[audio, lang], outputs=out, every=0)
demo.queue(max_size=12, concurrency_count=1).launch()'''