import os import time import torch import torchaudio from transformers import ( Wav2Vec2Processor, HubertForCTC, WhisperProcessor, WhisperForConditionalGeneration, Wav2Vec2ForCTC, AutoProcessor, AutoModelForCTC ) from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # === Helper: move all tensors to model device === def to_device(batch, device): if isinstance(batch, dict): return {k: v.to(device) for k, v in batch.items()} elif isinstance(batch, torch.Tensor): return batch.to(device) return batch # === Setup: Load all 3 models === # 1. Base HuBERT base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device).eval() # 2. Whisper + phonemizer whisper_proc = WhisperProcessor.from_pretrained("openai/whisper-base") whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device).eval() # 3. My Hubert Model (optional HF token via env) HF_TOKEN = os.environ.get("HF_TOKEN", None) # print(HF_TOKEN) proc = Wav2Vec2Processor.from_pretrained("tecasoftai/hubert-finetune", token=HF_TOKEN) model = HubertForCTC.from_pretrained("tecasoftai/hubert-finetune", token=HF_TOKEN).to(device).eval() # 4. wav2vec2-xls-r-300m-timit-phoneme # load model and processor timit_proc = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme") timit_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme").to(device).eval() # 5 bookbot/wav2vec2-ljspeech-gruut gruut_processor = AutoProcessor.from_pretrained("bookbot/wav2vec2-ljspeech-gruut") gruut_model = AutoModelForCTC.from_pretrained("bookbot/wav2vec2-ljspeech-gruut").to(device).eval() # 6 microsoft/wavlm-large-english-phoneme wavlm_proc = AutoProcessor.from_pretrained("speech31/wavlm-large-english-phoneme") wavlm_model = AutoModelForCTC.from_pretrained("speech31/wavlm-large-english-phoneme").to(device).eval() # === Inference functions === def run_hubert_base(wav): start = time.time() inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values inputs = inputs.to(device) with torch.no_grad(): logits = base_model(inputs).logits ids = torch.argmax(logits, dim=-1) text = base_proc.batch_decode(ids)[0] # Convert to phonemes (CMU-like string without stresses) phonemes = text_to_phoneme(text) phonemes = cmu_to_ipa(phonemes) return phonemes.strip(), time.time() - start def run_whisper(wav): start = time.time() inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt") input_features = inputs.input_features.to(device) attention_mask = inputs.get("attention_mask", None) gen_kwargs = {"language": "en"} if attention_mask is not None: gen_kwargs["attention_mask"] = attention_mask.to(device) with torch.no_grad(): pred_ids = whisper_model.generate(input_features, **gen_kwargs) text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0] phonemes = text_to_phoneme(text) phonemes = cmu_to_ipa(phonemes) return phonemes.strip(), time.time() - start def run_model(wav): start = time.time() # Prepare input (BatchEncoding supports .to(device)) inputs = proc(wav, sampling_rate=16000, return_tensors="pt").to(device) # Forward pass with torch.no_grad(): logits = model(**inputs).logits # Greedy decode ids = torch.argmax(logits, dim=-1) phonemes = proc.batch_decode(ids)[0] phonemes = cmu_to_ipa(phonemes) return phonemes.strip(), time.time() - start def run_timit(wav): start = time.time() # Read and process the input inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True) inputs = inputs.to(device) # Forward pass with torch.no_grad(): logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits # Decode id into string predicted_ids = torch.argmax(logits, axis=-1) phonemes = timit_proc.batch_decode(predicted_ids) phonemes = "".join(phonemes) return phonemes.strip(), time.time() - start def run_gruut(wav): start = time.time() # Preprocess waveform → model input inputs = gruut_processor( wav, sampling_rate=16000, return_tensors="pt", padding=True ).to(device) # Forward pass with torch.no_grad(): logits = gruut_model(**inputs).logits # Greedy decode → IPA phonemes pred_ids = torch.argmax(logits, dim=-1) phonemes = gruut_processor.batch_decode(pred_ids)[0] phonemes = "".join(phonemes) return phonemes.strip(), time.time() - start def run_wavlm_large_phoneme(wav): start = time.time() # Preprocess waveform → model input inputs = wavlm_proc( wav, sampling_rate=16000, return_tensors="pt", padding=True ).to(device) input_values = inputs.input_values attention_mask = inputs.get("attention_mask", None) # Forward pass with torch.no_grad(): logits = wavlm_model(input_values, attention_mask=attention_mask).logits # Greedy decode → phoneme tokens pred_ids = torch.argmax(logits, dim=-1) phonemes = wavlm_proc.batch_decode(pred_ids)[0] phonemes = "".join(phonemes) return phonemes.strip(), time.time() - start