S01Nour
feat: Introduce FastAPI endpoints for single and batch text generation with Pydantic models and Hugging Face model management.
306b243
raw
history blame
2.52 kB
"""Generation specific endpoints"""
from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging
from services import generate_model_manager
from models.generate import GenerateRequest, GenerateResponse, BatchGenerateRequest, BatchGenerateResponse
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/predict", response_model=GenerateResponse, tags=["Text Generation"])
async def generate_argument(request: GenerateRequest):
"""
Generate an argument for a given topic and position
- **topic**: The debate topic
- **position**: The stance (e.g. "positive", "negative")
"""
try:
# Generate text
result = generate_model_manager.generate(
topic=request.topic,
position=request.position
)
# Build response
response = GenerateResponse(
topic=request.topic,
position=request.position,
argument=result,
timestamp=datetime.now().isoformat()
)
logger.info(f"Generated argument: {result[:50]}...")
return response
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@router.post("/batch-predict", response_model=BatchGenerateResponse, tags=["Text Generation"])
async def batch_generate_argument(request: BatchGenerateRequest):
"""
Generate arguments for multiple topic-position pairs
"""
try:
items_data = [{"topic": item.topic, "position": item.position} for item in request.items]
# Batch generate
results = generate_model_manager.batch_generate(
items=items_data
)
# Build response
response_items = []
timestamp = datetime.now().isoformat()
for i, item in enumerate(request.items):
response_items.append(
GenerateResponse(
topic=item.topic,
position=item.position,
argument=results[i],
timestamp=timestamp
)
)
return BatchGenerateResponse(
results=response_items,
timestamp=timestamp
)
except Exception as e:
logger.error(f"Batch generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")