mrbui1990 commited on
Commit
e4d70e3
·
verified ·
1 Parent(s): 38d03a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -17
app.py CHANGED
@@ -23,6 +23,22 @@ model_trans = AutoModelForCausalLM.from_pretrained(
23
  device_map="cuda"
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # --- Sửa đổi hàm translate_text ---
27
  # Thêm tham số needStreaming
28
  def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(track_tqdm=True)):
@@ -79,18 +95,13 @@ def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(tr
79
  yield full_text
80
  else:
81
  # --- Logic cũ (Blocking) ---
82
- with torch.no_grad():
83
- outputs = model_trans.generate(
84
- tokenized_chat.to(model.device),
85
- max_new_tokens=max_tokens,
86
- temperature=temperature,
87
- top_p=top_p,
88
- do_sample=True if temperature > 0 else False,
89
- pad_token_id=tokenizer_trans.eos_token_id
90
- )
91
-
92
- response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True)
93
- print("response",response)
94
  return response
95
 
96
  # Tải model và tokenizer 1 LẦN DUY NHẤT
@@ -126,15 +137,29 @@ def chat_with_model(prompt, system_prompt, chatbot_display, internal_history, la
126
  else:
127
  # Dịch prompt input của user sang tiếng Anh (không cần streaming input này)
128
  if lang is not None:
129
- prompt_translated = translate_text(prompt, "English", needStreaming=False)
130
- print("prompt_translated",prompt_translated)
131
- # Lưu ý: Prompt gốc của user dùng để hiển thị, prompt translated dùng để đưa vào model
132
- actual_prompt_for_model = prompt_translated
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  else:
134
  actual_prompt_for_model = prompt
135
  print("prompt",prompt)
 
136
  print("actual_prompt_for_model",actual_prompt_for_model)
137
- actual_prompt_for_model = "".join(list(actual_prompt_for_model)) + " [Detailed description of the physical actions and expressions.]"
138
  print("prompt for model: " + actual_prompt_for_model)
139
 
140
  if chatbot_display is None:
 
23
  device_map="cuda"
24
  )
25
 
26
+ def generate_blocking(tokenized_chat, max_tokens, temperature, top_p):
27
+ """Thực hiện generate MỘT LẦN và trả về CHUỖI kết quả."""
28
+ with torch.no_grad():
29
+ # Đảm bảo bạn đang sử dụng đúng model: model_trans
30
+ outputs = model_trans.generate(
31
+ tokenized_chat, # Đã loại bỏ .to(model.device) vì nó đã ở model_trans.device
32
+ max_new_tokens=max_tokens,
33
+ temperature=temperature,
34
+ top_p=top_p,
35
+ do_sample=True if temperature > 0 else False,
36
+ pad_token_id=tokenizer_trans.eos_token_id
37
+ )
38
+
39
+ response = tokenizer_trans.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True)
40
+ return response
41
+
42
  # --- Sửa đổi hàm translate_text ---
43
  # Thêm tham số needStreaming
44
  def translate_text(text, lang=None, needStreaming=False, progress=gr.Progress(track_tqdm=True)):
 
95
  yield full_text
96
  else:
97
  # --- Logic cũ (Blocking) ---
98
+ response = generate_blocking(
99
+ tokenized_chat.to(model_trans.device), # Đảm bảo chuyển token sang đúng device
100
+ max_tokens,
101
+ temperature,
102
+ top_p
103
+ )
104
+ print("response (Blocking)", response)
 
 
 
 
 
105
  return response
106
 
107
  # Tải model và tokenizer 1 LẦN DUY NHẤT
 
137
  else:
138
  # Dịch prompt input của user sang tiếng Anh (không cần streaming input này)
139
  if lang is not None:
140
+ generator_obj = translate_text(prompt, "English", needStreaming=False) # Lấy Generator
141
+
142
+ # Lặp qua Generator để lấy chuỗi (dù bạn không muốn stream,
143
+ # generator vẫn là cách duy nhất để lấy giá trị nếu hàm translate_text không được sửa)
144
+ # Tuy nhiên, nếu bạn đã sửa hàm translate_text như BƯỚC 1, nó sẽ trả về chuỗi.
145
+
146
+ # GIẢ ĐỊNH bạn chưa sửa hàm translate_text, bạn cần lấy giá trị đầu tiên:
147
+ try:
148
+ # Vì bạn gọi với needStreaming=False, logic blocking của bạn
149
+ # sẽ trả về chuỗi, KHÔNG phải generator.
150
+ prompt_translated = generator_obj
151
+
152
+ # Nếu đã sửa translate_text (BƯỚC 1), không cần làm gì thêm, nó là chuỗi rồi.
153
+ # Nếu chưa sửa (và nó vẫn trả về generator dù là blocking), thì đây là lý do lỗi.
154
+
155
+ except StopIteration:
156
+ prompt_translated = "" # Generator rỗng
157
  else:
158
  actual_prompt_for_model = prompt
159
  print("prompt",prompt)
160
+ print("prompt_translated",prompt_translated)
161
  print("actual_prompt_for_model",actual_prompt_for_model)
162
+ actual_prompt_for_model = actual_prompt_for_model + " [Detailed description of the physical actions and expressions.]"
163
  print("prompt for model: " + actual_prompt_for_model)
164
 
165
  if chatbot_display is None: