Spaces:
mrbui1990
/
Running on Zero

mrbui1990 commited on
Commit
baeaebc
·
verified ·
1 Parent(s): 8c0a6f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -52
app.py CHANGED
@@ -75,65 +75,72 @@ def translate_blocking(text, lang):
75
 
76
  # --- HÀM DỊCH STREAMING/WRAPPER ---
77
  def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
78
  print("lang", lang)
 
 
79
  if lang is None:
80
- if needStreaming:
81
- yield text
82
- else:
83
- return text
84
 
85
  # Khối xây dựng thông điệp dịch
86
- messages = build_translation_messages(text, lang)
87
  max_tokens = 2048
88
  temperature = 0.5
89
  top_p = 0.05
90
 
91
- if needStreaming:
92
- # Xây dựng tokenized_chat generation_kwargs cho Streaming
93
- tokenized_chat = tokenizer_trans.apply_chat_template(
94
- messages,
95
- tokenize=True,
96
- add_generation_prompt=True,
97
- return_tensors="pt"
98
- ).to(model_trans.device)
99
 
100
- generation_kwargs = dict(
101
- inputs=tokenized_chat,
102
- max_new_tokens=max_tokens,
103
- temperature=temperature,
104
- top_p=top_p,
105
- do_sample=True if temperature > 0 else False,
106
- pad_token_id=tokenizer_trans.eos_token_id
107
- )
108
 
109
- # --- Logic Streaming ---
 
110
  streamer = TextIteratorStreamer(tokenizer_trans, skip_prompt=True, skip_special_tokens=True)
111
  generation_kwargs["streamer"] = streamer
112
 
113
- # Chạy generate trong một thread riêng biệt
114
  thread = Thread(target=model_trans.generate, kwargs=generation_kwargs)
115
  thread.start()
116
 
117
  full_text = ""
118
  for new_text in streamer:
119
  full_text += new_text
120
- # [FIX] Cần yield full text (đã được sửa trong code gốc)
121
- yield full_text.strip() # Làm sạch khi stream (tùy chọn)
122
-
123
  else:
124
- # --- Logic Blocking (Gọi hàm đã tách) ---
125
- response = translate_blocking(text, lang)
126
- print("response (Blocking)", response)
127
- return response
 
 
 
 
 
 
 
128
 
129
 
130
  # --- HÀM CHÍNH CHAT ---
131
  @spaces.GPU(duration=60)
132
  def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, lang, gender, progress=gr.Progress(track_tqdm=True)):
133
- """
134
- Hàm này nhận prompt mới, system_prompt, lịch sử hiển thị (của gr.Chatbot)
135
- và lịch sử nội bộ (của gr.State). Trả về dạng Streaming.
136
- """
137
  expected_key = os.environ.get("hf_key")
138
  if expected_key and expected_key not in prompt:
139
  print("❌ Invalid key.")
@@ -144,27 +151,24 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
144
  prompt = prompt.replace(expected_key, "")
145
 
146
  isAuto = False
147
-
148
- # [FIX] Đảm bảo actual_prompt_for_model luôn được định nghĩa
149
- actual_prompt_for_model = prompt
150
 
151
  if "[AUTO]" in prompt:
152
  prompt = prompt.replace("[AUTO]", "")
153
  isAuto = True
154
- # Gán lại prompt sau khi loại bỏ [AUTO]
155
  actual_prompt_for_model = prompt
156
  else:
157
  # Dịch prompt input của user sang tiếng Anh (không cần streaming input này)
158
  if lang is not None:
159
- # GỌI HÀM, NÓ VẪN TRẢ VỀ GENERATOR OBJECT, DÙ CHẠY BLOCKING LOGIC
160
- generator_obj = translate_text(prompt, "English", needStreaming=False)
161
  print("prompt_translated (generator)", generator_obj)
162
 
163
- # [FIX CỐ ĐỊNH] Dùng next() để buộc Generator chạy và trích xuất chuỗi:
164
  try:
165
  prompt_translated = next(generator_obj)
166
  except (StopIteration, TypeError):
167
- # Xử lý trường hợp generator rỗng hoặc đã bị trích xuất
168
  prompt_translated = ""
169
 
170
  # Gán và làm sạch chuỗi
@@ -174,12 +178,13 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
174
 
175
  print("prompt",prompt)
176
  print("actual_prompt_for_model",actual_prompt_for_model)
177
-
178
 
179
- # [FIX] Bỏ thao tác vô nghĩa "".join(list(...))
 
180
  actual_prompt_for_model = actual_prompt_for_model + " [Detailed description of the physical actions and expressions.]"
181
  print("prompt for model: " + actual_prompt_for_model)
182
 
 
183
  if chatbot_display is None:
184
  chatbot_display = []
