Yassine Mhirsi commited on
Commit
94c2a9a
·
1 Parent(s): 7218dd0

refactor: Simplify topic extraction logic in TopicService by removing Pydantic schema, enhancing JSON response handling, and adding fuzzy matching for improved topic validation.

Browse files
Files changed (1) hide show
  1. services/topic_service.py +54 -23
services/topic_service.py CHANGED
@@ -1,10 +1,10 @@
1
  """Service for topic extraction from text using LangChain Groq"""
2
 
3
  import logging
 
4
  from typing import Optional, List
5
  from langchain_core.messages import HumanMessage, SystemMessage
6
  from langchain_groq import ChatGroq
7
- from pydantic import BaseModel, Field
8
  from langsmith import traceable
9
 
10
  from config import GROQ_API_KEY
@@ -57,11 +57,6 @@ PREDEFINED_TOPICS = [
57
  ]
58
 
59
 
60
- class TopicOutput(BaseModel):
61
- """Pydantic schema for topic extraction output"""
62
- topic: str = Field(..., description="The selected topic from the predefined list that most closely matches the input text")
63
-
64
-
65
  class TopicService:
66
  """Service for extracting topics from text arguments by matching to predefined topics"""
67
 
@@ -72,7 +67,7 @@ class TopicService:
72
  self.predefined_topics = PREDEFINED_TOPICS
73
 
74
  def initialize(self, model_name: Optional[str] = None):
75
- """Initialize the Groq LLM with structured output"""
76
  if self.initialized:
77
  logger.info("Topic service already initialized")
78
  return
@@ -86,15 +81,13 @@ class TopicService:
86
  try:
87
  logger.info(f"Initializing topic extraction service with model: {self.model_name}")
88
 
89
- llm = ChatGroq(
90
  model=self.model_name,
91
  api_key=GROQ_API_KEY,
92
  temperature=0.0,
93
  max_tokens=512,
94
  )
95
 
96
- # Bind structured output directly to the model
97
- self.llm = llm.with_structured_output(TopicOutput)
98
  self.initialized = True
99
 
100
  logger.info("✓ Topic extraction service initialized successfully")
@@ -111,6 +104,8 @@ class TopicService:
111
 
112
  IMPORTANT: You MUST return EXACTLY one of the predefined topics below. Do not create new topics or modify the wording.
113
 
 
 
114
  Predefined Topics:
115
  {topics_list}
116
 
@@ -118,17 +113,17 @@ Instructions:
118
  1. Analyze the user's input text carefully
119
  2. Identify the main theme, subject, or argument being discussed
120
  3. Find the topic from the predefined list that is MOST SIMILAR to the input text
121
- 4. Return the EXACT topic text as it appears in the list above
122
 
123
  Examples:
124
  - Input: "I think we need to make assisted suicide illegal and punishable by law."
125
- Output: "Assisted suicide should be a criminal offence"
126
 
127
  - Input: "Student debt is crushing young people. The government should help pay for college."
128
- Output: "We should subsidize student loans"
129
 
130
  - Input: "Marijuana should be legal for adults to use recreationally."
