"""Pydantic schemas for key-point matching prediction endpoints""" from pydantic import BaseModel, Field, ConfigDict from typing import List, Optional, Dict class PredictionRequest(BaseModel): """Request model for single key-point/argument prediction""" model_config = ConfigDict( json_schema_extra={ "example": { "argument": "Apples are good for health", "key_point": "Fruits are healthy" } } ) argument: str = Field( ..., min_length=5, max_length=1000, description="The argument text to evaluate" ) key_point: str = Field( ..., min_length=5, max_length=500, description="The key point used for comparison" ) class PredictionResponse(BaseModel): """Response model for single prediction""" model_config = ConfigDict( json_schema_extra={ "example": { "prediction": 1, "confidence": 0.956, "label": "apparie", "probabilities": { "non_apparie": 0.044, "apparie": 0.956 } } } ) prediction: int = Field(..., description="1 = apparie, 0 = non_apparie") confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score of the prediction") label: str = Field(..., description="apparie or non_apparie") probabilities: Dict[str, float] = Field( ..., description="Dictionary of class probabilities" ) class BatchPredictionRequest(BaseModel): """Request model for batch predictions""" model_config = ConfigDict( json_schema_extra={ "example": { "pairs": [ { "argument": "Apples are good for health", "key_point": "Fruits are healthy" }, { "argument": "Dogs make great pets", "key_point": "Cats are better than dogs" }, { "argument": "Exercise is important", "key_point": "Sports are good for you" }, { "argument": "Reading books is fun", "key_point": "We should build more roads" }, { "argument": "Water is essential for life", "key_point": "Drinking water is important" } ] } } ) pairs: List[PredictionRequest] = Field( ..., max_length=100, description="List of argument-keypoint pairs (max 100)" ) class BatchPredictionResponse(BaseModel): """Response model for batch key-point predictions""" model_config = ConfigDict( json_schema_extra={ "example": { "predictions": [ { "prediction": 1, "confidence": 0.956, "label": "apparie", "probabilities": { "non_apparie": 0.044, "apparie": 0.956 } }, { "prediction": 0, "confidence": 0.892, "label": "non_apparie", "probabilities": { "non_apparie": 0.892, "apparie": 0.108 } }, { "prediction": 1, "confidence": 0.934, "label": "apparie", "probabilities": { "non_apparie": 0.066, "apparie": 0.934 } }, { "prediction": 0, "confidence": 0.995, "label": "non_apparie", "probabilities": { "non_apparie": 0.995, "apparie": 0.005 } }, { "prediction": 1, "confidence": 0.967, "label": "apparie", "probabilities": { "non_apparie": 0.033, "apparie": 0.967 } } ], "total_processed": 5, "summary": { "total_apparie": 3, "total_non_apparie": 2, "average_confidence": 0.9488, "successful_predictions": 5, "failed_predictions": 0 } } } ) predictions: List[PredictionResponse] total_processed: int = Field(..., description="Number of processed items") summary: Dict[str, float] = Field( default_factory=dict, description="Summary statistics of the batch prediction" ) class HealthResponse(BaseModel): """Health check model for the API""" model_config = ConfigDict( json_schema_extra={ "example": { "status": "healthy", "model_loaded": True, "device": "cpu", "model_name": "NLP-Debater-Project/distilBert-keypoint-matching", "timestamp": "2024-01-01T12:00:00Z" } } ) status: str = Field(..., description="API health status") model_loaded: bool = Field(..., description="Whether the model is loaded") device: str = Field(..., description="Device used for inference (cpu/cuda)") model_name: Optional[str] = Field(None, description="Name of the loaded model") timestamp: str = Field(..., description="Timestamp of the health check") class ModelInfoResponse(BaseModel): """Detailed model information response""" model_config = ConfigDict( json_schema_extra={ "example": { "model_name": "NLP-Debater-Project/distilBert-keypoint-matching", "device": "cpu", "max_length": 256, "num_labels": 2, "loaded": True, "performance": { "accuracy": 0.9285, "f1_score": 0.8836, "f1_apparie": 0.8113, "f1_non_apparie": 0.9559 }, "description": "DistilBERT model for key point - argument semantic matching" } } ) model_name: str device: str max_length: int num_labels: int loaded: bool performance: Dict[str, float] = Field( ..., description="Model performance metrics" ) description: str = Field(..., description="Model description")