| from __future__ import annotations |
|
|
| import contextlib |
| import io |
| import os |
| from pathlib import Path |
| from typing import List, Tuple |
|
|
| import gradio as gr |
| import torch |
| import spaces |
|
|
| from dia2 import Dia2, GenerationConfig, SamplingConfig |
|
|
| DEFAULT_REPO = os.environ.get("DIA2_DEFAULT_REPO", "nari-labs/Dia2-2B") |
| MAX_TURNS = 10 |
| INITIAL_TURNS = 2 |
|
|
| _dia: Dia2 | None = None |
|
|
|
|
| def _get_dia() -> Dia2: |
| global _dia |
| if _dia is None: |
| _dia = Dia2.from_repo(DEFAULT_REPO, device="cuda", dtype="bfloat16") |
| return _dia |
|
|
|
|
| def _concat_script(turn_count: int, turn_values: List[str]) -> str: |
| lines: List[str] = [] |
| for idx in range(min(turn_count, len(turn_values))): |
| text = (turn_values[idx] or "").strip() |
| if not text: |
| continue |
| speaker = "[S1]" if idx % 2 == 0 else "[S2]" |
| lines.append(f"{speaker} {text}") |
| return "\n".join(lines) |
|
|
|
|
| EXAMPLES: dict[str, dict[str, List[str] | str | None]] = { |
| "Intro": { |
| "turns": [ |
| "Hello Dia2 fans! Today we're unveiling the new open TTS model.", |
| "Sounds exciting. Can you show a sample right now?", |
| "Absolutely. (laughs) Just press generate.", |
| ], |
| "voice_s1": "example_prefix1.wav", |
| "voice_s2": "example_prefix2.wav", |
| }, |
| "Customer Support": { |
| "turns": [ |
| "Thanks for calling. How can I help you today?", |
| "My parcel never arrived and it's been two weeks.", |
| "I'm sorry about that. Let me check your tracking number.", |
| "Appreciate it. I really need that package soon.", |
| ], |
| "voice_s1": "example_prefix1.wav", |
| "voice_s2": "example_prefix2.wav", |
| }, |
| } |
|
|
|
|
| def _apply_turn_visibility(count: int) -> List[gr.Update]: |
| return [gr.update(visible=i < count) for i in range(MAX_TURNS)] |
|
|
|
|
| def _add_turn(count: int): |
| count = min(count + 1, MAX_TURNS) |
| return (count, *_apply_turn_visibility(count)) |
|
|
|
|
| def _remove_turn(count: int): |
| count = max(1, count - 1) |
| return (count, *_apply_turn_visibility(count)) |
|
|
|
|
| def _load_example(name: str, count: int): |
| data = EXAMPLES.get(name) |
| if not data: |
| return (count, *_apply_turn_visibility(count), None, None) |
| turns = data.get("turns", []) |
| voice_s1_path = data.get("voice_s1") |
| voice_s2_path = data.get("voice_s2") |
| new_count = min(len(turns), MAX_TURNS) |
| updates: List[gr.Update] = [] |
| for idx in range(MAX_TURNS): |
| if idx < new_count: |
| updates.append(gr.update(value=turns[idx], visible=True)) |
| else: |
| updates.append(gr.update(value="", visible=idx < INITIAL_TURNS)) |
| return (new_count, *updates, voice_s1_path, voice_s2_path) |
|
|
|
|
| def _prepare_prefix(file_path: str | None) -> str | None: |
| if not file_path: |
| return None |
| path = Path(file_path) |
| if not path.exists(): |
| return None |
| return str(path) |
|
|
|
|
| @spaces.GPU(duration=100) |
| def generate_audio( |
| turn_count: int, |
| *inputs, |
| ): |
| turn_values = list(inputs[:MAX_TURNS]) |
| voice_s1 = inputs[MAX_TURNS] |
| voice_s2 = inputs[MAX_TURNS + 1] |
| cfg_scale = float(inputs[MAX_TURNS + 2]) |
| text_temperature = float(inputs[MAX_TURNS + 3]) |
| audio_temperature = float(inputs[MAX_TURNS + 4]) |
| text_top_k = int(inputs[MAX_TURNS + 5]) |
| audio_top_k = int(inputs[MAX_TURNS + 6]) |
| include_prefix = bool(inputs[MAX_TURNS + 7]) |
|
|
| script = _concat_script(turn_count, turn_values) |
| if not script.strip(): |
| raise gr.Error("Please enter at least one non-empty speaker turn.") |
|
|
| dia = _get_dia() |
| config = GenerationConfig( |
| cfg_scale=cfg_scale, |
| text=SamplingConfig(temperature=text_temperature, top_k=text_top_k), |
| audio=SamplingConfig(temperature=audio_temperature, top_k=audio_top_k), |
| use_cuda_graph=True, |
| ) |
| kwargs = { |
| "prefix_speaker_1": _prepare_prefix(voice_s1), |
| "prefix_speaker_2": _prepare_prefix(voice_s2), |
| "include_prefix": include_prefix, |
| } |
| buffer = io.StringIO() |
| with contextlib.redirect_stdout(buffer): |
| result = dia.generate( |
| script, |
| config=config, |
| output_wav=None, |
| verbose=True, |
| **kwargs, |
| ) |
| waveform = result.waveform.detach().cpu().numpy() |
| sample_rate = result.sample_rate |
| timestamps = result.timestamps |
| log_text = buffer.getvalue().strip() |
| table = [[w, round(t, 3)] for w, t in timestamps] |
| return (sample_rate, waveform), table, log_text or "Generation finished." |
|
|
|
|
| def build_interface() -> gr.Blocks: |
| with gr.Blocks( |
| title="Dia2 TTS", css=".compact-turn textarea {min-height: 60px}" |
| ) as demo: |
| gr.Markdown( |
| """## Dia2 — Open TTS Model |
| Compose dialogue, attach optional voice prompts, and generate audio (CUDA graphs enabled by default).""" |
| ) |
| turn_state = gr.State(INITIAL_TURNS) |
| with gr.Row(equal_height=True): |
| example_dropdown = gr.Dropdown( |
| choices=["(select example)"] + list(EXAMPLES.keys()), |
| label="Examples", |
| value="(select example)", |
| ) |
| with gr.Row(equal_height=True): |
| with gr.Column(scale=1): |
| with gr.Group(): |
| gr.Markdown("### Script") |
| controls = [] |
| for idx in range(MAX_TURNS): |
| speaker = "[S1]" if idx % 2 == 0 else "[S2]" |
| box = gr.Textbox( |
| label=f"{speaker} turn {idx + 1}", |
| lines=2, |
| elem_classes=["compact-turn"], |
| placeholder=f"Enter dialogue for {speaker}…", |
| visible=idx < INITIAL_TURNS, |
| ) |
| controls.append(box) |
| with gr.Row(): |
| add_btn = gr.Button("Add Turn") |
| remove_btn = gr.Button("Remove Turn") |
| with gr.Group(): |
| gr.Markdown("### Voice Prompts") |
| with gr.Row(): |
| voice_s1 = gr.File( |
| label="[S1] voice (wav/mp3)", type="filepath" |
| ) |
| voice_s2 = gr.File( |
| label="[S2] voice (wav/mp3)", type="filepath" |
| ) |
| with gr.Group(): |
| gr.Markdown("### Sampling") |
| cfg_scale = gr.Slider( |
| 1.0, 8.0, value=6.0, step=0.1, label="CFG Scale" |
| ) |
| with gr.Group(): |
| gr.Markdown("#### Text Sampling") |
| text_temperature = gr.Slider( |
| 0.1, 1.5, value=0.6, step=0.05, label="Text Temperature" |
| ) |
| text_top_k = gr.Slider( |
| 1, 200, value=50, step=1, label="Text Top-K" |
| ) |
| with gr.Group(): |
| gr.Markdown("#### Audio Sampling") |
| audio_temperature = gr.Slider( |
| 0.1, 1.5, value=0.8, step=0.05, label="Audio Temperature" |
| ) |
| audio_top_k = gr.Slider( |
| 1, 200, value=50, step=1, label="Audio Top-K" |
| ) |
| include_prefix = gr.Checkbox( |
| label="Keep prefix audio in output", value=False |
| ) |
| generate_btn = gr.Button("Generate", variant="primary") |
| with gr.Column(scale=1): |
| gr.Markdown("### Output") |
| audio_out = gr.Audio(label="Waveform", interactive=False) |
| timestamps = gr.Dataframe( |
| headers=["word", "seconds"], label="Timestamps" |
| ) |
| log_box = gr.Textbox(label="Logs", lines=8) |
|
|
| add_btn.click( |
| lambda c: _add_turn(c), |
| inputs=turn_state, |
| outputs=[turn_state, *controls], |
| ) |
| remove_btn.click( |
| lambda c: _remove_turn(c), |
| inputs=turn_state, |
| outputs=[turn_state, *controls], |
| ) |
| example_dropdown.change( |
| lambda name, c: _load_example(name, c), |
| inputs=[example_dropdown, turn_state], |
| outputs=[turn_state, *controls, voice_s1, voice_s2], |
| ) |
|
|
| generate_btn.click( |
| generate_audio, |
| inputs=[ |
| turn_state, |
| *controls, |
| voice_s1, |
| voice_s2, |
| cfg_scale, |
| text_temperature, |
| audio_temperature, |
| text_top_k, |
| audio_top_k, |
| include_prefix, |
| ], |
| outputs=[audio_out, timestamps, log_box], |
| ) |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| app = build_interface() |
| app.queue(default_concurrency_limit=1) |
| app.launch(share=True) |
|
|