131
- Output: "We should legalize cannabis"
132
  """
133
 
134
  @traceable(name="extract_topic")
@@ -162,7 +157,25 @@ Examples:
162
  ]
163
  )
164
 
165
- selected_topic = result.topic.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  # Validate that the returned topic is in the predefined list
168
  if selected_topic not in self.predefined_topics:
@@ -178,14 +191,32 @@ Examples:
178
  logger.info(f"Found case-insensitive match: '{selected_topic}'")
179
  break
180
  else:
181
- # If still no match, log error and raise
182
- logger.error(
183
- f"Could not match returned topic '{selected_topic}' to any predefined topic. "
184
- f"Available topics: {self.predefined_topics[:3]}..."
185
- )
186
- raise ValueError(
187
- f"Returned topic '{selected_topic}' is not in the predefined topics list"
188
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  return selected_topic
191
 
 
1
  """Service for topic extraction from text using LangChain Groq"""
2
 
3
  import logging
4
+ import json
5
  from typing import Optional, List
6
  from langchain_core.messages import HumanMessage, SystemMessage
7
  from langchain_groq import ChatGroq
 
8
  from langsmith import traceable
9
 
10
  from config import GROQ_API_KEY
 
57
  ]
58
 
59
 
 
 
 
 
 
60
  class TopicService:
61
  """Service for extracting topics from text arguments by matching to predefined topics"""
62
 
 
67
  self.predefined_topics = PREDEFINED_TOPICS
68
 
69
  def initialize(self, model_name: Optional[str] = None):
70
+ """Initialize the Groq LLM"""
71
  if self.initialized:
72
  logger.info("Topic service already initialized")
73
  return
 
81
  try:
82
  logger.info(f"Initializing topic extraction service with model: {self.model_name}")
83
 
84
+ self.llm = ChatGroq(
85
  model=self.model_name,
86
  api_key=GROQ_API_KEY,
87
  temperature=0.0,
88
  max_tokens=512,
89
  )
90
 
 
 
91
  self.initialized = True
92
 
93
  logger.info("✓ Topic extraction service initialized successfully")
 
104
 
105
  IMPORTANT: You MUST return EXACTLY one of the predefined topics below. Do not create new topics or modify the wording.
106
 
107
+ Return your response as a JSON object with a single "topic" field containing the exact topic text from the list.
108
+
109
  Predefined Topics:
110
  {topics_list}
111
 
 
113
  1. Analyze the user's input text carefully
114
  2. Identify the main theme, subject, or argument being discussed
115
  3. Find the topic from the predefined list that is MOST SIMILAR to the input text
116
+ 4. Return a JSON object with the EXACT topic text as it appears in the list above
117
 
118
  Examples:
119
  - Input: "I think we need to make assisted suicide illegal and punishable by law."
120
+ Output: {{"topic": "Assisted suicide should be a criminal offence"}}
121
 
122
  - Input: "Student debt is crushing young people. The government should help pay for college."
123
+ Output: {{"topic": "We should subsidize student loans"}}
124
 
125
  - Input: "Marijuana should be legal for adults to use recreationally."
126
+ Output: {{"topic": "We should legalize cannabis"}}
127
  """
128
 
129
  @traceable(name="extract_topic")
 
157
  ]
158
  )
159
 
160
+ # Extract content from the response
161
+ response_content = result.content.strip()
162
+
163
+ # Try to parse as JSON first
164
+ try:
165
+ parsed_response = json.loads(response_content)
166
+ selected_topic = parsed_response.get("topic", "").strip()
167
+ except json.JSONDecodeError:
168
+ # If not JSON, try to extract topic from plain text
169
+ # Look for the topic in the response text
170
+ selected_topic = response_content.strip()
171
+ # Remove quotes if present
172
+ if selected_topic.startswith('"') and selected_topic.endswith('"'):
173
+ selected_topic = selected_topic[1:-1]
174
+ elif selected_topic.startswith("'") and selected_topic.endswith("'"):
175
+ selected_topic = selected_topic[1:-1]
176
+
177
+ if not selected_topic:
178
+ raise ValueError("No topic found in LLM response")
179
 
180
  # Validate that the returned topic is in the predefined list
181
  if selected_topic not in self.predefined_topics:
 
191
  logger.info(f"Found case-insensitive match: '{selected_topic}'")
192
  break
193
  else:
194
+ # If still no match, try fuzzy matching by checking if the topic contains key words
195
+ # This is a fallback for when the LLM returns something close but not exact
196
+ best_match = None
197
+ best_match_score = 0
198
+ selected_words = set(selected_topic_lower.split())
199
+
200
+ for predefined_topic in self.predefined_topics:
201
+ predefined_words = set(predefined_topic.lower().split())
202
+ # Calculate word overlap
203
+ overlap = len(selected_words & predefined_words)
204
+ if overlap > best_match_score and overlap >= 2: # At least 2 words must match
205
+ best_match_score = overlap
206
+ best_match = predefined_topic
207
+
208
+ if best_match:
209
+ logger.info(f"Found fuzzy match: '{selected_topic}' -> '{best_match}'")
210
+ selected_topic = best_match
211
+ else:
212
+ # If still no match, log error and raise
213
+ logger.error(
214
+ f"Could not match returned topic '{selected_topic}' to any predefined topic. "
215
+ f"Available topics: {self.predefined_topics[:3]}..."
216
+ )
217
+ raise ValueError(
218
+ f"Returned topic '{selected_topic}' is not in the predefined topics list"
219
+ )
220
 
221
  return selected_topic
222