LutherYTT's picture
Update app.py
31c7300
raw
history blame
3.77 kB
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, BertModel, BertPreTrainedModel
from safetensors.torch import load_file
import gc
# Release memory
gc.collect()
torch.cuda.empty_cache()
model_name = "hfl/chinese-roberta-wwm-ext"
class MultiTaskRoBert(BertPreTrainedModel):
def __init__(self, config, model_name_or_path):
super().__init__(config)
# Load backbone with pretrained weights if desired
self.bert = BertModel.from_pretrained(model_name_or_path, config=config)
self.classifier = nn.Linear(config.hidden_size, 3)
self.regressor = nn.Linear(config.hidden_size, 5)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled = outputs.pooler_output
sentiment_logits = self.classifier(pooled)
regression_outputs = self.regressor(pooled)
return sentiment_logits, regression_outputs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
model_path = "model1.safetensors"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
model = MultiTaskRoBert(config, model_name).to(device)
state_dict = load_file(model_path, device=device)
model.load_state_dict(state_dict)
model.eval()
# Use half precision to reduce memory usage
# if device.type == 'cuda':
# model.half()
def predict(text: str):
try:
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
if device.type == 'cuda':
with torch.cuda.amp.autocast():
out = model(**inputs)
else:
out = model(**inputs)
pred_class = torch.argmax(out["logits"], dim=-1).item()
sentiment_map = {0: "正面", 1: "負面", 2: "中立"}
reg_results = out["regression_outputs"][0].cpu().numpy()
rating, delight, anger, sorrow, happiness = reg_results
return {
"情感": sentiment_map[pred_class],
"評分": round(rating, 2),
"喜悅": round(delight, 2),
"憤怒": round(anger, 2),
"悲傷": round(sorrow, 2),
"快樂": round(happiness, 2),
}
except Exception as e:
return {"错误": f"处理失败: {str(e)}"}
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"),
outputs=gr.JSON(label="分析結果"),
title="粵語情感與情緒分析",
description="輸入粵語文本,分析情感(正面/負面/中立)和五種情緒評分",
examples=[
["呢個plan聽落唔錯,我哋試下先啦。"],
["份proposal 你send 咗俾client未?Deadline 係EOD呀。"],
["返工返到好攰,但係見到同事就feel better啲。"],
["你今次嘅presentation做得唔錯,我好 impressed!"],
["夜晚聽到嗰啲聲,我唔敢出房門。"],
["個client 真係好 difficult 囉,改咗n 次 requirements,仲要urgent,chur 到痴線!"],
["我尋日冇乜特別事做,就係喺屋企睇電視。"],
["Weekend 去staycation,間酒店個view 正到爆!"],
["做乜嘢都冇意義。"],
["今朝遲到咗,差啲miss咗個重要meeting"],
]
)
if __name__ == "__main__":
iface.launch(share=True, show_error=True)