S01Nour commited on
Commit
306b243
·
1 Parent(s): 2f878ea

feat: Introduce FastAPI endpoints for single and batch text generation with Pydantic models and Hugging Face model management.

Browse files
models/generate.py CHANGED
@@ -1,14 +1,52 @@
1
  """Pydantic models for text generation"""
2
 
3
- from pydantic import BaseModel
4
- from typing import Optional
5
 
6
  class GenerateRequest(BaseModel):
7
- input_text: str
8
- max_length: Optional[int] = 128
9
- num_beams: Optional[int] = 4
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class GenerateResponse(BaseModel):
12
- input_text: str
13
- generated_text: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  timestamp: str
 
1
  """Pydantic models for text generation"""
2
 
3
+ from pydantic import BaseModel, ConfigDict, Field
4
+ from typing import Optional, List
5
 
6
  class GenerateRequest(BaseModel):
7
+ """Request model for argument generation"""
8
+ model_config = ConfigDict(
9
+ json_schema_extra={
10
+ "example": {
11
+ "topic": "Assisted suicide should be a criminal offence",
12
+ "position": "positive" # "positive" or "negative"
13
+ }
14
+ }
15
+ )
16
+
17
+ topic: str = Field(..., min_length=5, max_length=1000,
18
+ description="The debate topic or statement")
19
+ position: str = Field(..., min_length=5, max_length=50,
20
+ description="The stance to take")
21
+
22
 
23
  class GenerateResponse(BaseModel):
24
+ """Response model for argument generation"""
25
+ model_config = ConfigDict(
26
+ json_schema_extra={
27
+ "example": {
28
+ "topic": "Assisted suicide should be a criminal offence",
29
+ "position": "positive", # "positive" or "negative"
30
+ "argument": "People have the right to choose how they end their lives",
31
+ "timestamp": "2024-11-15T10:30:00"
32
+ }
33
+ }
34
+ )
35
+
36
+ topic: str
37
+ position: str
38
+ argument: str
39
+ timestamp: str
40
+ timestamp: str
41
+
42
+
43
+ class BatchGenerateRequest(BaseModel):
44
+ """Request model for batch argument generation"""
45
+ items: List[GenerateRequest]
46
+
47
+
48
+ class BatchGenerateResponse(BaseModel):
49
+ """Response model for batch argument generation"""
50
+ results: List[GenerateResponse]
51
+ model_info: Optional[str] = "KPA T5 Generation Model"
52
  timestamp: str
routes/generate.py CHANGED
@@ -5,41 +5,75 @@ from datetime import datetime
5
  import logging
6
 
7
  from services import generate_model_manager
8
- from models.generate import GenerateRequest, GenerateResponse
9
 
10
  router = APIRouter()
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
- @router.post("/generate", response_model=GenerateResponse, tags=["Text Generation"])
15
- async def generate_text(request: GenerateRequest):
16
  """
17
- Generate text using the T5 model
18
 
19
- - **input_text**: The input text for generation
20
- - **max_length**: Maximum length of generated text (default: 128)
21
- - **num_beams**: Number of beams for beam search (default: 4)
22
-
23
- Returns generated text
24
  """
25
  try:
26
  # Generate text
27
  result = generate_model_manager.generate(
28
- request.input_text,
29
- max_length=request.max_length,
30
- num_beams=request.num_beams
31
  )
32
 
33
  # Build response
34
  response = GenerateResponse(
35
- input_text=request.input_text,
36
- generated_text=result,
 
37
  timestamp=datetime.now().isoformat()
38
  )
39
 
40
- logger.info(f"Generated text: {result}")
41
  return response
42
 
43
  except Exception as e:
44
  logger.error(f"Generation error: {str(e)}")
45
  raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import logging
6
 
7
  from services import generate_model_manager
8
+ from models.generate import GenerateRequest, GenerateResponse, BatchGenerateRequest, BatchGenerateResponse
9
 
10
  router = APIRouter()
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
+ @router.post("/predict", response_model=GenerateResponse, tags=["Text Generation"])
15
+ async def generate_argument(request: GenerateRequest):
16
  """
17
+ Generate an argument for a given topic and position
18
 
19
+ - **topic**: The debate topic
20
+ - **position**: The stance (e.g. "positive", "negative")
 
 
 
21
  """
22
  try:
23
  # Generate text