185
  if internal_history is None:
@@ -197,21 +202,20 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
197
  ).to(model.device)
198
 
199
  # Chuẩn bị Chatbot Display Placeholder
200
- # Append một list [user_msg, None] để bắt đầu streaming câu trả lời
201
  chatbot_display.append([prompt, ""])
202
 
203
  # --- LOGIC STREAMING CHÍNH ---
204
 
205
  if lang is not None:
206
  # TRƯỜNG HỢP CÓ DỊCH:
207
- # 1. Generate tiếng Anh (nhanh/blocking) để lấy full context
208
  output_tokens = model.generate(
209
  inputs,
210
  max_new_tokens=1024,
211
  do_sample=True,
212
  temperature=0.7,
213
  top_p=0.5,
214
- pad_token_id=tokenizer.eos_token_id # Thêm pad token cho model chat
215
  )
216
  english_response = tokenizer.decode(output_tokens[0][inputs.shape[-1]:], skip_special_tokens=True)
217
  print("Eng response generated: ", english_response)
@@ -221,7 +225,6 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
221
 
222
  partial_translation = ""
223
  for chunk in stream_translator:
224
- # chunk ở đây là full text tích lũy từ hàm translate_text đã sửa
225
  partial_translation = chunk
226
 
227
  # Cập nhật UI
@@ -229,7 +232,7 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
229
  yield "", chatbot_display, internal_history, partial_translation, prompt
230
 
231
  final_response_text = english_response
232
- final_translated = partial_translation.strip() # Làm sạch lần cuối
233
 
234
  else:
235
  # TRƯỜNG HỢP KHÔNG DỊCH (Raw English):
@@ -242,7 +245,7 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
242
  do_sample=True,
243
  temperature=0.7,
244
  top_p=0.5,
245
- pad_token_id=tokenizer.eos_token_id # Thêm pad token cho model chat
246
  )
247
 
248
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
@@ -255,13 +258,12 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
255
  yield "", chatbot_display, internal_history, partial_text, prompt
256
 
257
  final_response_text = partial_text
258
- final_translated = partial_text # Giống nhau vì không dịch
259
 
260
  # 6. Cập nhật "bộ nhớ" (gr.State) sau khi hoàn tất
261
  internal_history.append({"role": "user", "content": actual_prompt_for_model})
262
  internal_history.append({"role": "assistant", "content": final_response_text})
263
 
264
- # Yield lần cuối để đảm bảo state được lưu
265
  yield "", chatbot_display, internal_history, final_translated, prompt
266
 
267
  def clear_chat():
 
75
 
76
  # --- HÀM DỊCH STREAMING/WRAPPER ---
77
  def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(track_tqdm=True)):
78
+ """
79
+ Hàm dịch luôn trả về dạng generator (chỉ dùng yield).
80
+ needStreaming=True: Yield từng chunk.
81
+ needStreaming=False: Thực hiện Blocking, sau đó yield toàn bộ kết quả một lần.
82
+ """
83
  print("lang", lang)
84
+
85
+ # TRƯỜNG HỢP KHÔNG CẦN DỊCH
86
  if lang is None:
87
+ yield text
88
+ return # return không giá trị là hợp lệ
 
 
89
 
90
  # Khối xây dựng thông điệp dịch
91
+ messages = build_translation_messages(text, lang)
92
  max_tokens = 2048
93
  temperature = 0.5
94
  top_p = 0.05
95
 
96
+ # Xây dựng tokenized_chat và generation_kwargs
97
+ tokenized_chat = tokenizer_trans.apply_chat_template(
98
+ messages,
99
+ tokenize=True,
100
+ add_generation_prompt=True,
101
+ return_tensors="pt"
102
+ ).to(model_trans.device)
 
103
 
104
+ generation_kwargs = dict(
105
+ inputs=tokenized_chat,
106
+ max_new_tokens=max_tokens,
107
+ temperature=temperature,
108
+ top_p=top_p,
109
+ do_sample=True if temperature > 0 else False,
110
+ pad_token_id=tokenizer_trans.eos_token_id
111
+ )
112
 
113
+ if needStreaming:
114
+ # --- Logic Streaming (Yield từng chunk) ---
115
  streamer = TextIteratorStreamer(tokenizer_trans, skip_prompt=True, skip_special_tokens=True)
116
  generation_kwargs["streamer"] = streamer
117
 
 
118
  thread = Thread(target=model_trans.generate, kwargs=generation_kwargs)
119
  thread.start()
120
 
121
  full_text = ""
122
  for new_text in streamer:
123
  full_text += new_text
124
+ yield full_text # Yield accumulated text
125
+
 
126
  else:
