S01Nour
feat: Introduce FastAPI endpoints for single and batch text generation with Pydantic models and Hugging Face model management.
306b243
| """Generation specific endpoints""" | |
| from fastapi import APIRouter, HTTPException | |
| from datetime import datetime | |
| import logging | |
| from services import generate_model_manager | |
| from models.generate import GenerateRequest, GenerateResponse, BatchGenerateRequest, BatchGenerateResponse | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| async def generate_argument(request: GenerateRequest): | |
| """ | |
| Generate an argument for a given topic and position | |
| - **topic**: The debate topic | |
| - **position**: The stance (e.g. "positive", "negative") | |
| """ | |
| try: | |
| # Generate text | |
| result = generate_model_manager.generate( | |
| topic=request.topic, | |
| position=request.position | |
| ) | |
| # Build response | |
| response = GenerateResponse( | |
| topic=request.topic, | |
| position=request.position, | |
| argument=result, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| logger.info(f"Generated argument: {result[:50]}...") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Generation error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| async def batch_generate_argument(request: BatchGenerateRequest): | |
| """ | |
| Generate arguments for multiple topic-position pairs | |
| """ | |
| try: | |
| items_data = [{"topic": item.topic, "position": item.position} for item in request.items] | |
| # Batch generate | |
| results = generate_model_manager.batch_generate( | |
| items=items_data | |
| ) | |
| # Build response | |
| response_items = [] | |
| timestamp = datetime.now().isoformat() | |
| for i, item in enumerate(request.items): | |
| response_items.append( | |
| GenerateResponse( | |
| topic=item.topic, | |
| position=item.position, | |
| argument=results[i], | |
| timestamp=timestamp | |
| ) | |
| ) | |
| return BatchGenerateResponse( | |
| results=response_items, | |
| timestamp=timestamp | |
| ) | |
| except Exception as e: | |
| logger.error(f"Batch generation error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}") | |