|
|
import os, torch, gradio as gr |
|
|
from transformers import ( |
|
|
AutoModelForCTC, |
|
|
AutoProcessor, |
|
|
Wav2Vec2Processor, |
|
|
Wav2Vec2FeatureExtractor, |
|
|
Wav2Vec2CTCTokenizer, |
|
|
) |
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "Reihaneh/wav2vec2_fy_nl_best_frisian_1") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
processor = None |
|
|
try: |
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) |
|
|
except Exception as e: |
|
|
print("AutoProcessor failed, building Wav2Vec2Processor manually:", e) |
|
|
|
|
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
|
|
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor( |
|
|
feature_size=1, |
|
|
sampling_rate=16000, |
|
|
padding_value=0.0, |
|
|
do_normalize=True, |
|
|
return_attention_mask=True, |
|
|
) |
|
|
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
|
|
model = AutoModelForCTC.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval() |
|
|
|
|
|
|
|
|
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000) |
|
|
|
|
|
def _cheap_resample(wav, sr, target_sr): |
|
|
if sr == target_sr: |
|
|
return wav |
|
|
import numpy as np, math |
|
|
ratio = target_sr / sr |
|
|
idx = (np.arange(int(math.ceil(wav.shape[0] * ratio))) / ratio).astype(int) |
|
|
idx = idx[idx < wav.shape[0]] |
|
|
return wav[idx] |
|
|
|
|
|
'''def transcribe(audio): |
|
|
if audio is None: |
|
|
return "" |
|
|
sr, x = audio |
|
|
if x.ndim == 2: # stereo -> mono |
|
|
x = x[:, 0] |
|
|
x = _cheap_resample(x, sr, target_sr) |
|
|
inputs = processor(x, sampling_rate=target_sr, return_tensors="pt", padding=True) |
|
|
with torch.inference_mode(): |
|
|
logits = model(inputs.input_values.to(device)).logits |
|
|
ids = torch.argmax(logits, dim=-1) |
|
|
text = processor.batch_decode(ids)[0] |
|
|
return text''' |
|
|
|
|
|
|
|
|
def transcribe(a): |
|
|
try: |
|
|
if a is None: |
|
|
return "" |
|
|
sr, x = a |
|
|
|
|
|
|
|
|
import numpy as np, math |
|
|
if x.ndim == 2: |
|
|
x = x.mean(axis=1) |
|
|
x = np.nan_to_num(x).astype(np.float32) |
|
|
|
|
|
|
|
|
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000) |
|
|
if sr != target_sr: |
|
|
ratio = target_sr / float(sr) |
|
|
n = int(math.ceil(len(x) * ratio)) |
|
|
idx = (np.arange(n) / ratio).astype(np.int64) |
|
|
idx = np.clip(idx, 0, len(x) - 1) |
|
|
x = x[idx] |
|
|
|
|
|
|
|
|
inputs = processor(x, sampling_rate=target_sr, return_tensors="pt", padding=True) |
|
|
input_values = inputs.input_values.to(device) |
|
|
|
|
|
|
|
|
input_values = input_values.to(model.dtype) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
logits = model(input_values).logits |
|
|
ids = torch.argmax(logits, dim=-1) |
|
|
text = processor.batch_decode(ids)[0] |
|
|
return text |
|
|
except Exception as e: |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
return f"⚠️ Error: {e}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Frisian ASR") as demo: |
|
|
gr.Markdown("## 🎙️ Frisian ASR") |
|
|
audio = gr.Audio(sources=["microphone","upload"], type="numpy", label="Audio") |
|
|
out = gr.Textbox(label="Transcript") |
|
|
gr.Button("Transcribe").click(transcribe, inputs=audio, outputs=out) |
|
|
|
|
|
|
|
|
demo.queue().launch() |
|
|
|