127
+ # --- Logic Blocking (Chạy Blocking, yield kết quả một lần) ---
128
+ # Ta sử dụng model.generate trực tiếp và đợi kết quả.
129
+ with torch.no_grad():
130
+ outputs = model_trans.generate(**generation_kwargs)
131
+
132
+ # Decode và làm sạch chuỗi
133
+ response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True).strip()
134
+ print("response (Yielded Blocking)", response)
135
+
136
+ # Trả về kết quả cuối cùng dưới dạng Generator (yield một lần)
137
+ yield response
138
 
139
 
140
  # --- HÀM CHÍNH CHAT ---
141
  @spaces.GPU(duration=60)
142
  def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, lang, gender, progress=gr.Progress(track_tqdm=True)):
143
+
 
 
 
144
  expected_key = os.environ.get("hf_key")
145
  if expected_key and expected_key not in prompt:
146
  print("❌ Invalid key.")
 
151
  prompt = prompt.replace(expected_key, "")
152
 
153
  isAuto = False
154
+ actual_prompt_for_model = prompt # Gán giá trị mặc định
 
 
155
 
156
  if "[AUTO]" in prompt:
157
  prompt = prompt.replace("[AUTO]", "")
158
  isAuto = True
 
159
  actual_prompt_for_model = prompt
160
  else:
161
  # Dịch prompt input của user sang tiếng Anh (không cần streaming input này)
162
  if lang is not None:
163
+ # GỌI HÀM, NÓ LUÔN TRẢ VỀ GENERATOR OBJECT
164
+ generator_obj = translate_text(prompt, "English", needStreaming=False)
165
  print("prompt_translated (generator)", generator_obj)
166
 
167
+ # [FIX] Dùng next() để buộc Generator chạy và trích xuất chuỗi đầu tiên (kết quả Blocking):
168
  try:
169
  prompt_translated = next(generator_obj)
170
  except (StopIteration, TypeError):
171
+ # Xử lý trường hợp generator rỗng
172
  prompt_translated = ""
173
 
174
  # Gán và làm sạch chuỗi
 
178
 
179
  print("prompt",prompt)
180
  print("actual_prompt_for_model",actual_prompt_for_model)
 
181
 
182
+ # [FIX] Bỏ thao tác vô nghĩa "".join(list(...)) và thực hiện phép cộng chuỗi
183
+ # Dòng này giờ sẽ hoạt động vì actual_prompt_for_model đã là một chuỗi (str)
184
  actual_prompt_for_model = actual_prompt_for_model + " [Detailed description of the physical actions and expressions.]"
185
  print("prompt for model: " + actual_prompt_for_model)
186
 
187
+ # ... (Phần còn lại của hàm chat_with_model giữ nguyên logic generate)
188
  if chatbot_display is None:
189
  chatbot_display = []
190
  if internal_history is None:
 
202
  ).to(model.device)
203
 
204
  # Chuẩn bị Chatbot Display Placeholder
 
205
  chatbot_display.append([prompt, ""])
206
 
207
  # --- LOGIC STREAMING CHÍNH ---
208
 
209
  if lang is not None:
210
  # TRƯỜNG HỢP CÓ DỊCH:
211
+ # 1. Generate tiếng Anh (Blocking)
212
  output_tokens = model.generate(
213
  inputs,
214
  max_new_tokens=1024,
215
  do_sample=True,
216
  temperature=0.7,
217
  top_p=0.5,
218
+ pad_token_id=tokenizer.eos_token_id
219
  )
220
  english_response = tokenizer.decode(output_tokens[0][inputs.shape[-1]:], skip_special_tokens=True)
221
  print("Eng response generated: ", english_response)
 
225
 
226
  partial_translation = ""
227
  for chunk in stream_translator:
 
228
  partial_translation = chunk
229
 
230
  # Cập nhật UI
 
232
  yield "", chatbot_display, internal_history, partial_translation, prompt
233
 
234
  final_response_text = english_response
235
+ final_translated = partial_translation.strip()
236
 
237
  else:
238
  # TRƯỜNG HỢP KHÔNG DỊCH (Raw English):
 
245
  do_sample=True,
246
  temperature=0.7,
247
  top_p=0.5,
248
+ pad_token_id=tokenizer.eos_token_id
249
  )
250
 
251
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
258
  yield "", chatbot_display, internal_history, partial_text, prompt
259
 
260
  final_response_text = partial_text
261
+ final_translated = partial_text
262
 
263
  # 6. Cập nhật "bộ nhớ" (gr.State) sau khi hoàn tất
264
  internal_history.append({"role": "user", "content": actual_prompt_for_model})
265
  internal_history.append({"role": "assistant", "content": final_response_text})
266
 
 
267
  yield "", chatbot_display, internal_history, final_translated, prompt
268
 
269
  def clear_chat():