LutherYTT's picture
Rename app.py to app2.py
ec2e9bd
raw
history blame
3.57 kB
import gradio as gr
import torch
import torch.nn as nn
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModel
import gc
# Release memory
gc.collect()
torch.cuda.empty_cache()
class MultiTaskRoberta(nn.Module):
def __init__(self, base_model):
super().__init__()
self.roberta = base_model
self.classifier = nn.Linear(768, 3)
self.regressor = nn.Linear(768, 5)
def forward(self, input_ids, attention_mask=None, **kwargs):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:, 0]
logits = self.classifier(pooled)
regs = self.regressor(pooled)
return {"logits": logits, "regression_outputs": regs}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
# Load base model
base_model = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
model = MultiTaskRoberta(base_model)
# Load safetensors
model_path = "model1.safetensors"
state_dict = load_file(model_path, device="cpu")
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# Use half precision to reduce memory usage
# if device.type == 'cuda':
# model.half()
def predict(text: str):
try:
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
if device.type == 'cuda':
with torch.cuda.amp.autocast():
out = model(**inputs)
else:
out = model(**inputs)
pred_class = torch.argmax(out["logits"], dim=-1).item()
sentiment_map = {0: "正面", 1: "負面", 2: "中立"}
reg_results = out["regression_outputs"][0].cpu().numpy()
rating, delight, anger, sorrow, happiness = reg_results
return {
"情感": sentiment_map[pred_class],
"評分": round(rating, 2),
"喜悅": round(delight, 2),
"憤怒": round(anger, 2),
"悲傷": round(sorrow, 2),
"快樂": round(happiness, 2),
}
except Exception as e:
return {"错误": f"处理失败: {str(e)}"}
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"),
outputs=gr.JSON(label="分析結果"),
title="粵語情感與情緒分析",
description="輸入粵語文本,分析情感(正面/負面/中立)和五種情緒評分",
examples=[
["呢個plan聽落唔錯,我哋試下先啦。"],
["份proposal 你send 咗俾client未?Deadline 係EOD呀。"],
["返工返到好攰,但係見到同事就feel better啲。"],
["你今次嘅presentation做得唔錯,我好 impressed!"],
["夜晚聽到嗰啲聲,我唔敢出房門。"],
["個client 真係好 difficult 囉,改咗n 次 requirements,仲要urgent,chur 到痴線!"],
["我尋日冇乜特別事做,就係喺屋企睇電視。"],
["Weekend 去staycation,間酒店個view 正到爆!"],
["做乜嘢都冇意義。"],
["今朝遲到咗,差啲miss咗個重要meeting"],
]
)
if __name__ == "__main__":
iface.launch(share=True, show_error=True)