File size: 4,428 Bytes
0d13811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306b243
 
2ed0ab3
306b243
 
 
0d13811
 
 
306b243
 
0d13811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ed0ab3
 
 
 
0d13811
 
 
 
 
 
 
306b243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ed0ab3
 
 
 
306b243
 
 
 
 
 
 
0d13811
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Model manager for generation model"""

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import logging

logger = logging.getLogger(__name__)


class GenerateModelManager:
    """Manages generation model loading and predictions"""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = None
        self.model_loaded = False
    
    def load_model(self, model_id: str, api_key: str = None):
        """Load model and tokenizer from Hugging Face"""
        if self.model_loaded:
            logger.info("Generation model already loaded")
            return
        
        try:
            logger.info(f"Loading generation model from Hugging Face: {model_id}")
            
            # Determine device
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            logger.info(f"Using device: {self.device}")
            
            # Prepare token for authentication if API key is provided
            token = api_key if api_key else None
            
            # Load tokenizer and model from Hugging Face
            logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_id,
                token=token,
                trust_remote_code=True
            )
            
            logger.info("Loading model...")
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                model_id,
                token=token,
                trust_remote_code=True
            )
            self.model.to(self.device)
            self.model.eval()
            
            self.model_loaded = True
            logger.info("✓ Generation model loaded successfully from Hugging Face!")
            
        except Exception as e:
            logger.error(f"Error loading generation model: {str(e)}")
            raise RuntimeError(f"Failed to load generation model: {str(e)}")
    
    def _format_input(self, topic: str, position: str) -> str:
        """Format input for the model"""
        return f"topic: {topic} stance: {position}"

    def generate(self, topic: str, position: str, max_length: int = 128, num_beams: int = 4) -> str:
        """Generate argument for a topic and position"""
        if not self.model_loaded:
            raise RuntimeError("Generation model not loaded")
        
        input_text = self._format_input(topic, position)
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        ).to(self.device)
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
                no_repeat_ngram_size=3,
                repetition_penalty=2.5,
                length_penalty=1.0
            )
            
        # Decode
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return generated_text

    def batch_generate(self, items: list[dict], max_length: int = 128, num_beams: int = 4) -> list[str]:
        """Batch generate arguments"""
        if not self.model_loaded:
            raise RuntimeError("Generation model not loaded")
        
        # Prepare inputs
        input_texts = [self._format_input(item["topic"], item["position"]) for item in items]
        
        # Tokenize batch
        inputs = self.tokenizer(
            input_texts,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        ).to(self.device)
        
        # Generate batch
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
                no_repeat_ngram_size=3,
                repetition_penalty=2.5,
                length_penalty=1.0
            )
            
        # Decode batch
        generated_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        return generated_texts


# Initialize singleton instance
generate_model_manager = GenerateModelManager()