File size: 4,428 Bytes
0d13811 306b243 2ed0ab3 306b243 0d13811 306b243 0d13811 2ed0ab3 0d13811 306b243 2ed0ab3 306b243 0d13811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""Model manager for generation model"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import logging
logger = logging.getLogger(__name__)
class GenerateModelManager:
"""Manages generation model loading and predictions"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
def load_model(self, model_id: str, api_key: str = None):
"""Load model and tokenizer from Hugging Face"""
if self.model_loaded:
logger.info("Generation model already loaded")
return
try:
logger.info(f"Loading generation model from Hugging Face: {model_id}")
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Prepare token for authentication if API key is provided
token = api_key if api_key else None
# Load tokenizer and model from Hugging Face
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("Loading model...")
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
self.model.to(self.device)
self.model.eval()
self.model_loaded = True
logger.info("✓ Generation model loaded successfully from Hugging Face!")
except Exception as e:
logger.error(f"Error loading generation model: {str(e)}")
raise RuntimeError(f"Failed to load generation model: {str(e)}")
def _format_input(self, topic: str, position: str) -> str:
"""Format input for the model"""
return f"topic: {topic} stance: {position}"
def generate(self, topic: str, position: str, max_length: int = 128, num_beams: int = 4) -> str:
"""Generate argument for a topic and position"""
if not self.model_loaded:
raise RuntimeError("Generation model not loaded")
input_text = self._format_input(topic, position)
# Tokenize
inputs = self.tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=2.5,
length_penalty=1.0
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
def batch_generate(self, items: list[dict], max_length: int = 128, num_beams: int = 4) -> list[str]:
"""Batch generate arguments"""
if not self.model_loaded:
raise RuntimeError("Generation model not loaded")
# Prepare inputs
input_texts = [self._format_input(item["topic"], item["position"]) for item in items]
# Tokenize batch
inputs = self.tokenizer(
input_texts,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Generate batch
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=2.5,
length_penalty=1.0
)
# Decode batch
generated_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return generated_texts
# Initialize singleton instance
generate_model_manager = GenerateModelManager()
|