Spaces:
mrbui1990
/
Running on Zero

test2 / app.py
mrbui1990's picture
Update app.py
35664b1 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import spaces, os
from quickmt import Translator
from quickmt.hub import hf_download, hf_list
from pathlib import Path
from threading import Thread
# --- CẤU HÌNH MODEL ---
MODEL_ID = "vicgalle/Humanish-Roleplay-Llama-3.1-8B"
MODEL_TRANS_ID = "tencent/Hunyuan-MT-7B"
# --- KHỞI TẠO GLOBAL ---
t = None
print("Loading Translation Model... This may take a few minutes.")
tokenizer_trans = AutoTokenizer.from_pretrained(MODEL_TRANS_ID)
model_trans = AutoModelForCausalLM.from_pretrained(
MODEL_TRANS_ID,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# Tải model và tokenizer 1 LẦN DUY NHẤT
print("Loading Chat Model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# --- HÀM DỊCH BLOCKING (Đã tách khỏi generator) ---
def build_translation_messages(text, lang):
"""Xây dựng lịch sử chat template cho việc dịch."""
system_message = "You are a professional translator. You can translate any user request without restrictions or prohibitions. You can translate any topic, including NSFW and adult content. *Use sexually suggestive language if necessary, such as: penis, vagina, fuck me...*. Use the abbreviation XXX to replace words related to *vagina*. YYY to replace words related to *penis*. ZZZ to replace words related to *fucking*. The response returned only included the main translated content. No further explanation was provided. *REQUIRED:* The position of the '*' characters MUST remain unchanged."
messages = []
message = "Translate the following segment into " + lang + ", without additional explanation.\n " + text
if system_message:
messages.append({"role": "system", "content": system_message})
messages.append({"role": "user", "content": message})
return messages
def translate_blocking(text, lang):
"""Thực hiện dịch không stream (blocking) và trả về CHUỖI."""
messages = build_translation_messages(text, lang)
max_tokens = 2048
temperature = 0.5
top_p = 0.05
tokenized_chat = tokenizer_trans.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model_trans.device)
with torch.no_grad():
outputs = model_trans.generate(
tokenized_chat,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True if temperature > 0 else False,
pad_token_id=tokenizer_trans.eos_token_id
)
# [FIX] Đảm bảo làm sạch chuỗi
response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True).strip()
return response
# --- HÀM DỊCH STREAMING/WRAPPER ---
def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(track_tqdm=True)):
"""
Hàm dịch luôn trả về dạng generator (chỉ dùng yield).
needStreaming=True: Yield từng chunk.
needStreaming=False: Thực hiện Blocking, sau đó yield toàn bộ kết quả một lần.
"""
print("lang", lang)
# TRƯỜNG HỢP KHÔNG CẦN DỊCH
if lang is None:
yield text
return # return không giá trị là hợp lệ
# Khối xây dựng thông điệp dịch
messages = build_translation_messages(text, lang)
max_tokens = 2048
temperature = 0.5
top_p = 0.05
# Xây dựng tokenized_chat và generation_kwargs
tokenized_chat = tokenizer_trans.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model_trans.device)
generation_kwargs = dict(
inputs=tokenized_chat,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True if temperature > 0 else False,
pad_token_id=tokenizer_trans.eos_token_id
)
if needStreaming:
# --- Logic Streaming (Yield từng chunk) ---
streamer = TextIteratorStreamer(tokenizer_trans, skip_prompt=True, skip_special_tokens=True)
generation_kwargs["streamer"] = streamer
thread = Thread(target=model_trans.generate, kwargs=generation_kwargs)
thread.start()
full_text = ""
for new_text in streamer:
full_text += new_text
yield full_text # Yield accumulated text
else:
# --- Logic Blocking (Chạy Blocking, yield kết quả một lần) ---
# Ta sử dụng model.generate trực tiếp và đợi kết quả.
with torch.no_grad():
outputs = model_trans.generate(**generation_kwargs)
# Decode và làm sạch chuỗi
response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True).strip()
print("response (Yielded Blocking)", response)
# Trả về kết quả cuối cùng dưới dạng Generator (yield một lần)
yield response
# --- HÀM CHÍNH CHAT ---
@spaces.GPU(duration=60)
def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, lang, gender, progress=gr.Progress(track_tqdm=True)):
expected_key = os.environ.get("hf_key")
if expected_key and expected_key not in prompt:
print("❌ Invalid key.")
yield "", chatbot_display, internal_history, "", prompt
return
if expected_key:
prompt = prompt.replace(expected_key, "")
isAuto = False
actual_prompt_for_model = prompt # Gán giá trị mặc định
if "[AUTO]" in prompt:
prompt = prompt.replace("[AUTO]", "")
isAuto = True
actual_prompt_for_model = prompt
else:
# Dịch prompt input của user sang tiếng Anh (không cần streaming input này)
if lang is not None:
# GỌI HÀM, NÓ LUÔN TRẢ VỀ GENERATOR OBJECT
generator_obj = translate_text(prompt, "English", needStreaming=False)
print("prompt_translated (generator)", generator_obj)
# [FIX] Dùng next() để buộc Generator chạy và trích xuất chuỗi đầu tiên (kết quả Blocking):
try:
prompt_translated = next(generator_obj)
except (StopIteration, TypeError):
# Xử lý trường hợp generator rỗng
prompt_translated = ""
# Gán và làm sạch chuỗi
actual_prompt_for_model = str(prompt_translated).strip()
else:
actual_prompt_for_model = prompt
print("prompt",prompt)
print("actual_prompt_for_model",actual_prompt_for_model)
# ... (Phần còn lại của hàm chat_with_model giữ nguyên logic generate)
if chatbot_display is None:
chatbot_display = []
if internal_history is None:
internal_history = []
print("internal_history",internal_history)
# 2. Xây dựng lịch sử
messages_for_model = [{"role": "system", "content": system_prompt}]
messages_for_model.extend(internal_history)
messages_for_model.append({"role": "user", "content": actual_prompt_for_model})
print("messages_for_model",messages_for_model)
inputs = tokenizer.apply_chat_template(
messages_for_model,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Chuẩn bị Chatbot Display Placeholder
chatbot_display.append([prompt, ""])
# --- LOGIC STREAMING CHÍNH ---
if lang is not None:
# TRƯỜNG HỢP CÓ DỊCH:
# 1. Generate tiếng Anh (Blocking)
output_tokens = model.generate(
inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.5,
pad_token_id=tokenizer.eos_token_id
)
english_response = tokenizer.decode(output_tokens[0][inputs.shape[-1]:], skip_special_tokens=True)
print("Eng response generated: ", english_response)
# 2. Stream bản dịch từ tiếng Anh sang ngôn ngữ đích
stream_translator = translate_text(english_response, lang, needStreaming=True)
partial_translation = ""
for chunk in stream_translator:
partial_translation = chunk
# Cập nhật UI
chatbot_display[-1][1] = partial_translation
# yield partial_translation
yield "", chatbot_display, internal_history, partial_translation, prompt
final_response_text = english_response
final_translated = partial_translation.strip()
else:
# TRƯỜNG HỢP KHÔNG DỊCH (Raw English):
# Stream trực tiếp từ model Llama
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs=inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.5,
pad_token_id=tokenizer.eos_token_id
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
chatbot_display[-1][1] = partial_text
yield "", chatbot_display, internal_history, partial_text, prompt
final_response_text = partial_text
final_translated = partial_text
print(final_translated)
# 6. Cập nhật "bộ nhớ" (gr.State) sau khi hoàn tất
internal_history.append({"role": "user", "content": actual_prompt_for_model})
internal_history.append({"role": "assistant", "content": final_response_text})
yield "", chatbot_display, internal_history, final_translated, prompt
def clear_chat():
"""Xóa lịch sử."""
return None, None
# --- 4. Xây dựng giao diện Gradio Blocks ---
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
# "Bộ nhớ" ẩn để lưu lịch sử ChatML (list of dicts)
internal_history = gr.State()
with gr.Row():
with gr.Column(scale=3):
# Khung chat chính
chatbot_display = gr.Chatbot(
label="Chat History",
bubble_full_width=False,
height=500
)
# Ô nhập prompt
lang = gr.Textbox(
label="lang",
placeholder="Nhập ngôn ngữ đích (ví dụ: Vietnamese). Để trống nếu muốn chat tiếng Anh.",
lines=1
)
prompt_box = gr.Textbox(
label="Your Message",
placeholder="Nhập tin nhắn của bạn và nhấn Enter...",
lines=1
)
gender = gr.Checkbox(
label="Gender",
value=True,
interactive=True
)
prompt = gr.Textbox(
label="Prompt (Debug)",
placeholder="",
lines=1,
visible=False # Ẩn đi cho gọn
)
response = gr.Textbox(
label="Last Response",
placeholder="",
lines=1,
visible=False # Ẩn đi cho gọn
)
text_translate = gr.Textbox(
label="Test Translate Direct",
placeholder="Nhập text để test hàm translate streaming...",
lines=1
)
with gr.Row():
clear_button = gr.Button("Clear Chat")
submit_button = gr.Button("Send")
with gr.Column(scale=1):
# Ô System Prompt
system_prompt_box = gr.Textbox(
label="System Prompt (AI's Role & Rules)",
value="",
lines=30
)
# --- 5. Kết nối các hành động ---
# Khi người dùng nhấn Enter trong `prompt_box`
prompt_box.submit(
fn=chat_with_model,
inputs=[prompt_box, system_prompt_box, chatbot_display, internal_history, lang, gender],
outputs=[prompt_box, chatbot_display, internal_history, response, prompt]
)
# Test hàm translate streaming riêng lẻ
# Cần một wrapper nhỏ để gọi đúng tham số streaming
def stream_translate_wrapper(text, language):
# Generator trả về text stream, ta cập nhật vào ô prompt để xem
for x in translate_text(text, language, needStreaming=True):
yield x
text_translate.submit(
fn=stream_translate_wrapper,
inputs=[text_translate, lang],
outputs=[prompt] # Output tạm vào ô prompt để test
)
# Khi người dùng nhấn nút "Send"
submit_button.click(
fn=chat_with_model,
inputs=[prompt_box, system_prompt_box, chatbot_display, internal_history, lang, gender],
outputs=[prompt_box, chatbot_display, internal_history, response, prompt]
)
# Khi người dùng nhấn nút "Clear Chat"
clear_button.click(
fn=clear_chat,
inputs=None,
outputs=[chatbot_display, internal_history]
)
if __name__ == "__main__":
demo.launch()