Yassine Mhirsi
commited on
Commit
·
9db766f
1
Parent(s):
9001f9e
first test
Browse files- .dockerignore +23 -0
- Dockerfile +32 -0
- config.py +36 -0
- main.py +88 -0
- models/__init__.py +24 -0
- models/health.py +12 -0
- models/stance.py +59 -0
- requirements.txt +8 -0
- routes/__init__.py +15 -0
- routes/health.py +20 -0
- routes/root.py +25 -0
- routes/stance.py +89 -0
- services/__init__.py +8 -0
- services/stance_model_manager.py +101 -0
.dockerignore
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
*.so
|
| 7 |
+
*.egg
|
| 8 |
+
*.egg-info
|
| 9 |
+
dist
|
| 10 |
+
build
|
| 11 |
+
.git
|
| 12 |
+
.gitignore
|
| 13 |
+
.env
|
| 14 |
+
.venv
|
| 15 |
+
venv/
|
| 16 |
+
ENV/
|
| 17 |
+
env/
|
| 18 |
+
*.log
|
| 19 |
+
.DS_Store
|
| 20 |
+
README.md
|
| 21 |
+
models/
|
| 22 |
+
*.md
|
| 23 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.10 slim image for smaller size
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
build-essential \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements first for better caching
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy application code
|
| 19 |
+
COPY . .
|
| 20 |
+
|
| 21 |
+
# Expose port (Hugging Face Spaces uses port 7860 by default, but we'll use PORT env var)
|
| 22 |
+
EXPOSE 7860
|
| 23 |
+
|
| 24 |
+
# Set environment variables
|
| 25 |
+
ENV HOST=0.0.0.0
|
| 26 |
+
ENV PORT=7860
|
| 27 |
+
ENV RELOAD=False
|
| 28 |
+
|
| 29 |
+
# Run the application
|
| 30 |
+
# Hugging Face Spaces sets PORT environment variable automatically
|
| 31 |
+
CMD uvicorn main:app --host 0.0.0.0 --port ${PORT:-7860}
|
| 32 |
+
|
config.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration settings for the API"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# Load environment variables from .env file
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
# Get project root directory
|
| 11 |
+
API_DIR = Path(__file__).parent
|
| 12 |
+
PROJECT_ROOT = API_DIR.parent
|
| 13 |
+
|
| 14 |
+
# Hugging Face configuration
|
| 15 |
+
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
|
| 16 |
+
HUGGINGFACE_MODEL_ID = os.getenv("HUGGINGFACE_MODEL_ID", "yassine-mhirsi/debertav3-stance-detection")
|
| 17 |
+
|
| 18 |
+
# Stance detection model configuration
|
| 19 |
+
# Use Hugging Face model ID instead of local path
|
| 20 |
+
STANCE_MODEL_ID = HUGGINGFACE_MODEL_ID
|
| 21 |
+
|
| 22 |
+
# API configuration
|
| 23 |
+
API_TITLE = "NLP Project API"
|
| 24 |
+
API_DESCRIPTION = "API for various NLP models including stance detection and more"
|
| 25 |
+
API_VERSION = "1.0.0"
|
| 26 |
+
|
| 27 |
+
# Server configuration
|
| 28 |
+
HOST = os.getenv("HOST", "0.0.0.0") # Use 0.0.0.0 for Docker/Spaces
|
| 29 |
+
PORT = int(os.getenv("PORT", "7860")) # Default 7860 for Hugging Face Spaces
|
| 30 |
+
RELOAD = os.getenv("RELOAD", "False").lower() == "true" # Set to False in production
|
| 31 |
+
|
| 32 |
+
# CORS configuration
|
| 33 |
+
CORS_ORIGINS = ["*"] # In production, specify exact origins
|
| 34 |
+
CORS_CREDENTIALS = True
|
| 35 |
+
CORS_METHODS = ["*"]
|
| 36 |
+
CORS_HEADERS = ["*"]
|
main.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main FastAPI application entry point"""
|
| 2 |
+
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
from fastapi import FastAPI
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
import uvicorn
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from config import (
|
| 10 |
+
API_TITLE,
|
| 11 |
+
API_DESCRIPTION,
|
| 12 |
+
API_VERSION,
|
| 13 |
+
STANCE_MODEL_ID,
|
| 14 |
+
HUGGINGFACE_API_KEY,
|
| 15 |
+
HOST,
|
| 16 |
+
PORT,
|
| 17 |
+
RELOAD,
|
| 18 |
+
CORS_ORIGINS,
|
| 19 |
+
CORS_CREDENTIALS,
|
| 20 |
+
CORS_METHODS,
|
| 21 |
+
CORS_HEADERS,
|
| 22 |
+
)
|
| 23 |
+
from services import stance_model_manager
|
| 24 |
+
from routes import api_router
|
| 25 |
+
|
| 26 |
+
# Configure logging
|
| 27 |
+
logging.basicConfig(level=logging.INFO)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@asynccontextmanager
|
| 32 |
+
async def lifespan(app: FastAPI):
|
| 33 |
+
"""Load models on startup and cleanup on shutdown"""
|
| 34 |
+
# Startup: Load all models
|
| 35 |
+
logger.info("Loading models on startup...")
|
| 36 |
+
|
| 37 |
+
# Load stance detection model
|
| 38 |
+
try:
|
| 39 |
+
logger.info(f"Loading stance model from Hugging Face: {STANCE_MODEL_ID}")
|
| 40 |
+
stance_model_manager.load_model(STANCE_MODEL_ID, HUGGINGFACE_API_KEY)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error(f"✗ Failed to load stance model: {str(e)}")
|
| 43 |
+
logger.error("⚠️ Stance detection endpoints will not work!")
|
| 44 |
+
|
| 45 |
+
logger.info("✓ API startup complete")
|
| 46 |
+
|
| 47 |
+
yield # Application runs here
|
| 48 |
+
|
| 49 |
+
# Shutdown: Cleanup (if needed)
|
| 50 |
+
# Currently no cleanup needed, but you can add it here if necessary
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Create FastAPI application
|
| 54 |
+
app = FastAPI(
|
| 55 |
+
title=API_TITLE,
|
| 56 |
+
description=API_DESCRIPTION,
|
| 57 |
+
version=API_VERSION,
|
| 58 |
+
docs_url="/docs",
|
| 59 |
+
redoc_url="/redoc",
|
| 60 |
+
lifespan=lifespan,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Add CORS middleware
|
| 64 |
+
app.add_middleware(
|
| 65 |
+
CORSMiddleware,
|
| 66 |
+
allow_origins=CORS_ORIGINS,
|
| 67 |
+
allow_credentials=CORS_CREDENTIALS,
|
| 68 |
+
allow_methods=CORS_METHODS,
|
| 69 |
+
allow_headers=CORS_HEADERS,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Include API routes
|
| 73 |
+
app.include_router(api_router)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
# Run the API server
|
| 78 |
+
# Access at: http://localhost:8000
|
| 79 |
+
# API docs at: http://localhost:8000/docs
|
| 80 |
+
|
| 81 |
+
# Run the API server
|
| 82 |
+
uvicorn.run(
|
| 83 |
+
"main:app",
|
| 84 |
+
host=HOST,
|
| 85 |
+
port=PORT,
|
| 86 |
+
reload=RELOAD,
|
| 87 |
+
log_level="info"
|
| 88 |
+
)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models and schemas for request/response validation"""
|
| 2 |
+
|
| 3 |
+
# Import stance-related schemas
|
| 4 |
+
from .stance import (
|
| 5 |
+
StanceRequest,
|
| 6 |
+
StanceResponse,
|
| 7 |
+
BatchStanceRequest,
|
| 8 |
+
BatchStanceResponse,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Import health-related schemas
|
| 12 |
+
from .health import (
|
| 13 |
+
HealthResponse,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
# Stance schemas
|
| 18 |
+
"StanceRequest",
|
| 19 |
+
"StanceResponse",
|
| 20 |
+
"BatchStanceRequest",
|
| 21 |
+
"BatchStanceResponse",
|
| 22 |
+
# Health schemas
|
| 23 |
+
"HealthResponse",
|
| 24 |
+
]
|
models/health.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic schemas for health check endpoints"""
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HealthResponse(BaseModel):
|
| 7 |
+
"""Health check response"""
|
| 8 |
+
status: str
|
| 9 |
+
model_loaded: bool
|
| 10 |
+
device: str
|
| 11 |
+
timestamp: str
|
| 12 |
+
|
models/stance.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic schemas for stance detection endpoints"""
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class StanceRequest(BaseModel):
|
| 8 |
+
"""Request model for stance prediction"""
|
| 9 |
+
model_config = ConfigDict(
|
| 10 |
+
json_schema_extra={
|
| 11 |
+
"example": {
|
| 12 |
+
"topic": "Assisted suicide should be a criminal offence",
|
| 13 |
+
"argument": "People have the right to choose how they end their lives"
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
topic: str = Field(..., min_length=5, max_length=500,
|
| 19 |
+
description="The debate topic or statement")
|
| 20 |
+
argument: str = Field(..., min_length=5, max_length=1000,
|
| 21 |
+
description="The argument text to classify")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class StanceResponse(BaseModel):
|
| 25 |
+
"""Response model for stance prediction"""
|
| 26 |
+
model_config = ConfigDict(
|
| 27 |
+
json_schema_extra={
|
| 28 |
+
"example": {
|
| 29 |
+
"topic": "Assisted suicide should be a criminal offence",
|
| 30 |
+
"argument": "People have the right to choose how they end their lives",
|
| 31 |
+
"predicted_stance": "CON",
|
| 32 |
+
"confidence": 0.9234,
|
| 33 |
+
"probability_con": 0.9234,
|
| 34 |
+
"probability_pro": 0.0766,
|
| 35 |
+
"timestamp": "2024-11-15T10:30:00"
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
topic: str
|
| 41 |
+
argument: str
|
| 42 |
+
predicted_stance: str = Field(..., description="PRO or CON")
|
| 43 |
+
confidence: float = Field(..., ge=0.0, le=1.0)
|
| 44 |
+
probability_con: float
|
| 45 |
+
probability_pro: float
|
| 46 |
+
timestamp: str
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class BatchStanceRequest(BaseModel):
|
| 50 |
+
"""Request model for batch predictions"""
|
| 51 |
+
items: List[StanceRequest] = Field(..., max_length=50,
|
| 52 |
+
description="List of topic-argument pairs (max 50)")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BatchStanceResponse(BaseModel):
|
| 56 |
+
"""Response model for batch predictions"""
|
| 57 |
+
results: List[StanceResponse]
|
| 58 |
+
total_processed: int
|
| 59 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn[standard]==0.24.0
|
| 3 |
+
pydantic==2.5.0
|
| 4 |
+
python-dotenv==1.0.0
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
transformers>=4.35.0
|
| 7 |
+
accelerate>=0.24.0
|
| 8 |
+
|
routes/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API route handlers"""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter
|
| 4 |
+
from . import root, health, stance
|
| 5 |
+
|
| 6 |
+
# Create main router
|
| 7 |
+
api_router = APIRouter()
|
| 8 |
+
|
| 9 |
+
# Include all route modules
|
| 10 |
+
api_router.include_router(root.router)
|
| 11 |
+
api_router.include_router(health.router)
|
| 12 |
+
api_router.include_router(stance.router)
|
| 13 |
+
|
| 14 |
+
__all__ = ["api_router"]
|
| 15 |
+
|
routes/health.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Health check endpoint"""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from models import HealthResponse
|
| 6 |
+
from services import stance_model_manager
|
| 7 |
+
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@router.get("/health", response_model=HealthResponse, tags=["General"])
|
| 12 |
+
async def health_check():
|
| 13 |
+
"""Health check endpoint"""
|
| 14 |
+
return HealthResponse(
|
| 15 |
+
status="healthy" if stance_model_manager.model_loaded else "unhealthy",
|
| 16 |
+
model_loaded=stance_model_manager.model_loaded,
|
| 17 |
+
device=str(stance_model_manager.device) if stance_model_manager.device else "unknown",
|
| 18 |
+
timestamp=datetime.now().isoformat()
|
| 19 |
+
)
|
| 20 |
+
|
routes/root.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Root endpoint for API information"""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter
|
| 4 |
+
|
| 5 |
+
router = APIRouter()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@router.get("/", response_model=dict, tags=["General"])
|
| 9 |
+
async def root():
|
| 10 |
+
"""Root endpoint with API information"""
|
| 11 |
+
return {
|
| 12 |
+
"message": "NLP Project API",
|
| 13 |
+
"version": "1.0.0",
|
| 14 |
+
"features": {
|
| 15 |
+
"stance_detection": {
|
| 16 |
+
"predict": "/predict",
|
| 17 |
+
"batch_predict": "/batch-predict"
|
| 18 |
+
}
|
| 19 |
+
},
|
| 20 |
+
"endpoints": {
|
| 21 |
+
"health": "/health",
|
| 22 |
+
"docs": "/docs"
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
routes/stance.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stance detection endpoints"""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, HTTPException
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from models import (
|
| 8 |
+
StanceRequest,
|
| 9 |
+
StanceResponse,
|
| 10 |
+
BatchStanceRequest,
|
| 11 |
+
BatchStanceResponse,
|
| 12 |
+
)
|
| 13 |
+
from services import stance_model_manager
|
| 14 |
+
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@router.post("/predict", response_model=StanceResponse, tags=["Stance Detection"])
|
| 20 |
+
async def predict_stance(request: StanceRequest):
|
| 21 |
+
"""
|
| 22 |
+
Predict stance for a single topic-argument pair
|
| 23 |
+
|
| 24 |
+
- **topic**: The debate topic or statement (5-500 chars)
|
| 25 |
+
- **argument**: The argument to classify (5-1000 chars)
|
| 26 |
+
|
| 27 |
+
Returns predicted stance (PRO/CON) with confidence scores
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
# Make prediction
|
| 31 |
+
result = stance_model_manager.predict(request.topic, request.argument)
|
| 32 |
+
|
| 33 |
+
# Build response
|
| 34 |
+
response = StanceResponse(
|
| 35 |
+
topic=request.topic,
|
| 36 |
+
argument=request.argument,
|
| 37 |
+
predicted_stance=result["predicted_stance"],
|
| 38 |
+
confidence=result["confidence"],
|
| 39 |
+
probability_con=result["probability_con"],
|
| 40 |
+
probability_pro=result["probability_pro"],
|
| 41 |
+
timestamp=datetime.now().isoformat()
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})")
|
| 45 |
+
return response
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error(f"Prediction error: {str(e)}")
|
| 49 |
+
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@router.post("/batch-predict", response_model=BatchStanceResponse, tags=["Stance Detection"])
|
| 53 |
+
async def batch_predict_stance(request: BatchStanceRequest):
|
| 54 |
+
"""
|
| 55 |
+
Predict stance for multiple topic-argument pairs
|
| 56 |
+
|
| 57 |
+
- **items**: List of topic-argument pairs (max 50)
|
| 58 |
+
|
| 59 |
+
Returns predictions for all items
|
| 60 |
+
"""
|
| 61 |
+
try:
|
| 62 |
+
results = []
|
| 63 |
+
|
| 64 |
+
# Process each item
|
| 65 |
+
for item in request.items:
|
| 66 |
+
result = stance_model_manager.predict(item.topic, item.argument)
|
| 67 |
+
|
| 68 |
+
response = StanceResponse(
|
| 69 |
+
topic=item.topic,
|
| 70 |
+
argument=item.argument,
|
| 71 |
+
predicted_stance=result["predicted_stance"],
|
| 72 |
+
confidence=result["confidence"],
|
| 73 |
+
probability_con=result["probability_con"],
|
| 74 |
+
probability_pro=result["probability_pro"],
|
| 75 |
+
timestamp=datetime.now().isoformat()
|
| 76 |
+
)
|
| 77 |
+
results.append(response)
|
| 78 |
+
|
| 79 |
+
logger.info(f"Batch prediction completed: {len(results)} items")
|
| 80 |
+
|
| 81 |
+
return BatchStanceResponse(
|
| 82 |
+
results=results,
|
| 83 |
+
total_processed=len(results)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.error(f"Batch prediction error: {str(e)}")
|
| 88 |
+
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")
|
| 89 |
+
|
services/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Services for business logic and external integrations"""
|
| 2 |
+
|
| 3 |
+
from .stance_model_manager import StanceModelManager, stance_model_manager
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"StanceModelManager",
|
| 7 |
+
"stance_model_manager",
|
| 8 |
+
]
|
services/stance_model_manager.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model manager for stance detection model"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class StanceModelManager:
|
| 12 |
+
"""Manages stance detection model loading and predictions"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.model = None
|
| 16 |
+
self.tokenizer = None
|
| 17 |
+
self.device = None
|
| 18 |
+
self.model_loaded = False
|
| 19 |
+
|
| 20 |
+
def load_model(self, model_id: str, api_key: str = None):
|
| 21 |
+
"""Load model and tokenizer from Hugging Face"""
|
| 22 |
+
if self.model_loaded:
|
| 23 |
+
logger.info("Stance model already loaded")
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
logger.info(f"Loading stance model from Hugging Face: {model_id}")
|
| 28 |
+
|
| 29 |
+
# Determine device
|
| 30 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
logger.info(f"Using device: {self.device}")
|
| 32 |
+
|
| 33 |
+
# Prepare token for authentication if API key is provided
|
| 34 |
+
token = api_key if api_key else None
|
| 35 |
+
|
| 36 |
+
# Load tokenizer and model from Hugging Face
|
| 37 |
+
logger.info("Loading tokenizer...")
|
| 38 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 39 |
+
model_id,
|
| 40 |
+
token=token,
|
| 41 |
+
trust_remote_code=True
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
logger.info("Loading model...")
|
| 45 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 46 |
+
model_id,
|
| 47 |
+
token=token,
|
| 48 |
+
trust_remote_code=True
|
| 49 |
+
)
|
| 50 |
+
self.model.to(self.device)
|
| 51 |
+
self.model.eval()
|
| 52 |
+
|
| 53 |
+
self.model_loaded = True
|
| 54 |
+
logger.info("✓ Stance model loaded successfully from Hugging Face!")
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Error loading stance model: {str(e)}")
|
| 58 |
+
raise RuntimeError(f"Failed to load stance model: {str(e)}")
|
| 59 |
+
|
| 60 |
+
def predict(self, topic: str, argument: str) -> dict:
|
| 61 |
+
"""Make a single stance prediction"""
|
| 62 |
+
if not self.model_loaded:
|
| 63 |
+
raise RuntimeError("Stance model not loaded")
|
| 64 |
+
|
| 65 |
+
# Format input
|
| 66 |
+
text = f"Topic: {topic} [SEP] Argument: {argument}"
|
| 67 |
+
|
| 68 |
+
# Tokenize
|
| 69 |
+
inputs = self.tokenizer(
|
| 70 |
+
text,
|
| 71 |
+
return_tensors="pt",
|
| 72 |
+
truncation=True,
|
| 73 |
+
max_length=512,
|
| 74 |
+
padding=True
|
| 75 |
+
).to(self.device)
|
| 76 |
+
|
| 77 |
+
# Predict
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
outputs = self.model(**inputs)
|
| 80 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 81 |
+
predicted_class = torch.argmax(probabilities, dim=-1).item()
|
| 82 |
+
|
| 83 |
+
# Extract probabilities
|
| 84 |
+
prob_con = probabilities[0][0].item()
|
| 85 |
+
prob_pro = probabilities[0][1].item()
|
| 86 |
+
|
| 87 |
+
# Determine stance
|
| 88 |
+
stance = "PRO" if predicted_class == 1 else "CON"
|
| 89 |
+
confidence = probabilities[0][predicted_class].item()
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"predicted_stance": stance,
|
| 93 |
+
"confidence": confidence,
|
| 94 |
+
"probability_con": prob_con,
|
| 95 |
+
"probability_pro": prob_pro
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Initialize singleton instance
|
| 100 |
+
stance_model_manager = StanceModelManager()
|
| 101 |
+
|