24
  result = generate_model_manager.generate(
25
+ topic=request.topic,
26
+ position=request.position
 
27
  )
28
 
29
  # Build response
30
  response = GenerateResponse(
31
+ topic=request.topic,
32
+ position=request.position,
33
+ argument=result,
34
  timestamp=datetime.now().isoformat()
35
  )
36
 
37
+ logger.info(f"Generated argument: {result[:50]}...")
38
  return response
39
 
40
  except Exception as e:
41
  logger.error(f"Generation error: {str(e)}")
42
  raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
43
+
44
+
45
+ @router.post("/batch-predict", response_model=BatchGenerateResponse, tags=["Text Generation"])
46
+ async def batch_generate_argument(request: BatchGenerateRequest):
47
+ """
48
+ Generate arguments for multiple topic-position pairs
49
+ """
50
+ try:
51
+ items_data = [{"topic": item.topic, "position": item.position} for item in request.items]
52
+
53
+ # Batch generate
54
+ results = generate_model_manager.batch_generate(
55
+ items=items_data
56
+ )
57
+
58
+ # Build response
59
+ response_items = []
60
+ timestamp = datetime.now().isoformat()
61
+
62
+ for i, item in enumerate(request.items):
63
+ response_items.append(
64
+ GenerateResponse(
65
+ topic=item.topic,
66
+ position=item.position,
67
+ argument=results[i],
68
+ timestamp=timestamp
69
+ )
70
+ )
71
+
72
+ return BatchGenerateResponse(
73
+ results=response_items,
74
+ timestamp=timestamp
75
+ )
76
+
77
+ except Exception as e:
78
+ logger.error(f"Batch generation error: {str(e)}")
79
+ raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
services/generate_model_manager.py CHANGED
@@ -56,11 +56,18 @@ class GenerateModelManager:
56
  logger.error(f"Error loading generation model: {str(e)}")
57
  raise RuntimeError(f"Failed to load generation model: {str(e)}")
58
 
59
- def generate(self, input_text: str, max_length: int = 128, num_beams: int = 4) -> str:
60
- """Generate text from input"""
 
 
 
 
 
61
  if not self.model_loaded:
62
  raise RuntimeError("Generation model not loaded")
63
 
 
 
64
  # Tokenize
65
  inputs = self.tokenizer(
66
  input_text,
@@ -84,6 +91,37 @@ class GenerateModelManager:
84
 
85
  return generated_text
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Initialize singleton instance
89
  generate_model_manager = GenerateModelManager()
 
56
  logger.error(f"Error loading generation model: {str(e)}")
57
  raise RuntimeError(f"Failed to load generation model: {str(e)}")
58
 
59
+ def _format_input(self, topic: str, position: str) -> str:
60
+ """Format input for the model"""
61
+ # Standard format for argument generation
62
+ return f"topic: {topic} stance: {position}"
63
+
64
+ def generate(self, topic: str, position: str, max_length: int = 128, num_beams: int = 4) -> str:
65
+ """Generate argument for a topic and position"""
66
  if not self.model_loaded:
67
  raise RuntimeError("Generation model not loaded")
68
 
69
+ input_text = self._format_input(topic, position)
70
+
71
  # Tokenize
72
  inputs = self.tokenizer(
73
  input_text,
 
91
 
92
  return generated_text
93
 
94
+ def batch_generate(self, items: list[dict], max_length: int = 128, num_beams: int = 4) -> list[str]:
95
+ """Batch generate arguments"""
96
+ if not self.model_loaded:
97
+ raise RuntimeError("Generation model not loaded")
98
+
99
+ # Prepare inputs
100
+ input_texts = [self._format_input(item["topic"], item["position"]) for item in items]
101
+
102
+ # Tokenize batch
103
+ inputs = self.tokenizer(
104
+ input_texts,
105
+ return_tensors="pt",
106
+ truncation=True,
107
+ max_length=512,
108
+ padding=True
109
+ ).to(self.device)
110
+
111
+ # Generate batch
112
+ with torch.no_grad():
113
+ outputs = self.model.generate(
114
+ **inputs,
115
+ max_length=max_length,
116
+ num_beams=num_beams,
117
+ early_stopping=True
118
+ )
119
+
120
+ # Decode batch
121
+ generated_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
122
+
123
+ return generated_texts
124
+
125
 
126
  # Initialize singleton instance
127
  generate_model_manager = GenerateModelManager()