File size: 7,720 Bytes
03977cf 5fb4696 674469e c45f0d6 5fb4696 d43ba60 674469e 2da4544 674469e 5fb4696 03977cf d43ba60 03977cf 2da4544 5fb4696 d43ba60 9db766f 674469e 2da4544 c45f0d6 9db766f c45f0d6 d43ba60 c45f0d6 d43ba60 9db766f 674469e c45f0d6 674469e c45f0d6 674469e c45f0d6 d43ba60 c45f0d6 674469e d43ba60 c45f0d6 674469e 2da4544 d43ba60 c45f0d6 674469e d43ba60 c45f0d6 674469e 2da4544 9db766f c45f0d6 2da4544 c45f0d6 2da4544 d43ba60 9db766f c45f0d6 9db766f d43ba60 9db766f d43ba60 c45f0d6 56dc677 2da4544 c45f0d6 56dc677 2da4544 56dc677 c45f0d6 674469e 2da4544 c45f0d6 674469e 2da4544 674469e c45f0d6 674469e 2da4544 c45f0d6 674469e 2da4544 1d46e48 d43ba60 c45f0d6 2da4544 c45f0d6 674469e c45f0d6 5fb4696 c45f0d6 5fb4696 2da4544 c45f0d6 5fb4696 9db766f d43ba60 9db766f c45f0d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
import sys
from pathlib import Path
import logging
from contextlib import asynccontextmanager
import atexit
import shutil
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# --- Logging ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# --- Ajouter app dir au PATH ---
app_dir = Path(__file__).parent
sys.path.insert(0, str(app_dir))
# --- Config ---
from config import (
API_TITLE, API_DESCRIPTION, API_VERSION,
HUGGINGFACE_API_KEY, HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_LABEL_MODEL_ID,
HOST, PORT, RELOAD,
CORS_ORIGINS, CORS_METHODS, CORS_HEADERS, CORS_CREDENTIALS,
PRELOAD_MODELS_ON_STARTUP, LOAD_STANCE_MODEL, LOAD_KPA_MODEL,
GROQ_API_KEY, GROQ_STT_MODEL, GROQ_TTS_MODEL, GROQ_CHAT_MODEL
)
# --- Fonction de nettoyage ---
def cleanup_temp_files():
"""Nettoyer les fichiers temporaires audio au démarrage"""
temp_dir = Path("temp_audio")
if temp_dir.exists():
try:
shutil.rmtree(temp_dir)
logger.info("✓ Cleaned up previous temp audio files")
except Exception as e:
logger.warning(f"⚠ Could not clean temp directory: {e}")
# Appeler au démarrage
cleanup_temp_files()
# Configurer le nettoyage à la fermeture
@atexit.register
def cleanup_on_exit():
temp_dir = Path("temp_audio")
if temp_dir.exists():
try:
shutil.rmtree(temp_dir)
except:
pass
# --- Import des singletons de services ---
try:
from services.stance_model_manager import stance_model_manager
from services.label_model_manager import kpa_model_manager
logger.info("✓ Model managers imported")
except ImportError as e:
logger.warning(f"⚠ Could not import model managers: {e}")
stance_model_manager = None
kpa_model_manager = None
# --- Lifespan / startup API ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("="*60)
logger.info("🚀 API STARTUP - Loading models and checking APIs...")
logger.info("="*60)
# Vérifier les clés API
if not GROQ_API_KEY:
logger.warning("⚠ GROQ_API_KEY is not set. STT/TTS features may not work.")
else:
logger.info("✓ GROQ_API_KEY is configured")
if not HUGGINGFACE_API_KEY:
logger.warning("⚠ HUGGINGFACE_API_KEY is not set. Local models may not work.")
else:
logger.info("✓ HUGGINGFACE_API_KEY is configured")
# Précharger les modèles Hugging Face si configuré
if PRELOAD_MODELS_ON_STARTUP:
# Charger stance model
if LOAD_STANCE_MODEL and stance_model_manager and HUGGINGFACE_STANCE_MODEL_ID:
try:
stance_model_manager.load_model(HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_API_KEY)
logger.info("✓ Stance model loaded successfully")
except Exception as e:
logger.error(f"✗ Failed loading stance model: {e}")
# Charger KPA model
if LOAD_KPA_MODEL and kpa_model_manager and HUGGINGFACE_LABEL_MODEL_ID:
try:
kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY)
logger.info("✓ KPA model loaded successfully")
except Exception as e:
logger.error(f"✗ Failed loading KPA model: {e}")
logger.info("="*60)
logger.info("✓ Startup complete. API ready to receive requests.")
logger.info(f" STT Model: {GROQ_STT_MODEL}")
logger.info(f" TTS Model: {GROQ_TTS_MODEL}")
logger.info(f" Chat Model: {GROQ_CHAT_MODEL}")
logger.info("="*60)
yield
logger.info("🛑 Shutting down API...")
# Nettoyage final
cleanup_on_exit()
# --- FastAPI app ---
app = FastAPI(
title=API_TITLE,
description=API_DESCRIPTION,
version=API_VERSION,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# --- CORS ---
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=CORS_CREDENTIALS,
allow_methods=CORS_METHODS,
allow_headers=CORS_HEADERS,
)
# --- Routes ---
# STT Routes
try:
from routes.stt_routes import router as stt_router
app.include_router(stt_router, prefix="/api/v1/stt", tags=["Speech To Text"])
logger.info("✓ STT route loaded (Groq Whisper)")
except ImportError as e:
logger.warning(f"⚠ STT route not found: {e}")
except Exception as e:
logger.warning(f"⚠ Failed loading STT route: {e}")
# TTS Routes
try:
from routes.tts_routes import router as tts_router
app.include_router(tts_router, prefix="/api/v1/tts", tags=["Text To Speech"])
logger.info("✓ TTS route loaded (Groq PlayAI TTS)")
except ImportError as e:
logger.warning(f"⚠ TTS route not found: {e}")
except Exception as e:
logger.warning(f"⚠ Failed loading TTS route: {e}")
# Main API Routes
try:
from routes import api_router
app.include_router(api_router)
logger.info("✓ Main API routes loaded")
except ImportError as e:
logger.warning(f"⚠ Main API routes not found: {e}")
except Exception as e:
logger.warning(f"⚠ Failed loading main API routes: {e}")
# Dans main.py, après les autres routes
try:
from routes.voice_chat_routes import router as voice_chat_router
app.include_router(voice_chat_router, tags=["Voice Chat"])
logger.info("✓ Voice Chat route loaded")
except ImportError as e:
logger.warning(f"⚠ Voice Chat route not found: {e}")
except Exception as e:
logger.warning(f"⚠ Failed loading Voice Chat route: {e}")
# --- Basic routes ---
@app.get("/health", tags=["Health"])
async def health():
"""Health check endpoint"""
health_status = {
"status": "healthy",
"service": "NLP Debater + Groq Voice",
"features": {
"stt": GROQ_STT_MODEL if GROQ_API_KEY else "disabled",
"tts": GROQ_TTS_MODEL if GROQ_API_KEY else "disabled",
"chat": GROQ_CHAT_MODEL if GROQ_API_KEY else "disabled",
"stance_model": "loaded" if (stance_model_manager and stance_model_manager.model is not None) else "not loaded",
"kpa_model": "loaded" if (kpa_model_manager and kpa_model_manager.model is not None) else "not loaded"
}
}
return health_status
@app.get("/", tags=["Root"])
async def root():
"""Root endpoint with API information"""
return {
"message": "NLP Debater API with Groq Voice Support",
"version": API_VERSION,
"endpoints": {
"docs": "/docs",
"redoc": "/redoc",
"health": "/health",
"stt": "/api/v1/stt/",
"tts": "/api/v1/tts/"
},
"models": {
"stt": GROQ_STT_MODEL,
"tts": GROQ_TTS_MODEL,
"chat": GROQ_CHAT_MODEL
}
}
# --- Error handlers ---
@app.exception_handler(404)
async def not_found_handler(request, exc):
return {
"error": "Not Found",
"message": f"The requested URL {request.url} was not found",
"available_endpoints": {
"GET /": "API information",
"GET /health": "Health check",
"POST /api/v1/stt/": "Speech to text",
"POST /api/v1/tts/": "Text to speech"
}
}
# --- Run server ---
if __name__ == "__main__":
logger.info("="*60)
logger.info(f"Starting server on {HOST}:{PORT}")
logger.info(f"Reload mode: {RELOAD}")
logger.info("="*60)
uvicorn.run(
"main:app",
host=HOST,
port=PORT,
reload=RELOAD,
log_level="info"
) |