Update app.py
Browse files
app.py
CHANGED
|
@@ -23,6 +23,22 @@ model_trans = AutoModelForCausalLM.from_pretrained(
|
|
| 23 |
device_map="cuda"
|
| 24 |
)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# --- Sửa đổi hàm translate_text ---
|
| 27 |
# Thêm tham số needStreaming
|
| 28 |
def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(track_tqdm=True)):
|
|
@@ -79,18 +95,13 @@ def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(tr
|
|
| 79 |
yield full_text
|
| 80 |
else:
|
| 81 |
# --- Logic cũ (Blocking) ---
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
pad_token_id=tokenizer_trans.eos_token_id
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True)
|
| 93 |
-
print("response",response)
|
| 94 |
return response
|
| 95 |
|
| 96 |
# Tải model và tokenizer 1 LẦN DUY NHẤT
|
|
@@ -126,15 +137,29 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
|
|
| 126 |
else:
|
| 127 |
# Dịch prompt input của user sang tiếng Anh (không cần streaming input này)
|
| 128 |
if lang is not None:
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
#
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
else:
|
| 134 |
actual_prompt_for_model = prompt
|
| 135 |
print("prompt",prompt)
|
|
|
|
| 136 |
print("actual_prompt_for_model",actual_prompt_for_model)
|
| 137 |
-
actual_prompt_for_model =
|
| 138 |
print("prompt for model: " + actual_prompt_for_model)
|
| 139 |
|
| 140 |
if chatbot_display is None:
|
|
|
|
| 23 |
device_map="cuda"
|
| 24 |
)
|
| 25 |
|
| 26 |
+
def generate_blocking(tokenized_chat, max_tokens, temperature, top_p):
|
| 27 |
+
"""Thực hiện generate MỘT LẦN và trả về CHUỖI kết quả."""
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
# Đảm bảo bạn đang sử dụng đúng model: model_trans
|
| 30 |
+
outputs = model_trans.generate(
|
| 31 |
+
tokenized_chat, # Đã loại bỏ .to(model.device) vì nó đã ở model_trans.device
|
| 32 |
+
max_new_tokens=max_tokens,
|
| 33 |
+
temperature=temperature,
|
| 34 |
+
top_p=top_p,
|
| 35 |
+
do_sample=True if temperature > 0 else False,
|
| 36 |
+
pad_token_id=tokenizer_trans.eos_token_id
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True)
|
| 40 |
+
return response
|
| 41 |
+
|
| 42 |
# --- Sửa đổi hàm translate_text ---
|
| 43 |
# Thêm tham số needStreaming
|
| 44 |
def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 95 |
yield full_text
|
| 96 |
else:
|
| 97 |
# --- Logic cũ (Blocking) ---
|
| 98 |
+
response = generate_blocking(
|
| 99 |
+
tokenized_chat.to(model_trans.device), # Đảm bảo chuyển token sang đúng device
|
| 100 |
+
max_tokens,
|
| 101 |
+
temperature,
|
| 102 |
+
top_p
|
| 103 |
+
)
|
| 104 |
+
print("response (Blocking)", response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
return response
|
| 106 |
|
| 107 |
# Tải model và tokenizer 1 LẦN DUY NHẤT
|
|
|
|
| 137 |
else:
|
| 138 |
# Dịch prompt input của user sang tiếng Anh (không cần streaming input này)
|
| 139 |
if lang is not None:
|
| 140 |
+
generator_obj = translate_text(prompt, "English", needStreaming=False) # Lấy Generator
|
| 141 |
+
|
| 142 |
+
# Lặp qua Generator để lấy chuỗi (dù bạn không muốn stream,
|
| 143 |
+
# generator vẫn là cách duy nhất để lấy giá trị nếu hàm translate_text không được sửa)
|
| 144 |
+
# Tuy nhiên, nếu bạn đã sửa hàm translate_text như BƯỚC 1, nó sẽ trả về chuỗi.
|
| 145 |
+
|
| 146 |
+
# GIẢ ĐỊNH bạn chưa sửa hàm translate_text, bạn cần lấy giá trị đầu tiên:
|
| 147 |
+
try:
|
| 148 |
+
# Vì bạn gọi với needStreaming=False, logic blocking của bạn
|
| 149 |
+
# sẽ trả về chuỗi, KHÔNG phải generator.
|
| 150 |
+
prompt_translated = generator_obj
|
| 151 |
+
|
| 152 |
+
# Nếu đã sửa translate_text (BƯỚC 1), không cần làm gì thêm, nó là chuỗi rồi.
|
| 153 |
+
# Nếu chưa sửa (và nó vẫn trả về generator dù là blocking), thì đây là lý do lỗi.
|
| 154 |
+
|
| 155 |
+
except StopIteration:
|
| 156 |
+
prompt_translated = "" # Generator rỗng
|
| 157 |
else:
|
| 158 |
actual_prompt_for_model = prompt
|
| 159 |
print("prompt",prompt)
|
| 160 |
+
print("prompt_translated",prompt_translated)
|
| 161 |
print("actual_prompt_for_model",actual_prompt_for_model)
|
| 162 |
+
actual_prompt_for_model = actual_prompt_for_model + " [Detailed description of the physical actions and expressions.]"
|
| 163 |
print("prompt for model: " + actual_prompt_for_model)
|
| 164 |
|
| 165 |
if chatbot_display is None:
|