Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import gradio as gr | |
| import pandas as pd | |
| import io | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.optim import AdamW | |
| from sklearn.model_selection import train_test_split | |
| # ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋ | |
| MODEL_NAME = "beomi/kcbert-base" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) # 3๊ฐ ํด๋์ค๋ก ๋ณ๊ฒฝ | |
| # ๋ฐ์ดํฐ์ ํด๋์ค ์ ์ | |
| class CustomDataset(Dataset): | |
| def __init__(self, dataframe, tokenizer, max_len=128): | |
| self.tokenizer = tokenizer | |
| self.data = dataframe | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, index): | |
| item = self.data.iloc[index] | |
| description = str(item['description']) | |
| label = item['label'] | |
| encoding = self.tokenizer.encode_plus( | |
| description, | |
| add_special_tokens=True, | |
| max_length=self.max_len, | |
| return_token_type_ids=False, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt', | |
| ) | |
| return { | |
| 'input_ids': encoding['input_ids'].flatten(), | |
| 'attention_mask': encoding['attention_mask'].flatten(), | |
| 'labels': torch.tensor(label, dtype=torch.long) | |
| } | |
| # ํ๋ จ ๋ฐ์ดํฐ ์ค๋น ๋ฐ ๋ชจ๋ธ ํ๋ จ | |
| def train_model(): | |
| csv_data = """description,gender | |
| "๊ทธ๋ ์ถ๊ตฌ๋ฅผ ์ ๋ง ์ข์ํ๊ณ , ๊ทผ์ก์ง์ ๋ชธ๋งค๋ฅผ ๊ฐ์ก๋ค.",๋จ์ | |
| "๊ทธ๋ ๋ ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ์ก๊ณ , ๋ถํ์ ์ํผ์ค๋ฅผ ์ ์๋ค.",์ฌ์ | |
| "์งง์ ๋จธ๋ฆฌ์ ์ ์ฅ์ ์ ์ ๊ทธ๋ ํ์์ ์ฐธ์ํ๋ค.",๋จ์ | |
| "์๋ฆ๋ค์ด ๋ชฉ์๋ฆฌ๋ก ๋ ธ๋ํ๋ ๊ทธ๋ ๋ ๊ฐ์๋ค.",์ฌ์ | |
| "๊ทธ์ ์ทจ๋ฏธ๋ ์๋์ฐจ ์ ๋น์ ์ปดํจํฐ ๊ฒ์์ด๋ค.",๋จ์ | |
| "๊ทธ๋ ๋ ์ฌ์ธํ ์๊ธธ๋ก ์๊ธฐ ์ธํ์ ๋ง๋ค์๋ค.",์ฌ์ | |
| "๊ตฐ๋์์ ๋ง ์ ๋ํ ๊ทธ๋ ์ฉ์ฉํด ๋ณด์๋ค.",๋จ์ | |
| "๊ทธ๋ ๋ ์น๊ตฌ๋ค๊ณผ ์๋ค ๋ ๋ ๊ฒ์ ์ข์ํ๋ค.",์ฌ์ | |
| "๊ฐ๋ ฅํ ๋ฆฌ๋์ญ์ผ๋ก ํ์ ์ด๋๋ ๋ชจ์ต์ด ์ธ์์ ์ด์๋ค.",๋จ์ | |
| "์์ ์ด ์ง์ ๋ง๋ ์ฟ ํค๋ฅผ ์ฃผ๋ณ์ ๋๋์ด์ฃผ๊ณค ํ๋ค.",์ฌ์ | |
| "๊ทธ๋ค์ ์ฑ ์ฝ๊ธฐ๋ฅผ ์ข์ํ๊ณ ์กฐ์ฉํ ์ฑ๊ฒฉ์ด๋ค.",์ค์ฑ | |
| "ํค๊ฐ ํฌ๊ณ ์ฒด๊ฒฉ์ด ์ข์ผ๋ฉฐ ์ด๋์ ์ฆ๊ธด๋ค.",์ค์ฑ | |
| "์๋ฆฌ์ ์ฒญ์๋ฅผ ๋ชจ๋ ์ํ๋ฉฐ ์ง์์ผ์ ๋๋งก์ ํ๋ค.",์ค์ฑ | |
| "์ปดํจํฐ ํ๋ก๊ทธ๋๋ฐ๊ณผ ๋จ๊ฐ์ง์ ๋ชจ๋ ์ทจ๋ฏธ๋ก ํ๋ค.",์ค์ฑ | |
| "์ฐจ๋ถํ ์ฑ๊ฒฉ์ผ๋ก ์๋ด์ ์ํด์ฃผ๋ ํธ์ด๋ค.",์ค์ฑ | |
| "๋ ์์ ์ํ๊ฐ์์ ์ฆ๊ธฐ๋ ๋ฌธํ ์ ํธ๊ฐ์ด๋ค.",์ค์ฑ | |
| """ | |
| data = pd.read_csv(io.StringIO(csv_data)) | |
| # 3๊ฐ ํด๋์ค๋ก ๋ผ๋ฒจ ๋ณ๊ฒฝ: ๋จ์=0, ์ฌ์=1, ์ค์ฑ=2 | |
| data['label'] = data['gender'].apply(lambda x: 0 if x == '๋จ์' else (1 if x == '์ฌ์' else 2)) | |
| train_data, _ = train_test_split(data, test_size=0.2, random_state=42) | |
| train_dataset = CustomDataset(train_data, tokenizer) | |
| train_loader = DataLoader(train_dataset, batch_size=2) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| optimizer = AdamW(model.parameters(), lr=5e-5) | |
| print("๋ชจ๋ธ ํ๋ จ ์์...") | |
| model.train() | |
| for epoch in range(3): | |
| for batch in train_loader: | |
| optimizer.zero_grad() | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| labels = batch['labels'].to(device) | |
| outputs = model(input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Epoch {epoch + 1} ์๋ฃ") | |
| print("๋ชจ๋ธ ํ๋ จ ์๋ฃ!") | |
| # ์์ธก ํจ์ | |
| def predict_gender(text): | |
| if not text.strip(): | |
| return "ํ ์คํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์." | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.eval() | |
| encoding = tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=128, | |
| return_token_type_ids=False, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt', | |
| ) | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| prediction = torch.argmax(outputs.logits, dim=1).flatten().item() | |
| confidence = probabilities[0][prediction].item() | |
| # 3๊ฐ ํด๋์ค ๋งคํ: 0=๋จ์, 1=์ฌ์, 2=์ค์ฑ | |
| gender_map = {0: "๋จ์", 1: "์ฌ์", 2: "์ค์ฑ"} | |
| gender = gender_map[prediction] | |
| return f"์์ธก ์ฑ๋ณ: {gender} (์ ๋ขฐ๋: {confidence:.2%})" | |
| # ์ฑ ์์ ์ ๋ชจ๋ธ ํ๋ จ | |
| print("์ฑ ์ด๊ธฐํ ์ค...") | |
| train_model() | |
| # Gradio ์ธํฐํ์ด์ค ์์ฑ | |
| iface = gr.Interface( | |
| fn=predict_gender, | |
| inputs=gr.Textbox( | |
| lines=3, | |
| placeholder="์ฑ๋ณ์ ์์ธกํ ํ ์คํธ๋ฅผ ์ ๋ ฅํ์ธ์.\n์: '๊ทธ๋ ์ถ๊ตฌ๋ฅผ ์ข์ํ๊ณ ๊ทผ์ก์ง์ด๋ค.'", | |
| label="ํ ์คํธ ์ ๋ ฅ" | |
| ), | |
| outputs=gr.Textbox(label="์์ธก ๊ฒฐ๊ณผ"), | |
| title="๐ค AI ์ฑ๋ณ ์์ธก๊ธฐ (3๋ถ๋ฅ)", | |
| description="์ ๋ ฅ๋ ํ ์คํธ๋ฅผ ๋ฐํ์ผ๋ก ์ฑ๋ณ์ ์์ธกํฉ๋๋ค. (๋จ์/์ฌ์/์ค์ฑ)", | |
| examples=[ | |
| ["๊ทธ๋ ์ถ๊ตฌ๋ฅผ ์ ๋ง ์ข์ํ๊ณ , ๊ทผ์ก์ง์ ๋ชธ๋งค๋ฅผ ๊ฐ์ก๋ค."], | |
| ["๊ทธ๋ ๋ ๊ธด ๋จธ๋ฆฌ๋ฅผ ๊ฐ์ก๊ณ , ๋ถํ์ ์ํผ์ค๋ฅผ ์ ์๋ค."], | |
| ["์งง์ ๋จธ๋ฆฌ์ ์ ์ฅ์ ์ ์ ๊ทธ๋ ํ์์ ์ฐธ์ํ๋ค."], | |
| ["์๋ฆ๋ค์ด ๋ชฉ์๋ฆฌ๋ก ๋ ธ๋ํ๋ ๊ทธ๋ ๋ ๊ฐ์๋ค."], | |
| ["๊ทธ๋ค์ ์ฑ ์ฝ๊ธฐ๋ฅผ ์ข์ํ๊ณ ์กฐ์ฉํ ์ฑ๊ฒฉ์ด๋ค."], | |
| ["์๋ฆฌ์ ์ฒญ์๋ฅผ ๋ชจ๋ ์ํ๋ฉฐ ์ง์์ผ์ ๋๋งก์ ํ๋ค."] | |
| ], | |
| theme=gr.themes.Soft(), | |
| # Google ์ธ์ฆ์ ์ํ ์ปค์คํ HTML ํค๋ ์ถ๊ฐ | |
| head=""" | |
| <meta name="google-site-verification" content="9owJnk1eK0CZKk6u6slBQwC6ts3e1GUAm_ohwPtE2BI" /> | |
| """ | |
| ) | |
| # ์ฑ ์คํ | |
| if __name__ == "__main__": | |
| iface.launch() |