|
|
|
|
|
import os, torch, gradio as gr |
|
|
from transformers import AutoProcessor, AutoModelForCTC |
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "Reihaneh/wav2vec2_fy_nl_best_frisian_1") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) |
|
|
model = AutoModelForCTC.from_pretrained(MODEL_ID, token=HF_TOKEN, torch_dtype=dtype) |
|
|
model.to(device).eval() |
|
|
|
|
|
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000) |
|
|
|
|
|
def transcribe(audio): |
|
|
if audio is None: |
|
|
return "" |
|
|
sr, wav = audio |
|
|
if wav.ndim == 2: |
|
|
wav = wav[:, 0] |
|
|
|
|
|
if sr != target_sr: |
|
|
|
|
|
import numpy as np |
|
|
import math |
|
|
ratio = target_sr / sr |
|
|
idx = (np.arange(int(math.ceil(wav.shape[0] * ratio))) / ratio).astype(int) |
|
|
idx = idx[idx < wav.shape[0]] |
|
|
wav = wav[idx] |
|
|
|
|
|
inputs = processor(wav, sampling_rate=target_sr, return_tensors="pt") |
|
|
with torch.inference_mode(): |
|
|
logits = model(inputs.input_values.to(device)).logits |
|
|
ids = torch.argmax(logits, dim=-1) |
|
|
text = processor.batch_decode(ids)[0] |
|
|
return text |
|
|
|
|
|
with gr.Blocks(title="Frisian ASR") as demo: |
|
|
gr.Markdown("## ποΈ Frisian ASR\nUpload audio or use your mic.") |
|
|
audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Audio") |
|
|
out = gr.Textbox(label="Transcript") |
|
|
gr.Button("Transcribe").click(transcribe, inputs=audio, outputs=out) |
|
|
|
|
|
|
|
|
|
|
|
demo.queue(concurrency_count=1, max_size=16) |
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
import os, torch, gradio as gr |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
|
|
|
# optional: keep memory low on CPU Spaces |
|
|
if not torch.cuda.is_available(): |
|
|
torch.set_num_threads(max(1, os.cpu_count() // 2)) |
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "Reihaneh/wav2vec2_fy_nl_best_frisian_1") # set in Space Secrets or hardcode |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") # only needed if model is private |
|
|
|
|
|
asr = pipeline( |
|
|
task="automatic-speech-recognition", |
|
|
model=MODEL_ID, |
|
|
# common safe defaults; adjust for your model |
|
|
chunk_length_s=20, # streaming-style chunking for long audio |
|
|
stride_length_s=(4, 2), |
|
|
device_map="auto", # uses GPU if available |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
token=HF_TOKEN, # works for private models |
|
|
) |
|
|
|
|
|
def transcribe(audio, language=None): |
|
|
""" |
|
|
audio comes as (sr, np.ndarray) from gr.Audio when type='numpy'. |
|
|
If your model expects specific sampling rate, set 'sampling_rate' below. |
|
|
""" |
|
|
if audio is None: |
|
|
return "" |
|
|
result = asr( |
|
|
audio, |
|
|
generate_kwargs={"task": "transcribe"} if "whisper" in MODEL_ID.lower() else None, |
|
|
return_timestamps=False, |
|
|
) |
|
|
# result is either a dict with 'text' or a string depending on transformers version |
|
|
return result["text"] if isinstance(result, dict) and "text" in result else str(result) |
|
|
|
|
|
with gr.Blocks(title="ASR Demo") as demo: |
|
|
gr.Markdown("# ποΈ ASR Demo\nUpload a file or speak into your mic.") |
|
|
with gr.Row(): |
|
|
audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Audio") |
|
|
lang = gr.Dropdown(choices=[None,"fy-NL","nl-NL","en-US"], value=None, label="(Optional) Language hint") |
|
|
out = gr.Textbox(label="Transcription") |
|
|
btn = gr.Button("Transcribe") |
|
|
btn.click(fn=transcribe, inputs=[audio, lang], outputs=out) |
|
|
# optional: live/auto mode (non-streaming) |
|
|
audio.change(fn=transcribe, inputs=[audio, lang], outputs=out, every=0) |
|
|
|
|
|
demo.queue(max_size=12, concurrency_count=1).launch()''' |
|
|
|