Yassine Mhirsi commited on
Commit
2380f6f
·
1 Parent(s): 682062a

feat: Add topic-related schemas and API routes for topic management, along with LangChain dependencies in requirements.

Browse files
models/__init__.py CHANGED
@@ -28,6 +28,14 @@ from .generate import (
28
  GenerateResponse,
29
  )
30
 
 
 
 
 
 
 
 
 
31
  # Import MCP-related schemas
32
  from .mcp_models import (
33
  ToolCallRequest,
@@ -60,6 +68,11 @@ __all__ = [
60
  # Generate schemas
61
  "GenerateRequest",
62
  "GenerateResponse",
 
 
 
 
 
63
  # MCP schemas
64
  "ToolCallRequest",
65
  "ToolCallResponse",
 
28
  GenerateResponse,
29
  )
30
 
31
+ # Import topic-related schemas
32
+ from .topic import (
33
+ TopicRequest,
34
+ TopicResponse,
35
+ BatchTopicRequest,
36
+ BatchTopicResponse,
37
+ )
38
+
39
  # Import MCP-related schemas
40
  from .mcp_models import (
41
  ToolCallRequest,
 
68
  # Generate schemas
69
  "GenerateRequest",
70
  "GenerateResponse",
71
+ # Topic schemas
72
+ "TopicRequest",
73
+ "TopicResponse",
74
+ "BatchTopicRequest",
75
+ "BatchTopicResponse",
76
  # MCP schemas
77
  "ToolCallRequest",
78
  "ToolCallResponse",
models/topic.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for topic extraction endpoints"""
2
+
3
+ from pydantic import BaseModel, Field, ConfigDict
4
+ from typing import List, Optional
5
+
6
+
7
+ class TopicRequest(BaseModel):
8
+ """Request model for topic extraction"""
9
+ model_config = ConfigDict(
10
+ json_schema_extra={
11
+ "example": {
12
+ "text": "Social media companies must NOT be allowed to track people across websites."
13
+ }
14
+ }
15
+ )
16
+
17
+ text: str = Field(
18
+ ..., min_length=5, max_length=5000,
19
+ description="The text/argument to extract topic from"
20
+ )
21
+
22
+
23
+ class TopicResponse(BaseModel):
24
+ """Response model for topic extraction"""
25
+ model_config = ConfigDict(
26
+ json_schema_extra={
27
+ "example": {
28
+ "text": "Social media companies must NOT be allowed to track people across websites.",
29
+ "topic": "social media tracking and cross-website user privacy",
30
+ "timestamp": "2024-01-01T12:00:00Z"
31
+ }
32
+ }
33
+ )
34
+
35
+ text: str = Field(..., description="The original input text")
36
+ topic: str = Field(..., description="The extracted topic")
37
+ timestamp: str = Field(..., description="Timestamp of the extraction")
38
+
39
+
40
+ class BatchTopicRequest(BaseModel):
41
+ """Request model for batch topic extraction"""
42
+ model_config = ConfigDict(
43
+ json_schema_extra={
44
+ "example": {
45
+ "texts": [
46
+ "Social media companies must NOT be allowed to track people across websites.",
47
+ "I don't think universal basic income is a good idea — it'll disincentivize work.",
48
+ "We must invest in renewable energy to combat climate change."
49
+ ]
50
+ }
51
+ }
52
+ )
53
+
54
+ texts: List[str] = Field(
55
+ ..., min_length=1, max_length=50,
56
+ description="List of texts to extract topics from (max 50)"
57
+ )
58
+
59
+
60
+ class BatchTopicResponse(BaseModel):
61
+ """Response model for batch topic extraction"""
62
+ model_config = ConfigDict(
63
+ json_schema_extra={
64
+ "example": {
65
+ "results": [
66
+ {
67
+ "text": "Social media companies must NOT be allowed to track people across websites.",
68
+ "topic": "social media tracking and cross-website user privacy",
69
+ "timestamp": "2024-01-01T12:00:00Z"
70
+ },
71
+ {
72
+ "text": "I don't think universal basic income is a good idea — it'll disincentivize work.",
73
+ "topic": "universal basic income and its impact on work incentives",
74
+ "timestamp": "2024-01-01T12:00:00Z"
75
+ }
76
+ ],
77
+ "total_processed": 2,
78
+ "timestamp": "2024-01-01T12:00:00Z"
79
+ }
80
+ }
81
+ )
82
+
83
+ results: List[TopicResponse] = Field(..., description="List of topic extraction results")
84
+ total_processed: int = Field(..., description="Number of texts processed")
85
+ timestamp: str = Field(..., description="Timestamp of the batch extraction")
86
+
requirements.txt CHANGED
@@ -8,6 +8,12 @@ pydantic>=2.5.0
8
  requests>=2.31.0
9
  groq>=0.9.0
10
 
 
 
 
 
 
 
11
  # Audio processing (optionnel si vous avez besoin de traitement local)
12
  soundfile>=0.12.1
13
 
 
8
  requests>=2.31.0
9
  groq>=0.9.0
10
 
11
+ # LangChain
12
+ langchain>=0.1.0
13
+ langchain-core>=0.1.0
14
+ langchain-groq>=0.1.0
15
+ langsmith>=0.1.0
16
+
17
  # Audio processing (optionnel si vous avez besoin de traitement local)
18
  soundfile>=0.12.1
19
 
routes/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
  """API route handlers"""
2
 
3
  from fastapi import APIRouter
4
- from . import root, health, stance, label, generate
5
  from routes.tts_routes import router as audio_router
6
  # Create main router
7
  api_router = APIRouter()
@@ -12,6 +12,7 @@ api_router.include_router(health.router)
12
  api_router.include_router(stance.router, prefix="/stance")
13
  api_router.include_router(label.router, prefix="/label")
14
  api_router.include_router(generate.router, prefix="/generate")
 
15
  api_router.include_router(audio_router)
16
 
17
  __all__ = ["api_router"]
 
1
  """API route handlers"""
2
 
3
  from fastapi import APIRouter
4
+ from . import root, health, stance, label, generate, topic
5
  from routes.tts_routes import router as audio_router
6
  # Create main router
7
  api_router = APIRouter()
 
12
  api_router.include_router(stance.router, prefix="/stance")
13
  api_router.include_router(label.router, prefix="/label")
14
  api_router.include_router(generate.router, prefix="/generate")
15
+ api_router.include_router(topic.router, prefix="/topic")
16
  api_router.include_router(audio_router)
17
 
18
  __all__ = ["api_router"]
routes/topic.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Topic extraction endpoints"""
2
+
3
+ from fastapi import APIRouter, HTTPException
4
+ from datetime import datetime
5
+ import logging
6
+
7
+ from services.topic_service import topic_service
8
+ from models.topic import (
9
+ TopicRequest,
10
+ TopicResponse,
11
+ BatchTopicRequest,
12
+ BatchTopicResponse,
13
+ )
14
+
15
+ router = APIRouter()
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @router.post("/extract", response_model=TopicResponse, tags=["Topic Extraction"])
20
+ async def extract_topic(request: TopicRequest):
21
+ """
22
+ Extract a topic from a given text/argument
23
+
24
+ - **text**: The input text or argument to extract topic from (5-5000 chars)
25
+
26
+ Returns the extracted topic description
27
+ """
28
+ try:
29
+ # Extract topic
30
+ topic = topic_service.extract_topic(request.text)
31
+
32
+ # Build response
33
+ response = TopicResponse(
34
+ text=request.text,
35
+ topic=topic,
36
+ timestamp=datetime.now().isoformat()
37
+ )
38
+
39
+ logger.info(f"Topic extracted: {topic[:50]}...")
40
+ return response
41
+
42
+ except ValueError as e:
43
+ logger.error(f"Validation error: {str(e)}")
44
+ raise HTTPException(status_code=400, detail=str(e))
45
+ except Exception as e:
46
+ logger.error(f"Topic extraction error: {str(e)}")
47
+ raise HTTPException(status_code=500, detail=f"Topic extraction failed: {str(e)}")
48
+
49
+
50
+ @router.post("/batch-extract", response_model=BatchTopicResponse, tags=["Topic Extraction"])
51
+ async def batch_extract_topics(request: BatchTopicRequest):
52
+ """
53
+ Extract topics from multiple texts/arguments
54
+
55
+ - **texts**: List of texts to extract topics from (max 50)
56
+
57
+ Returns extracted topics for all texts
58
+ """
59
+ try:
60
+ # Batch extract topics
61
+ topics = topic_service.batch_extract_topics(request.texts)
62
+
63
+ # Build response
64
+ results = []
65
+ timestamp = datetime.now().isoformat()
66
+
67
+ for i, text in enumerate(request.texts):
68
+ if topics[i] is not None:
69
+ results.append(
70
+ TopicResponse(
71
+ text=text,
72
+ topic=topics[i],
73
+ timestamp=timestamp
74
+ )
75
+ )
76
+ else:
77
+ # Skip failed extractions or handle as needed
78
+ logger.warning(f"Failed to extract topic for text at index {i}")
79
+
80
+ logger.info(f"Batch topic extraction completed: {len(results)}/{len(request.texts)} successful")
81
+
82
+ return BatchTopicResponse(
83
+ results=results,
84
+ total_processed=len(results),
85
+ timestamp=timestamp
86
+ )
87
+
88
+ except ValueError as e:
89
+ logger.error(f"Validation error: {str(e)}")
90
+ raise HTTPException(status_code=400, detail=str(e))
91
+ except Exception as e:
92
+ logger.error(f"Batch topic extraction error: {str(e)}")
93
+ raise HTTPException(status_code=500, detail=f"Batch topic extraction failed: {str(e)}")
94
+
services/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .generate_model_manager import GenerateModelManager, generate_model_manager
7
  # NEW imports
8
  from .stt_service import speech_to_text
9
  from .tts_service import text_to_speech
 
10
 
11
  __all__ = [
12
  "StanceModelManager",
@@ -15,6 +16,8 @@ __all__ = [
15
  "kpa_model_manager",
16
  "GenerateModelManager",
17
  "generate_model_manager",
 
 
18
 
19
  # NEW exports
20
  "speech_to_text",
 
7
  # NEW imports
8
  from .stt_service import speech_to_text
9
  from .tts_service import text_to_speech
10
+ from .topic_service import TopicService, topic_service
11
 
12
  __all__ = [
13
  "StanceModelManager",
 
16
  "kpa_model_manager",
17
  "GenerateModelManager",
18
  "generate_model_manager",
19
+ "TopicService",
20
+ "topic_service",
21
 
22
  # NEW exports
23
  "speech_to_text",
services/topic_service.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Service for topic extraction from text using LangChain Groq"""
2
+
3
+ import logging
4
+ from typing import Optional, List
5
+ from langchain_core.messages import HumanMessage, SystemMessage
6
+ from langchain_groq import ChatGroq
7
+ from pydantic import BaseModel, Field
8
+ from langsmith import traceable
9
+
10
+ from config import GROQ_API_KEY
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class TopicOutput(BaseModel):
16
+ """Pydantic schema for topic extraction output"""
17
+ topic: str = Field(..., description="A specific, detailed topic description")
18
+
19
+
20
+ class TopicService:
21
+ """Service for extracting topics from text arguments"""
22
+
23
+ def __init__(self):
24
+ self.llm = None
25
+ self.model_name = "openai/gpt-oss-safeguard-20b" # Default model
26
+ self.initialized = False
27
+
28
+ def initialize(self, model_name: Optional[str] = None):
29
+ """Initialize the Groq LLM with structured output"""
30
+ if self.initialized:
31
+ logger.info("Topic service already initialized")
32
+ return
33
+
34
+ if not GROQ_API_KEY:
35
+ raise ValueError("GROQ_API_KEY not found in environment variables")
36
+
37
+ if model_name:
38
+ self.model_name = model_name
39
+
40
+ try:
41
+ logger.info(f"Initializing topic extraction service with model: {self.model_name}")
42
+
43
+ llm = ChatGroq(
44
+ model=self.model_name,
45
+ api_key=GROQ_API_KEY,
46
+ temperature=0.0,
47
+ max_tokens=512,
48
+ )
49
+
50
+ # Bind structured output directly to the model
51
+ self.llm = llm.with_structured_output(TopicOutput)
52
+ self.initialized = True
53
+
54
+ logger.info("✓ Topic extraction service initialized successfully")
55
+
56
+ except Exception as e:
57
+ logger.error(f"Error initializing topic service: {str(e)}")
58
+ raise RuntimeError(f"Failed to initialize topic service: {str(e)}")
59
+
60
+ @traceable(name="extract_topic")
61
+ def extract_topic(self, text: str) -> str:
62
+ """
63
+ Extract a topic from the given text/argument
64
+
65
+ Args:
66
+ text: The input text/argument to extract topic from
67
+
68
+ Returns:
69
+ The extracted topic string
70
+ """
71
+ if not self.initialized:
72
+ self.initialize()
73
+
74
+ if not text or not isinstance(text, str):
75
+ raise ValueError("Text must be a non-empty string")
76
+
77
+ text = text.strip()
78
+ if len(text) == 0:
79
+ raise ValueError("Text cannot be empty")
80
+
81
+ system_message = """You are an information extraction model.
82
+ Extract a detailed topic from the user text.
83
+
84
+ Examples:
85
+ - Text: "Governments should subsidize electric cars to encourage adoption."
86
+ Output: topic="government subsidies for electric vehicle adoption"
87
+
88
+ - Text: "Raising the minimum wage will hurt small businesses and cost jobs."
89
+ Output: topic="raising the minimum wage and its economic impact on small businesses"
90
+ """
91
+
92
+ try:
93
+ result = self.llm.invoke(
94
+ [
95
+ SystemMessage(content=system_message),
96
+ HumanMessage(content=text),
97
+ ]
98
+ )
99
+
100
+ return result.topic
101
+
102
+ except Exception as e:
103
+ logger.error(f"Error extracting topic: {str(e)}")
104
+ raise RuntimeError(f"Topic extraction failed: {str(e)}")
105
+
106
+ def batch_extract_topics(self, texts: List[str]) -> List[str]:
107
+ """
108
+ Extract topics from multiple texts
109
+
110
+ Args:
111
+ texts: List of input texts/arguments
112
+
113
+ Returns:
114
+ List of extracted topics
115
+ """
116
+ if not self.initialized:
117
+ self.initialize()
118
+
119
+ if not texts or not isinstance(texts, list):
120
+ raise ValueError("Texts must be a non-empty list")
121
+
122
+ results = []
123
+ for text in texts:
124
+ try:
125
+ topic = self.extract_topic(text)
126
+ results.append(topic)
127
+ except Exception as e:
128
+ logger.error(f"Error extracting topic for text '{text[:50]}...': {str(e)}")
129
+ results.append(None) # Or raise, depending on desired behavior
130
+
131
+ return results
132
+
133
+
134
+ # Initialize singleton instance
135
+ topic_service = TopicService()
136
+