Reihaneh commited on
Commit
9dc2a58
·
verified ·
1 Parent(s): a698d06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -2
app.py CHANGED
@@ -44,7 +44,7 @@ def _cheap_resample(wav, sr, target_sr):
44
  idx = idx[idx < wav.shape[0]]
45
  return wav[idx]
46
 
47
- def transcribe(audio):
48
  if audio is None:
49
  return ""
50
  sr, x = audio
@@ -56,7 +56,47 @@ def transcribe(audio):
56
  logits = model(inputs.input_values.to(device)).logits
57
  ids = torch.argmax(logits, dim=-1)
58
  text = processor.batch_decode(ids)[0]
59
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  with gr.Blocks(title="Frisian ASR") as demo:
62
  gr.Markdown("## 🎙️ Frisian ASR")
 
44
  idx = idx[idx < wav.shape[0]]
45
  return wav[idx]
46
 
47
+ '''def transcribe(audio):
48
  if audio is None:
49
  return ""
50
  sr, x = audio
 
56
  logits = model(inputs.input_values.to(device)).logits
57
  ids = torch.argmax(logits, dim=-1)
58
  text = processor.batch_decode(ids)[0]
59
+ return text'''
60
+
61
+
62
+ def transcribe(a):
63
+ try:
64
+ if a is None:
65
+ return ""
66
+ sr, x = a # if you use a helper, just make sure you end up with (sr, np.ndarray)
67
+
68
+ # 1) mono + sanitize + FORCE float32
69
+ import numpy as np, math
70
+ if x.ndim == 2:
71
+ x = x.mean(axis=1)
72
+ x = np.nan_to_num(x).astype(np.float32)
73
+
74
+ # 2) (optional) cheap resample to your processor’s SR
75
+ target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
76
+ if sr != target_sr:
77
+ ratio = target_sr / float(sr)
78
+ n = int(math.ceil(len(x) * ratio))
79
+ idx = (np.arange(n) / ratio).astype(np.int64)
80
+ idx = np.clip(idx, 0, len(x) - 1)
81
+ x = x[idx]
82
+
83
+ # 3) tokenize → cast inputs to DEVICE + MODEL DTYPE
84
+ inputs = processor(x, sampling_rate=target_sr, return_tensors="pt", padding=True)
85
+ input_values = inputs.input_values.to(device)
86
+
87
+ # >>> KEY LINE: match model dtype (prevents "Input type (double) and bias type should be the same")
88
+ input_values = input_values.to(model.dtype)
89
+
90
+ with torch.inference_mode():
91
+ logits = model(input_values).logits
92
+ ids = torch.argmax(logits, dim=-1)
93
+ text = processor.batch_decode(ids)[0]
94
+ return text
95
+ except Exception as e:
96
+ import traceback
97
+ print(traceback.format_exc())
98
+ return f"⚠️ Error: {e}"
99
+
100
 
101
  with gr.Blocks(title="Frisian ASR") as demo:
102
  gr.Markdown("## 🎙️ Frisian ASR")