gen_predict / app.py
ojs595's picture
Update app.py
974f1e7 verified
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
import pandas as pd
import io
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
MODEL_NAME = "beomi/kcbert-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) # 3๊ฐœ ํด๋ž˜์Šค๋กœ ๋ณ€๊ฒฝ
# ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ •์˜
class CustomDataset(Dataset):
def __init__(self, dataframe, tokenizer, max_len=128):
self.tokenizer = tokenizer
self.data = dataframe
self.max_len = max_len
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data.iloc[index]
description = str(item['description'])
label = item['label']
encoding = self.tokenizer.encode_plus(
description,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
# ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ ์ค€๋น„ ๋ฐ ๋ชจ๋ธ ํ›ˆ๋ จ
def train_model():
csv_data = """description,gender
"๊ทธ๋Š” ์ถ•๊ตฌ๋ฅผ ์ •๋ง ์ข‹์•„ํ•˜๊ณ , ๊ทผ์œก์งˆ์˜ ๋ชธ๋งค๋ฅผ ๊ฐ€์กŒ๋‹ค.",๋‚จ์ž
"๊ทธ๋…€๋Š” ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์กŒ๊ณ , ๋ถ„ํ™์ƒ‰ ์›ํ”ผ์Šค๋ฅผ ์ž…์—ˆ๋‹ค.",์—ฌ์ž
"์งง์€ ๋จธ๋ฆฌ์— ์ •์žฅ์„ ์ž…์€ ๊ทธ๋Š” ํšŒ์˜์— ์ฐธ์„ํ–ˆ๋‹ค.",๋‚จ์ž
"์•„๋ฆ„๋‹ค์šด ๋ชฉ์†Œ๋ฆฌ๋กœ ๋…ธ๋ž˜ํ•˜๋Š” ๊ทธ๋…€๋Š” ๊ฐ€์ˆ˜๋‹ค.",์—ฌ์ž
"๊ทธ์˜ ์ทจ๋ฏธ๋Š” ์ž๋™์ฐจ ์ •๋น„์™€ ์ปดํ“จํ„ฐ ๊ฒŒ์ž„์ด๋‹ค.",๋‚จ์ž
"๊ทธ๋…€๋Š” ์„ฌ์„ธํ•œ ์†๊ธธ๋กœ ์•„๊ธฐ ์ธํ˜•์„ ๋งŒ๋“ค์—ˆ๋‹ค.",์—ฌ์ž
"๊ตฐ๋Œ€์—์„œ ๋ง‰ ์ œ๋Œ€ํ•œ ๊ทธ๋Š” ์”ฉ์”ฉํ•ด ๋ณด์˜€๋‹ค.",๋‚จ์ž
"๊ทธ๋…€๋Š” ์นœ๊ตฌ๋“ค๊ณผ ์ˆ˜๋‹ค ๋– ๋Š” ๊ฒƒ์„ ์ข‹์•„ํ•œ๋‹ค.",์—ฌ์ž
"๊ฐ•๋ ฅํ•œ ๋ฆฌ๋”์‹ญ์œผ๋กœ ํŒ€์„ ์ด๋„๋Š” ๋ชจ์Šต์ด ์ธ์ƒ์ ์ด์—ˆ๋‹ค.",๋‚จ์ž
"์ž์‹ ์ด ์ง์ ‘ ๋งŒ๋“  ์ฟ ํ‚ค๋ฅผ ์ฃผ๋ณ€์— ๋‚˜๋ˆ„์–ด์ฃผ๊ณค ํ•œ๋‹ค.",์—ฌ์ž
"๊ทธ๋“ค์€ ์ฑ… ์ฝ๊ธฐ๋ฅผ ์ข‹์•„ํ•˜๊ณ  ์กฐ์šฉํ•œ ์„ฑ๊ฒฉ์ด๋‹ค.",์ค‘์„ฑ
"ํ‚ค๊ฐ€ ํฌ๊ณ  ์ฒด๊ฒฉ์ด ์ข‹์œผ๋ฉฐ ์šด๋™์„ ์ฆ๊ธด๋‹ค.",์ค‘์„ฑ
"์š”๋ฆฌ์™€ ์ฒญ์†Œ๋ฅผ ๋ชจ๋‘ ์ž˜ํ•˜๋ฉฐ ์ง‘์•ˆ์ผ์„ ๋„๋งก์•„ ํ•œ๋‹ค.",์ค‘์„ฑ
"์ปดํ“จํ„ฐ ํ”„๋กœ๊ทธ๋ž˜๋ฐ๊ณผ ๋œจ๊ฐœ์งˆ์„ ๋ชจ๋‘ ์ทจ๋ฏธ๋กœ ํ•œ๋‹ค.",์ค‘์„ฑ
"์ฐจ๋ถ„ํ•œ ์„ฑ๊ฒฉ์œผ๋กœ ์ƒ๋‹ด์„ ์ž˜ํ•ด์ฃผ๋Š” ํŽธ์ด๋‹ค.",์ค‘์„ฑ
"๋…์„œ์™€ ์˜ํ™”๊ฐ์ƒ์„ ์ฆ๊ธฐ๋Š” ๋ฌธํ™” ์• ํ˜ธ๊ฐ€์ด๋‹ค.",์ค‘์„ฑ
"""
data = pd.read_csv(io.StringIO(csv_data))
# 3๊ฐœ ํด๋ž˜์Šค๋กœ ๋ผ๋ฒจ ๋ณ€๊ฒฝ: ๋‚จ์ž=0, ์—ฌ์ž=1, ์ค‘์„ฑ=2
data['label'] = data['gender'].apply(lambda x: 0 if x == '๋‚จ์ž' else (1 if x == '์—ฌ์ž' else 2))
train_data, _ = train_test_split(data, test_size=0.2, random_state=42)
train_dataset = CustomDataset(train_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
print("๋ชจ๋ธ ํ›ˆ๋ จ ์‹œ์ž‘...")
model.train()
for epoch in range(3):
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1} ์™„๋ฃŒ")
print("๋ชจ๋ธ ํ›ˆ๋ จ ์™„๋ฃŒ!")
# ์˜ˆ์ธก ํ•จ์ˆ˜
def predict_gender(text):
if not text.strip():
return "ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
encoding = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
prediction = torch.argmax(outputs.logits, dim=1).flatten().item()
confidence = probabilities[0][prediction].item()
# 3๊ฐœ ํด๋ž˜์Šค ๋งคํ•‘: 0=๋‚จ์ž, 1=์—ฌ์ž, 2=์ค‘์„ฑ
gender_map = {0: "๋‚จ์ž", 1: "์—ฌ์ž", 2: "์ค‘์„ฑ"}
gender = gender_map[prediction]
return f"์˜ˆ์ธก ์„ฑ๋ณ„: {gender} (์‹ ๋ขฐ๋„: {confidence:.2%})"
# ์•ฑ ์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ํ›ˆ๋ จ
print("์•ฑ ์ดˆ๊ธฐํ™” ์ค‘...")
train_model()
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
iface = gr.Interface(
fn=predict_gender,
inputs=gr.Textbox(
lines=3,
placeholder="์„ฑ๋ณ„์„ ์˜ˆ์ธกํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.\n์˜ˆ: '๊ทธ๋Š” ์ถ•๊ตฌ๋ฅผ ์ข‹์•„ํ•˜๊ณ  ๊ทผ์œก์งˆ์ด๋‹ค.'",
label="ํ…์ŠคํŠธ ์ž…๋ ฅ"
),
outputs=gr.Textbox(label="์˜ˆ์ธก ๊ฒฐ๊ณผ"),
title="๐Ÿค– AI ์„ฑ๋ณ„ ์˜ˆ์ธก๊ธฐ (3๋ถ„๋ฅ˜)",
description="์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์„ฑ๋ณ„์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. (๋‚จ์ž/์—ฌ์ž/์ค‘์„ฑ)",
examples=[
["๊ทธ๋Š” ์ถ•๊ตฌ๋ฅผ ์ •๋ง ์ข‹์•„ํ•˜๊ณ , ๊ทผ์œก์งˆ์˜ ๋ชธ๋งค๋ฅผ ๊ฐ€์กŒ๋‹ค."],
["๊ทธ๋…€๋Š” ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์กŒ๊ณ , ๋ถ„ํ™์ƒ‰ ์›ํ”ผ์Šค๋ฅผ ์ž…์—ˆ๋‹ค."],
["์งง์€ ๋จธ๋ฆฌ์— ์ •์žฅ์„ ์ž…์€ ๊ทธ๋Š” ํšŒ์˜์— ์ฐธ์„ํ–ˆ๋‹ค."],
["์•„๋ฆ„๋‹ค์šด ๋ชฉ์†Œ๋ฆฌ๋กœ ๋…ธ๋ž˜ํ•˜๋Š” ๊ทธ๋…€๋Š” ๊ฐ€์ˆ˜๋‹ค."],
["๊ทธ๋“ค์€ ์ฑ… ์ฝ๊ธฐ๋ฅผ ์ข‹์•„ํ•˜๊ณ  ์กฐ์šฉํ•œ ์„ฑ๊ฒฉ์ด๋‹ค."],
["์š”๋ฆฌ์™€ ์ฒญ์†Œ๋ฅผ ๋ชจ๋‘ ์ž˜ํ•˜๋ฉฐ ์ง‘์•ˆ์ผ์„ ๋„๋งก์•„ ํ•œ๋‹ค."]
],
theme=gr.themes.Soft(),
# Google ์ธ์ฆ์„ ์œ„ํ•œ ์ปค์Šคํ…€ HTML ํ—ค๋“œ ์ถ”๊ฐ€
head="""
<meta name="google-site-verification" content="9owJnk1eK0CZKk6u6slBQwC6ts3e1GUAm_ohwPtE2BI" />
"""
)
# ์•ฑ ์‹คํ–‰
if __name__ == "__main__":
iface.launch()