LutherYTT's picture
Update app.py
ee55fba verified
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, BertModel, BertPreTrainedModel
from safetensors.torch import load_file
import gc
# Release memory
gc.collect()
torch.cuda.empty_cache()
model_name = "hfl/chinese-roberta-wwm-ext"
class MultiTaskRoBert(BertPreTrainedModel):
def __init__(self, config, model_name):
super().__init__(config)
# Load backbone with pretrained weights if desired
self.bert = BertModel.from_pretrained(model_name, config=config)
self.classifier = nn.Linear(config.hidden_size, 3)
self.regressor = nn.Linear(config.hidden_size, 5)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled = outputs.pooler_output
sentiment_logits = self.classifier(pooled)
regression_outputs = self.regressor(pooled)
return {"logits": sentiment_logits, "regression_outputs": regression_outputs}
device = "cpu"
print(f"Device: {device}")
model_path = "model1.safetensors"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
model = MultiTaskRoBert(config, model_name).to(device)
state_dict = load_file(model_path, device="cpu")
model.load_state_dict(state_dict)
model.eval()
# Use half precision to reduce memory usage
# if device.type == 'cuda':
# model.half()
def predict(text: str):
try:
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
out = model(**inputs)
pred_class = torch.argmax(out["logits"], dim=-1).item()
sentiment_map = {0: "正面", 1: "負面", 2: "中立"}
reg_results = out["regression_outputs"][0].cpu().numpy()
for i in range(len(reg_results)):
if reg_results[i] < 0:
reg_results[i] = 0
elif reg_results[i] > 5:
reg_results[i] = 5
rating, delight, anger, sorrow, happiness = reg_results
return {
"情感": sentiment_map[pred_class],
"強度": round(rating, 2),
"喜": round(delight, 2),
"怒": round(anger, 2),
"哀": round(sorrow, 2),
"樂": round(happiness, 2),
}
except Exception as e:
return {"错误": f"处理失败: {str(e)}"}
article = "Author: Lu Yuk Tong, [Github Link](https://github.com/LutherYTT/Cantonese-Sentiment-Analysis-System-Multitasking-Learning-on-Scarce-Data)"
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"),
outputs=gr.JSON(label="分析結果"),
title="粵語情感與情緒分析",
description="輸入粵語文本,分析情感(正面/負面/中立)和五種情緒評分",
examples=[
["呢個plan聽落唔錯,我哋試下先啦。"],
["份proposal 你send 咗俾client未?Deadline 係EOD呀。"],
["返工返到好攰,但係見到同事就feel better啲。"],
["你今次嘅presentation做得唔錯,我好 impressed!"],
["夜晚聽到嗰啲聲,我唔敢出房門。"],
["個client 真係好 difficult 囉,改咗n 次 requirements,仲要urgent,chur 到痴線!"],
["我尋日冇乜特別事做,就係喺屋企睇電視。"],
["Weekend 去staycation,間酒店個view 正到爆!"],
["做乜嘢都冇意義。"],
["今朝遲到咗,差啲miss咗個重要meeting"],
],
article=article,
)
if __name__ == "__main__":
iface.launch(share=True, show_error=True)