Xenobd's picture
Update app.py
4d1a87b verified
import os
import gradio as gr
import numpy as np
import soundfile as sf
import tempfile
import time
import json
from typing import List, Tuple
# --------------------------
# Repo ensure (clone + LFS pull)
# --------------------------
REPO_URL = "https://huggingface.co/Supertone/supertonic"
TARGET_DIR = "supertonic" # folder name after clone
def run_cmd(cmd: str) -> int:
print(f"[CMD] {cmd}")
return os.system(cmd)
print("=== Checking Supertonic repo ===")
if not os.path.exists(TARGET_DIR):
print("[+] Cloning repo (LFS pointers only)...")
run_cmd("git lfs install")
ret = run_cmd(f"GIT_LFS_SKIP_SMUDGE=1 git clone {REPO_URL} {TARGET_DIR}")
if ret != 0:
raise RuntimeError("git clone failed")
else:
print("[βœ“] Repo already exists. Skipping clone.")
# Pull LFS assets (real ONNX files). If user doesn't want this, remove this line.
print("[+] Pulling LFS files (this will download ONNX models; skip if already pulled)...")
run_cmd(f"cd {TARGET_DIR} && git lfs pull")
# --------------------------
# Make repo importable
# --------------------------
import sys
sys.path.insert(0, os.path.abspath(TARGET_DIR))
# --------------------------
# Import your TTS code from repo
# --------------------------
from tts_model import (
load_text_to_speech,
load_voice_style,
sanitize_filename,
chunk_text,
)
# --------------------------
# Discover available voice styles
# --------------------------
VOICE_STYLES_DIR = os.path.join(TARGET_DIR, "voice_styles")
def list_voice_styles(styles_dir: str = VOICE_STYLES_DIR) -> List[str]:
if not os.path.exists(styles_dir):
return []
files = sorted(
[f for f in os.listdir(styles_dir) if f.lower().endswith(".json")]
)
return files
available_styles = list_voice_styles()
if not available_styles:
print("No voice styles found in", VOICE_STYLES_DIR)
else:
print("Found voice styles:", available_styles)
# --------------------------
# Load TTS model once
# --------------------------
ONNX_DIR = os.path.join(TARGET_DIR, "onnx")
TOTAL_STEP = 15
print("Loading TTS model...")
tts = load_text_to_speech(ONNX_DIR) # may take a while
# --------------------------
# Helper: load a single style by filename (returns Style)
# --------------------------
def load_style_by_name(filename: str):
if not filename:
raise ValueError("No style selected")
path = os.path.join(VOICE_STYLES_DIR, filename)
if not os.path.exists(path):
raise FileNotFoundError(f"Style file not found: {path}")
# load_voice_style expects list of paths
return load_voice_style([path])
# --------------------------
# Voice style descriptions
# --------------------------
VOICE_DESCRIPTIONS = {
"F1.json": "Female Voice 1 - Clear and professional",
"F2.json": "Female Voice 2 - Warm and expressive",
"M1.json": "Male Voice 1 - Deep and authoritative",
"M2.json": "Male Voice 2 - Casual and friendly"
}
# --------------------------
# Continuous Streaming TTS Generator
# --------------------------
def run_tts_stream(text: str, speed: float, style_name: str):
"""
Generator that yields continuous audio stream as chunks are generated.
"""
try:
if not text or not text.strip():
yield None, "❌ Text cannot be empty."
return
try:
style = load_style_by_name(style_name)
except Exception as e:
yield None, f"❌ Failed to load voice style: {e}"
return
chunks = chunk_text(text)
total_chunks = len(chunks)
yield None, f"🟑 Starting generation: {total_chunks} chunk(s) to process..."
# Create a temporary file for streaming
temp_dir = tempfile.mkdtemp()
stream_file = os.path.join(temp_dir, "stream_audio.wav")
all_audio_chunks = []
for idx, chunk in enumerate(chunks, start=1):
yield None, f"⏳ Generating chunk {idx}/{total_chunks}..."
# Generate the audio chunk
wav, dur = tts._infer([chunk], style, TOTAL_STEP, float(speed))
audio = wav.squeeze()
all_audio_chunks.append(audio)
# Extract duration as scalar
if hasattr(dur, '__len__'):
dur_scalar = float(dur[0]) if len(dur) > 0 else 0.0
else:
dur_scalar = float(dur)
# Concatenate all chunks so far
current_audio = np.concatenate(all_audio_chunks) if len(all_audio_chunks) > 1 else all_audio_chunks[0]
# Save current state to temporary file
sf.write(stream_file, current_audio, tts.sample_rate)
# Yield the file path - Gradio will stream this continuously
yield stream_file, f"πŸ”Š Playing... Chunk {idx}/{total_chunks} ready ({dur_scalar:.1f}s)"
# Final update with complete audio
total_duration = len(current_audio) / tts.sample_rate
yield stream_file, f"πŸŽ‰ Generation complete! Total duration: {total_duration:.1f}s"
except Exception as e:
yield None, f"❌ Error: {type(e).__name__}: {e}"
# --------------------------
# Full generation endpoint
# --------------------------
def run_tts_full(text: str, speed: float, style_name: str):
if not text or not text.strip():
return None, "❌ Text cannot be empty."
try:
style = load_style_by_name(style_name)
wav_cat, dur_cat = tts(text=text, style=style, total_step=TOTAL_STEP, speed=float(speed), silence_duration=0.2)
audio = wav_cat.squeeze()
if hasattr(dur_cat, '__len__'):
dur_scalar = float(dur_cat[0]) if len(dur_cat) > 0 else 0.0
else:
dur_scalar = float(dur_cat)
tmp_path = tempfile.mktemp(suffix=".wav")
sf.write(tmp_path, audio, tts.sample_rate)
return tmp_path, f"βœ… Generated successfully! Duration: {dur_scalar:.1f}s"
except Exception as e:
return None, f"❌ Error: {type(e).__name__}: {e}"
# --------------------------
# Professional UI
# --------------------------
def ui():
with gr.Blocks(title="Supertonic TTS Studio") as demo:
# Custom CSS
gr.HTML("""
<style>
.gradio-container {
font-family: 'Inter', 'Segoe UI', system-ui, sans-serif;
max-width: 1200px;
margin: 0 auto;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
border-radius: 12px;
margin-bottom: 2rem;
color: white;
text-align: center;
}
.header h1 {
margin: 0;
font-size: 2.5rem;
font-weight: 700;
}
.header p {
margin: 0.5rem 0 0 0;
opacity: 0.9;
font-size: 1.1rem;
}
.card {
background: white;
padding: 1.5rem;
border-radius: 12px;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
border: 1px solid #e5e7eb;
margin-bottom: 1.5rem;
}
.card h2 {
margin-top: 0;
color: #374151;
font-size: 1.5rem;
font-weight: 600;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
color: white;
padding: 1rem 2rem;
border-radius: 8px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
font-size: 1.1rem;
width: 100%;
}
.btn-primary:hover {
transform: translateY(-2px);
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4);
}
.btn-secondary {
background: #f8fafc;
border: 2px solid #e2e8f0;
color: #475569;
padding: 1rem 2rem;
border-radius: 8px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
font-size: 1.1rem;
width: 100%;
}
.btn-secondary:hover {
background: #f1f5f9;
border-color: #cbd5e1;
transform: translateY(-1px);
}
.stats {
background: #f0f9ff;
border: 1px solid #bae6fd;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
}
.progress-info {
background: #fef3c7;
border: 1px solid #fcd34d;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
}
</style>
""")
# Header Section
gr.HTML("""
<div class="header">
<h1>πŸŽ™οΈ Supertonic TTS Studio</h1>
<p>Professional Text-to-Speech with Real-time Continuous Streaming</p>
</div>
""")
with gr.Row():
# Left Column - Input Controls
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("## πŸ“ Text Input")
text_input = gr.Textbox(
label="",
placeholder="Enter your text here... (Supports multiple paragraphs)",
lines=8,
max_lines=12,
show_label=False
)
with gr.Row():
char_count = gr.Number(
label="Character Count",
value=0,
interactive=False,
precision=0
)
word_count = gr.Number(
label="Word Count",
value=0,
interactive=False,
precision=0
)
# Middle Column - Voice Settings
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("## 🎡 Voice Settings")
voice_style = gr.Dropdown(
label="Voice Style",
choices=available_styles,
value=available_styles[0] if available_styles else None,
info="Choose your preferred voice style"
)
# Voice description
voice_description = gr.Markdown(
value=VOICE_DESCRIPTIONS.get(available_styles[0], "Select a voice style") if available_styles else "No voices available"
)
speed_slider = gr.Slider(
label="Speaking Speed",
minimum=0.5,
maximum=2.0,
value=1.05,
step=0.05,
info="Adjust the speech rate (0.5x to 2.0x)"
)
with gr.Group():
gr.Markdown("## ⚑ Generation Mode")
with gr.Row():
stream_btn = gr.Button(
"🎀 Start Streaming",
size="lg",
variant="primary"
)
with gr.Row():
full_btn = gr.Button(
"πŸ’Ύ Generate & Download",
size="lg",
variant="secondary"
)
# Output Section
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
gr.Markdown("## πŸ”Š Real-time Streaming Output")
with gr.Group():
streaming_info = gr.Textbox(
label="Generation Status",
value="Ready to generate...",
interactive=False,
show_label=False
)
# Streaming audio with filepath for continuous playback
streaming_audio = gr.Audio(
label="Live Continuous Stream",
type="filepath",
show_label=False,
autoplay=True
)
with gr.Group():
gr.Markdown("## πŸ“₯ Download Output")
with gr.Group():
download_info = gr.Textbox(
label="Download Status",
value="Generate a full version to download",
interactive=False,
show_label=False
)
download_audio = gr.Audio(
label="Complete Audio File",
type="filepath",
show_label=False
)
# Update character and word counts
def update_counts(text):
chars = len(text) if text else 0
words = len(text.split()) if text else 0
return chars, words
text_input.change(
fn=update_counts,
inputs=[text_input],
outputs=[char_count, word_count]
)
# Update voice description
def update_voice_description(style_name):
description = VOICE_DESCRIPTIONS.get(style_name, "No description available")
return f"**Voice Description:** {description}"
voice_style.change(
fn=update_voice_description,
inputs=[voice_style],
outputs=[voice_description]
)
# Connect buttons to functions
stream_btn.click(
fn=run_tts_stream,
inputs=[text_input, speed_slider, voice_style],
outputs=[streaming_audio, streaming_info],
show_progress=False
)
full_btn.click(
fn=run_tts_full,
inputs=[text_input, speed_slider, voice_style],
outputs=[download_audio, download_info]
)
return demo
app = ui()
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)