Nguyen Anh Hong commited on
Commit
d4c10cd
·
1 Parent(s): 863c2e6
Files changed (2) hide show
  1. app.py +3 -0
  2. sample.py +112 -0
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
@@ -40,3 +42,4 @@ iface = gr.Interface(
40
  if __name__ == "__main__":
41
  iface.launch()
42
 
 
 
1
+ """
2
+
3
  import gradio as gr
4
  import torch
5
  import torchaudio
 
42
  if __name__ == "__main__":
43
  iface.launch()
44
 
45
+ """
sample.py CHANGED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import torchaudio
5
+ import gradio as gr
6
+ from transformers import (
7
+ Wav2Vec2Processor, HubertForCTC,
8
+ WhisperProcessor, WhisperForConditionalGeneration
9
+ )
10
+ from phonemizer import phonemize
11
+ import difflib
12
+
13
+ # === Setup: Load all 3 models ===
14
+
15
+ # 1. Base HuBERT
16
+ base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
17
+ base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").eval()
18
+
19
+ # 2. Whisper + phonemizer
20
+ whisper_proc = WhisperProcessor.from_pretrained("openai/whisper-base")
21
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").eval()
22
+
23
+ # 3. My Hubert Model
24
+ token = os.environ.get("HF_TOKEN")
25
+ your_proc = Wav2Vec2Processor.from_pretrained("tecasoftai/hubert-finetune", token=token)
26
+ your_model = HubertForCTC.from_pretrained("tecasoftai/hubert-finetune", token=token).eval()
27
+
28
+ # === Helper ===
29
+
30
+ def load_audio(filepath):
31
+ waveform, sr = torchaudio.load(filepath)
32
+ if sr != 16000:
33
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
34
+ return waveform.squeeze()
35
+
36
+ def calc_per(pred, ref):
37
+ pred_list = pred.strip().split()
38
+ ref_list = ref.strip().split()
39
+ sm = difflib.SequenceMatcher(None, ref_list, pred_list)
40
+ dist = sum(tr[-1] for tr in sm.get_opcodes() if tr[0] != 'equal')
41
+ if len(ref_list) == 0:
42
+ return 0.0
43
+ return round(100 * dist / len(ref_list), 2)
44
+
45
+ # === Inference functions ===
46
+
47
+ def run_hubert_base(wav):
48
+ start = time.time()
49
+ inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt")
50
+ with torch.no_grad():
51
+ logits = base_model(**inputs).logits
52
+ ids = torch.argmax(logits, dim=-1)
53
+ phonemes = base_proc.batch_decode(ids)[0]
54
+ return phonemes, time.time() - start
55
+
56
+ def run_whisper(wav):
57
+ start = time.time()
58
+ inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt")
59
+ with torch.no_grad():
60
+ ids = whisper_model.generate(inputs["input_features"])
61
+ text = whisper_proc.batch_decode(ids, skip_special_tokens=True)[0]
62
+ phonemes = phonemize(text, language='en-us', backend='espeak')
63
+ return phonemes, time.time() - start
64
+
65
+ def run_your_model(wav):
66
+ start = time.time()
67
+ inputs = your_proc(wav, sampling_rate=16000, return_tensors="pt")
68
+ with torch.no_grad():
69
+ logits = your_model(**inputs).logits
70
+ ids = torch.argmax(logits, dim=-1)
71
+ phonemes = your_proc.batch_decode(ids)[0]
72
+ return phonemes, time.time() - start
73
+
74
+ # === Main Gradio function ===
75
+
76
+ def benchmark_all(audio_path, reference_phoneme):
77
+ wav = load_audio(audio_path)
78
+
79
+ results = []
80
+
81
+ # 1. HuBERT Base
82
+ phonemes, dur = run_hubert_base(wav)
83
+ per = calc_per(phonemes, reference_phoneme)
84
+ results.append(["HuBERT-Base", phonemes, f"{dur:.2f}s", f"{per}%"])
85
+
86
+ # 2. Whisper
87
+ phonemes, dur = run_whisper(wav)
88
+ per = calc_per(phonemes, reference_phoneme)
89
+ results.append(["Whisper + Phonemizer", phonemes, f"{dur:.2f}s", f"{per}%"])
90
+
91
+ # 3. My Hubert model
92
+ phonemes, dur = run_your_model(wav)
93
+ per = calc_per(phonemes, reference_phoneme)
94
+ results.append(["Your HuBERT (fine-tuned)", phonemes, f"{dur:.2f}s", f"{per}%"])
95
+
96
+ return results
97
+
98
+ # === UI ===
99
+
100
+ demo = gr.Interface(
101
+ fn=benchmark_all,
102
+ inputs=[
103
+ gr.Audio(type="filepath", label="Upload Audio"),
104
+ gr.Textbox(label="Ground-truth Phonemes (space-separated)", placeholder="f ə n ə m aɪ z")
105
+ ],
106
+ outputs=gr.Dataframe(headers=["Model", "Phoneme Output", "Inference Time", "PER (%)"]),
107
+ title="Phoneme Recognition Benchmark",
108
+ description="Compare HuBERT-Base, Whisper, and your fine-tuned model on phoneme recognition."
109
+ )
110
+
111
+ if __name__ == "__main__":
112
+ demo.launch()