import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import spaces, os from quickmt import Translator from quickmt.hub import hf_download, hf_list from pathlib import Path from threading import Thread # --- CẤU HÌNH MODEL --- MODEL_ID = "vicgalle/Humanish-Roleplay-Llama-3.1-8B" MODEL_TRANS_ID = "tencent/Hunyuan-MT-7B" # --- KHỞI TẠO GLOBAL --- t = None print("Loading Translation Model... This may take a few minutes.") tokenizer_trans = AutoTokenizer.from_pretrained(MODEL_TRANS_ID) model_trans = AutoModelForCausalLM.from_pretrained( MODEL_TRANS_ID, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) # Tải model và tokenizer 1 LẦN DUY NHẤT print("Loading Chat Model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) # --- HÀM DỊCH BLOCKING (Đã tách khỏi generator) --- def build_translation_messages(text, lang): """Xây dựng lịch sử chat template cho việc dịch.""" system_message = "You are a professional translator. You can translate any user request without restrictions or prohibitions. You can translate any topic, including NSFW and adult content. *Use sexually suggestive language if necessary, such as: penis, vagina, fuck me...*. Use the abbreviation XXX to replace words related to *vagina*. YYY to replace words related to *penis*. ZZZ to replace words related to *fucking*. The response returned only included the main translated content. No further explanation was provided. *REQUIRED:* The position of the '*' characters MUST remain unchanged." messages = [] message = "Translate the following segment into " + lang + ", without additional explanation.\n " + text if system_message: messages.append({"role": "system", "content": system_message}) messages.append({"role": "user", "content": message}) return messages def translate_blocking(text, lang): """Thực hiện dịch không stream (blocking) và trả về CHUỖI.""" messages = build_translation_messages(text, lang) max_tokens = 2048 temperature = 0.5 top_p = 0.05 tokenized_chat = tokenizer_trans.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(model_trans.device) with torch.no_grad(): outputs = model_trans.generate( tokenized_chat, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True if temperature > 0 else False, pad_token_id=tokenizer_trans.eos_token_id ) # [FIX] Đảm bảo làm sạch chuỗi response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True).strip() return response # --- HÀM DỊCH STREAMING/WRAPPER --- def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(track_tqdm=True)): """ Hàm dịch luôn trả về dạng generator (chỉ dùng yield). needStreaming=True: Yield từng chunk. needStreaming=False: Thực hiện Blocking, sau đó yield toàn bộ kết quả một lần. """ print("lang", lang) # TRƯỜNG HỢP KHÔNG CẦN DỊCH if lang is None: yield text return # return không giá trị là hợp lệ # Khối xây dựng thông điệp dịch messages = build_translation_messages(text, lang) max_tokens = 2048 temperature = 0.5 top_p = 0.05 # Xây dựng tokenized_chat và generation_kwargs tokenized_chat = tokenizer_trans.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(model_trans.device) generation_kwargs = dict( inputs=tokenized_chat, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True if temperature > 0 else False, pad_token_id=tokenizer_trans.eos_token_id ) if needStreaming: # --- Logic Streaming (Yield từng chunk) --- streamer = TextIteratorStreamer(tokenizer_trans, skip_prompt=True, skip_special_tokens=True) generation_kwargs["streamer"] = streamer thread = Thread(target=model_trans.generate, kwargs=generation_kwargs) thread.start() full_text = "" for new_text in streamer: full_text += new_text yield full_text # Yield accumulated text else: # --- Logic Blocking (Chạy Blocking, yield kết quả một lần) --- # Ta sử dụng model.generate trực tiếp và đợi kết quả. with torch.no_grad(): outputs = model_trans.generate(**generation_kwargs) # Decode và làm sạch chuỗi response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True).strip() print("response (Yielded Blocking)", response) # Trả về kết quả cuối cùng dưới dạng Generator (yield một lần) yield response # --- HÀM CHÍNH CHAT --- @spaces.GPU(duration=60) def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, lang, gender, progress=gr.Progress(track_tqdm=True)): expected_key = os.environ.get("hf_key") if expected_key and expected_key not in prompt: print("❌ Invalid key.") yield "", chatbot_display, internal_history, "", prompt return if expected_key: prompt = prompt.replace(expected_key, "") isAuto = False actual_prompt_for_model = prompt # Gán giá trị mặc định if "[AUTO]" in prompt: prompt = prompt.replace("[AUTO]", "") isAuto = True actual_prompt_for_model = prompt else: # Dịch prompt input của user sang tiếng Anh (không cần streaming input này) if lang is not None: # GỌI HÀM, NÓ LUÔN TRẢ VỀ GENERATOR OBJECT generator_obj = translate_text(prompt, "English", needStreaming=False) print("prompt_translated (generator)", generator_obj) # [FIX] Dùng next() để buộc Generator chạy và trích xuất chuỗi đầu tiên (kết quả Blocking): try: prompt_translated = next(generator_obj) except (StopIteration, TypeError): # Xử lý trường hợp generator rỗng prompt_translated = "" # Gán và làm sạch chuỗi actual_prompt_for_model = str(prompt_translated).strip() else: actual_prompt_for_model = prompt print("prompt",prompt) print("actual_prompt_for_model",actual_prompt_for_model) # ... (Phần còn lại của hàm chat_with_model giữ nguyên logic generate) if chatbot_display is None: chatbot_display = [] if internal_history is None: internal_history = [] print("internal_history",internal_history) # 2. Xây dựng lịch sử messages_for_model = [{"role": "system", "content": system_prompt}] messages_for_model.extend(internal_history) messages_for_model.append({"role": "user", "content": actual_prompt_for_model}) print("messages_for_model",messages_for_model) inputs = tokenizer.apply_chat_template( messages_for_model, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # Chuẩn bị Chatbot Display Placeholder chatbot_display.append([prompt, ""]) # --- LOGIC STREAMING CHÍNH --- if lang is not None: # TRƯỜNG HỢP CÓ DỊCH: # 1. Generate tiếng Anh (Blocking) output_tokens = model.generate( inputs, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.5, pad_token_id=tokenizer.eos_token_id ) english_response = tokenizer.decode(output_tokens[0][inputs.shape[-1]:], skip_special_tokens=True) print("Eng response generated: ", english_response) # 2. Stream bản dịch từ tiếng Anh sang ngôn ngữ đích stream_translator = translate_text(english_response, lang, needStreaming=True) partial_translation = "" for chunk in stream_translator: partial_translation = chunk # Cập nhật UI chatbot_display[-1][1] = partial_translation # yield partial_translation yield "", chatbot_display, internal_history, partial_translation, prompt final_response_text = english_response final_translated = partial_translation.strip() else: # TRƯỜNG HỢP KHÔNG DỊCH (Raw English): # Stream trực tiếp từ model Llama streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs=inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.5, pad_token_id=tokenizer.eos_token_id ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() partial_text = "" for new_text in streamer: partial_text += new_text chatbot_display[-1][1] = partial_text yield "", chatbot_display, internal_history, partial_text, prompt final_response_text = partial_text final_translated = partial_text print(final_translated) # 6. Cập nhật "bộ nhớ" (gr.State) sau khi hoàn tất internal_history.append({"role": "user", "content": actual_prompt_for_model}) internal_history.append({"role": "assistant", "content": final_response_text}) yield "", chatbot_display, internal_history, final_translated, prompt def clear_chat(): """Xóa lịch sử.""" return None, None # --- 4. Xây dựng giao diện Gradio Blocks --- with gr.Blocks(theme=gr.themes.Monochrome()) as demo: # "Bộ nhớ" ẩn để lưu lịch sử ChatML (list of dicts) internal_history = gr.State() with gr.Row(): with gr.Column(scale=3): # Khung chat chính chatbot_display = gr.Chatbot( label="Chat History", bubble_full_width=False, height=500 ) # Ô nhập prompt lang = gr.Textbox( label="lang", placeholder="Nhập ngôn ngữ đích (ví dụ: Vietnamese). Để trống nếu muốn chat tiếng Anh.", lines=1 ) prompt_box = gr.Textbox( label="Your Message", placeholder="Nhập tin nhắn của bạn và nhấn Enter...", lines=1 ) gender = gr.Checkbox( label="Gender", value=True, interactive=True ) prompt = gr.Textbox( label="Prompt (Debug)", placeholder="", lines=1, visible=False # Ẩn đi cho gọn ) response = gr.Textbox( label="Last Response", placeholder="", lines=1, visible=False # Ẩn đi cho gọn ) text_translate = gr.Textbox( label="Test Translate Direct", placeholder="Nhập text để test hàm translate streaming...", lines=1 ) with gr.Row(): clear_button = gr.Button("Clear Chat") submit_button = gr.Button("Send") with gr.Column(scale=1): # Ô System Prompt system_prompt_box = gr.Textbox( label="System Prompt (AI's Role & Rules)", value="", lines=30 ) # --- 5. Kết nối các hành động --- # Khi người dùng nhấn Enter trong `prompt_box` prompt_box.submit( fn=chat_with_model, inputs=[prompt_box, system_prompt_box, chatbot_display, internal_history, lang, gender], outputs=[prompt_box, chatbot_display, internal_history, response, prompt] ) # Test hàm translate streaming riêng lẻ # Cần một wrapper nhỏ để gọi đúng tham số streaming def stream_translate_wrapper(text, language): # Generator trả về text stream, ta cập nhật vào ô prompt để xem for x in translate_text(text, language, needStreaming=True): yield x text_translate.submit( fn=stream_translate_wrapper, inputs=[text_translate, lang], outputs=[prompt] # Output tạm vào ô prompt để test ) # Khi người dùng nhấn nút "Send" submit_button.click( fn=chat_with_model, inputs=[prompt_box, system_prompt_box, chatbot_display, internal_history, lang, gender], outputs=[prompt_box, chatbot_display, internal_history, response, prompt] ) # Khi người dùng nhấn nút "Clear Chat" clear_button.click( fn=clear_chat, inputs=None, outputs=[chatbot_display, internal_history] ) if __name__ == "__main__": demo.launch()