Yassine Mhirsi commited on
Commit
9db766f
·
1 Parent(s): 9001f9e

first test

Browse files
.dockerignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.so
7
+ *.egg
8
+ *.egg-info
9
+ dist
10
+ build
11
+ .git
12
+ .gitignore
13
+ .env
14
+ .venv
15
+ venv/
16
+ ENV/
17
+ env/
18
+ *.log
19
+ .DS_Store
20
+ README.md
21
+ models/
22
+ *.md
23
+
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.10 slim image for smaller size
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements first for better caching
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application code
19
+ COPY . .
20
+
21
+ # Expose port (Hugging Face Spaces uses port 7860 by default, but we'll use PORT env var)
22
+ EXPOSE 7860
23
+
24
+ # Set environment variables
25
+ ENV HOST=0.0.0.0
26
+ ENV PORT=7860
27
+ ENV RELOAD=False
28
+
29
+ # Run the application
30
+ # Hugging Face Spaces sets PORT environment variable automatically
31
+ CMD uvicorn main:app --host 0.0.0.0 --port ${PORT:-7860}
32
+
config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration settings for the API"""
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from dotenv import load_dotenv
6
+
7
+ # Load environment variables from .env file
8
+ load_dotenv()
9
+
10
+ # Get project root directory
11
+ API_DIR = Path(__file__).parent
12
+ PROJECT_ROOT = API_DIR.parent
13
+
14
+ # Hugging Face configuration
15
+ HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
16
+ HUGGINGFACE_MODEL_ID = os.getenv("HUGGINGFACE_MODEL_ID", "yassine-mhirsi/debertav3-stance-detection")
17
+
18
+ # Stance detection model configuration
19
+ # Use Hugging Face model ID instead of local path
20
+ STANCE_MODEL_ID = HUGGINGFACE_MODEL_ID
21
+
22
+ # API configuration
23
+ API_TITLE = "NLP Project API"
24
+ API_DESCRIPTION = "API for various NLP models including stance detection and more"
25
+ API_VERSION = "1.0.0"
26
+
27
+ # Server configuration
28
+ HOST = os.getenv("HOST", "0.0.0.0") # Use 0.0.0.0 for Docker/Spaces
29
+ PORT = int(os.getenv("PORT", "7860")) # Default 7860 for Hugging Face Spaces
30
+ RELOAD = os.getenv("RELOAD", "False").lower() == "true" # Set to False in production
31
+
32
+ # CORS configuration
33
+ CORS_ORIGINS = ["*"] # In production, specify exact origins
34
+ CORS_CREDENTIALS = True
35
+ CORS_METHODS = ["*"]
36
+ CORS_HEADERS = ["*"]
main.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main FastAPI application entry point"""
2
+
3
+ from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ import uvicorn
7
+ import logging
8
+
9
+ from config import (
10
+ API_TITLE,
11
+ API_DESCRIPTION,
12
+ API_VERSION,
13
+ STANCE_MODEL_ID,
14
+ HUGGINGFACE_API_KEY,
15
+ HOST,
16
+ PORT,
17
+ RELOAD,
18
+ CORS_ORIGINS,
19
+ CORS_CREDENTIALS,
20
+ CORS_METHODS,
21
+ CORS_HEADERS,
22
+ )
23
+ from services import stance_model_manager
24
+ from routes import api_router
25
+
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @asynccontextmanager
32
+ async def lifespan(app: FastAPI):
33
+ """Load models on startup and cleanup on shutdown"""
34
+ # Startup: Load all models
35
+ logger.info("Loading models on startup...")
36
+
37
+ # Load stance detection model
38
+ try:
39
+ logger.info(f"Loading stance model from Hugging Face: {STANCE_MODEL_ID}")
40
+ stance_model_manager.load_model(STANCE_MODEL_ID, HUGGINGFACE_API_KEY)
41
+ except Exception as e:
42
+ logger.error(f"✗ Failed to load stance model: {str(e)}")
43
+ logger.error("⚠️ Stance detection endpoints will not work!")
44
+
45
+ logger.info("✓ API startup complete")
46
+
47
+ yield # Application runs here
48
+
49
+ # Shutdown: Cleanup (if needed)
50
+ # Currently no cleanup needed, but you can add it here if necessary
51
+
52
+
53
+ # Create FastAPI application
54
+ app = FastAPI(
55
+ title=API_TITLE,
56
+ description=API_DESCRIPTION,
57
+ version=API_VERSION,
58
+ docs_url="/docs",
59
+ redoc_url="/redoc",
60
+ lifespan=lifespan,
61
+ )
62
+
63
+ # Add CORS middleware
64
+ app.add_middleware(
65
+ CORSMiddleware,
66
+ allow_origins=CORS_ORIGINS,
67
+ allow_credentials=CORS_CREDENTIALS,
68
+ allow_methods=CORS_METHODS,
69
+ allow_headers=CORS_HEADERS,
70
+ )
71
+
72
+ # Include API routes
73
+ app.include_router(api_router)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ # Run the API server
78
+ # Access at: http://localhost:8000
79
+ # API docs at: http://localhost:8000/docs
80
+
81
+ # Run the API server
82
+ uvicorn.run(
83
+ "main:app",
84
+ host=HOST,
85
+ port=PORT,
86
+ reload=RELOAD,
87
+ log_level="info"
88
+ )
models/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models and schemas for request/response validation"""
2
+
3
+ # Import stance-related schemas
4
+ from .stance import (
5
+ StanceRequest,
6
+ StanceResponse,
7
+ BatchStanceRequest,
8
+ BatchStanceResponse,
9
+ )
10
+
11
+ # Import health-related schemas
12
+ from .health import (
13
+ HealthResponse,
14
+ )
15
+
16
+ __all__ = [
17
+ # Stance schemas
18
+ "StanceRequest",
19
+ "StanceResponse",
20
+ "BatchStanceRequest",
21
+ "BatchStanceResponse",
22
+ # Health schemas
23
+ "HealthResponse",
24
+ ]
models/health.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic schemas for health check endpoints"""
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class HealthResponse(BaseModel):
7
+ """Health check response"""
8
+ status: str
9
+ model_loaded: bool
10
+ device: str
11
+ timestamp: str
12
+
models/stance.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic schemas for stance detection endpoints"""
2
+
3
+ from pydantic import BaseModel, Field, ConfigDict
4
+ from typing import List
5
+
6
+
7
+ class StanceRequest(BaseModel):
8
+ """Request model for stance prediction"""
9
+ model_config = ConfigDict(
10
+ json_schema_extra={
11
+ "example": {
12
+ "topic": "Assisted suicide should be a criminal offence",
13
+ "argument": "People have the right to choose how they end their lives"
14
+ }
15
+ }
16
+ )
17
+
18
+ topic: str = Field(..., min_length=5, max_length=500,
19
+ description="The debate topic or statement")
20
+ argument: str = Field(..., min_length=5, max_length=1000,
21
+ description="The argument text to classify")
22
+
23
+
24
+ class StanceResponse(BaseModel):
25
+ """Response model for stance prediction"""
26
+ model_config = ConfigDict(
27
+ json_schema_extra={
28
+ "example": {
29
+ "topic": "Assisted suicide should be a criminal offence",
30
+ "argument": "People have the right to choose how they end their lives",
31
+ "predicted_stance": "CON",
32
+ "confidence": 0.9234,
33
+ "probability_con": 0.9234,
34
+ "probability_pro": 0.0766,
35
+ "timestamp": "2024-11-15T10:30:00"
36
+ }
37
+ }
38
+ )
39
+
40
+ topic: str
41
+ argument: str
42
+ predicted_stance: str = Field(..., description="PRO or CON")
43
+ confidence: float = Field(..., ge=0.0, le=1.0)
44
+ probability_con: float
45
+ probability_pro: float
46
+ timestamp: str
47
+
48
+
49
+ class BatchStanceRequest(BaseModel):
50
+ """Request model for batch predictions"""
51
+ items: List[StanceRequest] = Field(..., max_length=50,
52
+ description="List of topic-argument pairs (max 50)")
53
+
54
+
55
+ class BatchStanceResponse(BaseModel):
56
+ """Response model for batch predictions"""
57
+ results: List[StanceResponse]
58
+ total_processed: int
59
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ pydantic==2.5.0
4
+ python-dotenv==1.0.0
5
+ torch>=2.0.0
6
+ transformers>=4.35.0
7
+ accelerate>=0.24.0
8
+
routes/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API route handlers"""
2
+
3
+ from fastapi import APIRouter
4
+ from . import root, health, stance
5
+
6
+ # Create main router
7
+ api_router = APIRouter()
8
+
9
+ # Include all route modules
10
+ api_router.include_router(root.router)
11
+ api_router.include_router(health.router)
12
+ api_router.include_router(stance.router)
13
+
14
+ __all__ = ["api_router"]
15
+
routes/health.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Health check endpoint"""
2
+
3
+ from fastapi import APIRouter
4
+ from datetime import datetime
5
+ from models import HealthResponse
6
+ from services import stance_model_manager
7
+
8
+ router = APIRouter()
9
+
10
+
11
+ @router.get("/health", response_model=HealthResponse, tags=["General"])
12
+ async def health_check():
13
+ """Health check endpoint"""
14
+ return HealthResponse(
15
+ status="healthy" if stance_model_manager.model_loaded else "unhealthy",
16
+ model_loaded=stance_model_manager.model_loaded,
17
+ device=str(stance_model_manager.device) if stance_model_manager.device else "unknown",
18
+ timestamp=datetime.now().isoformat()
19
+ )
20
+
routes/root.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Root endpoint for API information"""
2
+
3
+ from fastapi import APIRouter
4
+
5
+ router = APIRouter()
6
+
7
+
8
+ @router.get("/", response_model=dict, tags=["General"])
9
+ async def root():
10
+ """Root endpoint with API information"""
11
+ return {
12
+ "message": "NLP Project API",
13
+ "version": "1.0.0",
14
+ "features": {
15
+ "stance_detection": {
16
+ "predict": "/predict",
17
+ "batch_predict": "/batch-predict"
18
+ }
19
+ },
20
+ "endpoints": {
21
+ "health": "/health",
22
+ "docs": "/docs"
23
+ }
24
+ }
25
+
routes/stance.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stance detection endpoints"""
2
+
3
+ from fastapi import APIRouter, HTTPException
4
+ from datetime import datetime
5
+ import logging
6
+
7
+ from models import (
8
+ StanceRequest,
9
+ StanceResponse,
10
+ BatchStanceRequest,
11
+ BatchStanceResponse,
12
+ )
13
+ from services import stance_model_manager
14
+
15
+ router = APIRouter()
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @router.post("/predict", response_model=StanceResponse, tags=["Stance Detection"])
20
+ async def predict_stance(request: StanceRequest):
21
+ """
22
+ Predict stance for a single topic-argument pair
23
+
24
+ - **topic**: The debate topic or statement (5-500 chars)
25
+ - **argument**: The argument to classify (5-1000 chars)
26
+
27
+ Returns predicted stance (PRO/CON) with confidence scores
28
+ """
29
+ try:
30
+ # Make prediction
31
+ result = stance_model_manager.predict(request.topic, request.argument)
32
+
33
+ # Build response
34
+ response = StanceResponse(
35
+ topic=request.topic,
36
+ argument=request.argument,
37
+ predicted_stance=result["predicted_stance"],
38
+ confidence=result["confidence"],
39
+ probability_con=result["probability_con"],
40
+ probability_pro=result["probability_pro"],
41
+ timestamp=datetime.now().isoformat()
42
+ )
43
+
44
+ logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})")
45
+ return response
46
+
47
+ except Exception as e:
48
+ logger.error(f"Prediction error: {str(e)}")
49
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
50
+
51
+
52
+ @router.post("/batch-predict", response_model=BatchStanceResponse, tags=["Stance Detection"])
53
+ async def batch_predict_stance(request: BatchStanceRequest):
54
+ """
55
+ Predict stance for multiple topic-argument pairs
56
+
57
+ - **items**: List of topic-argument pairs (max 50)
58
+
59
+ Returns predictions for all items
60
+ """
61
+ try:
62
+ results = []
63
+
64
+ # Process each item
65
+ for item in request.items:
66
+ result = stance_model_manager.predict(item.topic, item.argument)
67
+
68
+ response = StanceResponse(
69
+ topic=item.topic,
70
+ argument=item.argument,
71
+ predicted_stance=result["predicted_stance"],
72
+ confidence=result["confidence"],
73
+ probability_con=result["probability_con"],
74
+ probability_pro=result["probability_pro"],
75
+ timestamp=datetime.now().isoformat()
76
+ )
77
+ results.append(response)
78
+
79
+ logger.info(f"Batch prediction completed: {len(results)} items")
80
+
81
+ return BatchStanceResponse(
82
+ results=results,
83
+ total_processed=len(results)
84
+ )
85
+
86
+ except Exception as e:
87
+ logger.error(f"Batch prediction error: {str(e)}")
88
+ raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")
89
+
services/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Services for business logic and external integrations"""
2
+
3
+ from .stance_model_manager import StanceModelManager, stance_model_manager
4
+
5
+ __all__ = [
6
+ "StanceModelManager",
7
+ "stance_model_manager",
8
+ ]
services/stance_model_manager.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model manager for stance detection model"""
2
+
3
+ import os
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class StanceModelManager:
12
+ """Manages stance detection model loading and predictions"""
13
+
14
+ def __init__(self):
15
+ self.model = None
16
+ self.tokenizer = None
17
+ self.device = None
18
+ self.model_loaded = False
19
+
20
+ def load_model(self, model_id: str, api_key: str = None):
21
+ """Load model and tokenizer from Hugging Face"""
22
+ if self.model_loaded:
23
+ logger.info("Stance model already loaded")
24
+ return
25
+
26
+ try:
27
+ logger.info(f"Loading stance model from Hugging Face: {model_id}")
28
+
29
+ # Determine device
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ logger.info(f"Using device: {self.device}")
32
+
33
+ # Prepare token for authentication if API key is provided
34
+ token = api_key if api_key else None
35
+
36
+ # Load tokenizer and model from Hugging Face
37
+ logger.info("Loading tokenizer...")
38
+ self.tokenizer = AutoTokenizer.from_pretrained(
39
+ model_id,
40
+ token=token,
41
+ trust_remote_code=True
42
+ )
43
+
44
+ logger.info("Loading model...")
45
+ self.model = AutoModelForSequenceClassification.from_pretrained(
46
+ model_id,
47
+ token=token,
48
+ trust_remote_code=True
49
+ )
50
+ self.model.to(self.device)
51
+ self.model.eval()
52
+
53
+ self.model_loaded = True
54
+ logger.info("✓ Stance model loaded successfully from Hugging Face!")
55
+
56
+ except Exception as e:
57
+ logger.error(f"Error loading stance model: {str(e)}")
58
+ raise RuntimeError(f"Failed to load stance model: {str(e)}")
59
+
60
+ def predict(self, topic: str, argument: str) -> dict:
61
+ """Make a single stance prediction"""
62
+ if not self.model_loaded:
63
+ raise RuntimeError("Stance model not loaded")
64
+
65
+ # Format input
66
+ text = f"Topic: {topic} [SEP] Argument: {argument}"
67
+
68
+ # Tokenize
69
+ inputs = self.tokenizer(
70
+ text,
71
+ return_tensors="pt",
72
+ truncation=True,
73
+ max_length=512,
74
+ padding=True
75
+ ).to(self.device)
76
+
77
+ # Predict
78
+ with torch.no_grad():
79
+ outputs = self.model(**inputs)
80
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
81
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
82
+
83
+ # Extract probabilities
84
+ prob_con = probabilities[0][0].item()
85
+ prob_pro = probabilities[0][1].item()
86
+
87
+ # Determine stance
88
+ stance = "PRO" if predicted_class == 1 else "CON"
89
+ confidence = probabilities[0][predicted_class].item()
90
+
91
+ return {
92
+ "predicted_stance": stance,
93
+ "confidence": confidence,
94
+ "probability_con": prob_con,
95
+ "probability_pro": prob_pro
96
+ }
97
+
98
+
99
+ # Initialize singleton instance
100
+ stance_model_manager = StanceModelManager()
101
+