Yassine Mhirsi commited on
Commit
a453c29
·
1 Parent(s): 94c2a9a

feat: Add topic similarity service using Google Generative AI embeddings, enabling improved topic matching and similarity analysis. Update topic extraction logic to utilize this service and enhance overall functionality.

Browse files
.gitignore CHANGED
@@ -10,4 +10,5 @@ __pycache__
10
  *.pyzw
11
  *.pyzwz
12
  *.pyzwzw
13
- *.pyzwzwz
 
 
10
  *.pyzw
11
  *.pyzwz
12
  *.pyzwzw
13
+ *.pyzwzwz
14
+ data/topic_embeddings_cache.json
config.py CHANGED
@@ -28,6 +28,9 @@ GENERATE_MODEL_ID = HUGGINGFACE_GENERATE_MODEL_ID
28
  # ============ GROQ MODELS ============
29
  GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
30
 
 
 
 
31
  # **Speech-to-Text**
32
  GROQ_STT_MODEL = "whisper-large-v3-turbo"
33
 
@@ -84,5 +87,6 @@ logger.info(f" HF Label Model : {HUGGINGFACE_LABEL_MODEL_ID}")
84
  logger.info(f" GROQ STT Model : {GROQ_STT_MODEL}")
85
  logger.info(f" GROQ TTS Model : {GROQ_TTS_MODEL}")
86
  logger.info(f" GROQ Chat Model : {GROQ_CHAT_MODEL}")
 
87
  logger.info(f" Supabase URL : {'✓ Configured' if SUPABASE_URL else '✗ Not configured'}")
88
  logger.info("="*60)
 
28
  # ============ GROQ MODELS ============
29
  GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
30
 
31
+ # ============ GOOGLE MODELS ============
32
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
33
+
34
  # **Speech-to-Text**
35
  GROQ_STT_MODEL = "whisper-large-v3-turbo"
36
 
 
87
  logger.info(f" GROQ STT Model : {GROQ_STT_MODEL}")
88
  logger.info(f" GROQ TTS Model : {GROQ_TTS_MODEL}")
89
  logger.info(f" GROQ Chat Model : {GROQ_CHAT_MODEL}")
90
+ logger.info(f" Google API Key : {'✓ Configured' if GOOGLE_API_KEY else '✗ Not configured'}")
91
  logger.info(f" Supabase URL : {'✓ Configured' if SUPABASE_URL else '✗ Not configured'}")
92
  logger.info("="*60)
data/topics.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "topics": [
3
+ "Assisted suicide should be a criminal offence",
4
+ "We should abolish intellectual property rights",
5
+ "Homeschooling should be banned",
6
+ "The vow of celibacy should be abandoned",
7
+ "We should legalize prostitution",
8
+ "We should ban private military companies",
9
+ "We should abolish capital punishment",
10
+ "Foster care brings more harm than good",
11
+ "Routine child vaccinations should be mandatory",
12
+ "We should abolish the three-strikes laws",
13
+ "We should subsidize student loans",
14
+ "We should end the use of economic sanctions",
15
+ "We should end mandatory retirement",
16
+ "We should close Guantanamo Bay detention camp",
17
+ "We should subsidize space exploration",
18
+ "We should abandon the use of school uniform",
19
+ "The use of public defenders should be mandatory",
20
+ "We should adopt an austerity regime",
21
+ "Social media platforms should be regulated by the government",
22
+ "We should ban human cloning",
23
+ "We should adopt atheism",
24
+ "We should introduce compulsory voting",
25
+ "We should adopt libertarianism",
26
+ "We should abolish the right to keep and bear arms",
27
+ "We should legalize sex selection",
28
+ "We should abandon marriage",
29
+ "Entrapment should be legalized",
30
+ "We should end affirmative action",
31
+ "We should prohibit women in combat",
32
+ "We should adopt a zero-tolerance policy in schools",
33
+ "We should subsidize vocational education",
34
+ "We should ban the use of child actors",
35
+ "We should legalize cannabis",
36
+ "We should ban cosmetic surgery",
37
+ "We should end racial profiling",
38
+ "We should prohibit flag burning",
39
+ "The USA is a good country to live in",
40
+ "We should ban algorithmic trading",
41
+ "We should fight for the abolition of nuclear weapons",
42
+ "We should fight urbanization",
43
+ "We should subsidize journalism"
44
+ ]
45
+ }
requirements.txt CHANGED
@@ -8,6 +8,7 @@ pydantic>=2.5.0
8
  requests>=2.31.0
9
  groq>=0.9.0
10
  supabase>=2.0.0
 
11
 
12
  # LangChain
13
  langchain>=0.1.0
@@ -24,6 +25,7 @@ torch>=2.0.1
24
 
25
  # Autres dépendances
26
  numpy>=1.26.4
 
27
 
28
  mcp>=1.0.0
29
  # Note: fastapi-mcp peut ne pas exister officiellement,
 
8
  requests>=2.31.0
9
  groq>=0.9.0
10
  supabase>=2.0.0
11
+ google-genai>=0.2.0
12
 
13
  # LangChain
14
  langchain>=0.1.0
 
25
 
26
  # Autres dépendances
27
  numpy>=1.26.4
28
+ scikit-learn>=1.3.0
29
 
30
  mcp>=1.0.0
31
  # Note: fastapi-mcp peut ne pas exister officiellement,
services/analysis_service.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
5
  from datetime import datetime
6
 
7
  from services.database_service import database_service
8
- from services.topic_service import topic_service
9
  from services.stance_model_manager import stance_model_manager
10
 
11
  logger = logging.getLogger(__name__)
@@ -42,9 +42,9 @@ class AnalysisService:
42
 
43
  logger.info(f"Starting analysis for {len(arguments)} arguments for user {user_id}")
44
 
45
- # Step 1: Extract topics for all arguments
46
- logger.info("Step 1: Extracting topics...")
47
- topics = topic_service.batch_extract_topics(arguments)
48
 
49
  if len(topics) != len(arguments):
50
  raise RuntimeError(f"Topic extraction returned {len(topics)} topics but expected {len(arguments)}")
 
5
  from datetime import datetime
6
 
7
  from services.database_service import database_service
8
+ from services.topic_similarity_service import topic_similarity_service
9
  from services.stance_model_manager import stance_model_manager
10
 
11
  logger = logging.getLogger(__name__)
 
42
 
43
  logger.info(f"Starting analysis for {len(arguments)} arguments for user {user_id}")
44
 
45
+ # Step 1: Find most similar topics for all arguments
46
+ logger.info("Step 1: Finding most similar topics...")
47
+ topics = topic_similarity_service.batch_find_similar_topics(arguments)
48
 
49
  if len(topics) != len(arguments):
50
  raise RuntimeError(f"Topic extraction returned {len(topics)} topics but expected {len(arguments)}")
services/topic_service.py CHANGED
@@ -1,73 +1,32 @@
1
  """Service for topic extraction from text using LangChain Groq"""
2
 
3
  import logging
4
- import json
5
  from typing import Optional, List
6
  from langchain_core.messages import HumanMessage, SystemMessage
7
  from langchain_groq import ChatGroq
 
8
  from langsmith import traceable
9
 
10
  from config import GROQ_API_KEY
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
- # Predefined topics list
15
- PREDEFINED_TOPICS = [
16
- "Assisted suicide should be a criminal offence",
17
- "We should abolish intellectual property rights",
18
- "Homeschooling should be banned",
19
- "The vow of celibacy should be abandoned",
20
- "We should legalize prostitution",
21
- "We should ban private military companies",
22
- "We should abolish capital punishment",
23
- "Foster care brings more harm than good",
24
- "Routine child vaccinations should be mandatory",
25
- "We should abolish the three-strikes laws",
26
- "We should subsidize student loans",
27
- "We should end the use of economic sanctions",
28
- "We should end mandatory retirement",
29
- "We should close Guantanamo Bay detention camp",
30
- "We should subsidize space exploration",
31
- "We should abandon the use of school uniform",
32
- "The use of public defenders should be mandatory",
33
- "We should adopt an austerity regime",
34
- "Social media platforms should be regulated by the government",
35
- "We should ban human cloning",
36
- "We should adopt atheism",
37
- "We should introduce compulsory voting",
38
- "We should adopt libertarianism",
39
- "We should abolish the right to keep and bear arms",
40
- "We should legalize sex selection",
41
- "We should abandon marriage",
42
- "Entrapment should be legalized",
43
- "We should end affirmative action",
44
- "We should prohibit women in combat",
45
- "We should adopt a zero-tolerance policy in schools",
46
- "We should subsidize vocational education",
47
- "We should ban the use of child actors",
48
- "We should legalize cannabis",
49
- "We should ban cosmetic surgery",
50
- "We should end racial profiling",
51
- "We should prohibit flag burning",
52
- "The USA is a good country to live in",
53
- "We should ban algorithmic trading",
54
- "We should fight for the abolition of nuclear weapons",
55
- "We should fight urbanization",
56
- "We should subsidize journalism",
57
- ]
58
 
59
 
60
  class TopicService:
61
- """Service for extracting topics from text arguments by matching to predefined topics"""
62
 
63
  def __init__(self):
64
  self.llm = None
65
- self.model_name = "openai/gpt-oss-safeguard-20b" # Default model
66
  self.initialized = False
67
- self.predefined_topics = PREDEFINED_TOPICS
68
 
69
  def initialize(self, model_name: Optional[str] = None):
70
- """Initialize the Groq LLM"""
71
  if self.initialized:
72
  logger.info("Topic service already initialized")
73
  return
@@ -81,13 +40,15 @@ class TopicService:
81
  try:
82
  logger.info(f"Initializing topic extraction service with model: {self.model_name}")
83
 
84
- self.llm = ChatGroq(
85
  model=self.model_name,
86
  api_key=GROQ_API_KEY,
87
  temperature=0.0,
88
  max_tokens=512,
89
  )
90
 
 
 
91
  self.initialized = True
92
 
93
  logger.info("✓ Topic extraction service initialized successfully")
@@ -96,46 +57,16 @@ class TopicService:
96
  logger.error(f"Error initializing topic service: {str(e)}")
97
  raise RuntimeError(f"Failed to initialize topic service: {str(e)}")
98
 
99
- def _get_system_message(self) -> str:
100
- """Generate system message with predefined topics list"""
101
- topics_list = "\n".join([f"{i+1}. {topic}" for i, topic in enumerate(self.predefined_topics)])
102
-
103
- return f"""You are a topic classification model. Your task is to select the MOST SIMILAR topic from the predefined list below that best matches the user's input text.
104
-
105
- IMPORTANT: You MUST return EXACTLY one of the predefined topics below. Do not create new topics or modify the wording.
106
-
107
- Return your response as a JSON object with a single "topic" field containing the exact topic text from the list.
108
-
109
- Predefined Topics:
110
- {topics_list}
111
-
112
- Instructions:
113
- 1. Analyze the user's input text carefully
114
- 2. Identify the main theme, subject, or argument being discussed
115
- 3. Find the topic from the predefined list that is MOST SIMILAR to the input text
116
- 4. Return a JSON object with the EXACT topic text as it appears in the list above
117
-
118
- Examples:
119
- - Input: "I think we need to make assisted suicide illegal and punishable by law."
120
- Output: {{"topic": "Assisted suicide should be a criminal offence"}}
121
-
122
- - Input: "Student debt is crushing young people. The government should help pay for college."
123
- Output: {{"topic": "We should subsidize student loans"}}
124
-
125
- - Input: "Marijuana should be legal for adults to use recreationally."
126
- Output: {{"topic": "We should legalize cannabis"}}
127
- """
128
-
129
  @traceable(name="extract_topic")
130
  def extract_topic(self, text: str) -> str:
131
  """
132
- Extract a topic from the given text/argument by matching to predefined topics
133
 
134
  Args:
135
  text: The input text/argument to extract topic from
136
 
137
  Returns:
138
- The extracted topic string (must be one of the predefined topics)
139
  """
140
  if not self.initialized:
141
  self.initialize()
@@ -147,7 +78,16 @@ Examples:
147
  if len(text) == 0:
148
  raise ValueError("Text cannot be empty")
149
 
150
- system_message = self._get_system_message()
 
 
 
 
 
 
 
 
 
151
 
152
  try:
153
  result = self.llm.invoke(
@@ -157,68 +97,7 @@ Examples:
157
  ]
158
  )
159
 
160
- # Extract content from the response
161
- response_content = result.content.strip()
162
-
163
- # Try to parse as JSON first
164
- try:
165
- parsed_response = json.loads(response_content)
166
- selected_topic = parsed_response.get("topic", "").strip()
167
- except json.JSONDecodeError:
168
- # If not JSON, try to extract topic from plain text
169
- # Look for the topic in the response text
170
- selected_topic = response_content.strip()
171
- # Remove quotes if present
172
- if selected_topic.startswith('"') and selected_topic.endswith('"'):
173
- selected_topic = selected_topic[1:-1]
174
- elif selected_topic.startswith("'") and selected_topic.endswith("'"):
175
- selected_topic = selected_topic[1:-1]
176
-
177
- if not selected_topic:
178
- raise ValueError("No topic found in LLM response")
179
-
180
- # Validate that the returned topic is in the predefined list
181
- if selected_topic not in self.predefined_topics:
182
- logger.warning(
183
- f"LLM returned topic not in predefined list: '{selected_topic}'. "
184
- f"Attempting to find closest match..."
185
- )
186
- # Try to find the closest match (case-insensitive)
187
- selected_topic_lower = selected_topic.lower()
188
- for predefined_topic in self.predefined_topics:
189
- if predefined_topic.lower() == selected_topic_lower:
190
- selected_topic = predefined_topic
191
- logger.info(f"Found case-insensitive match: '{selected_topic}'")
192
- break
193
- else:
194
- # If still no match, try fuzzy matching by checking if the topic contains key words
195
- # This is a fallback for when the LLM returns something close but not exact
196
- best_match = None
197
- best_match_score = 0
198
- selected_words = set(selected_topic_lower.split())
199
-
200
- for predefined_topic in self.predefined_topics:
201
- predefined_words = set(predefined_topic.lower().split())
202
- # Calculate word overlap
203
- overlap = len(selected_words & predefined_words)
204
- if overlap > best_match_score and overlap >= 2: # At least 2 words must match
205
- best_match_score = overlap
206
- best_match = predefined_topic
207
-
208
- if best_match:
209
- logger.info(f"Found fuzzy match: '{selected_topic}' -> '{best_match}'")
210
- selected_topic = best_match
211
- else:
212
- # If still no match, log error and raise
213
- logger.error(
214
- f"Could not match returned topic '{selected_topic}' to any predefined topic. "
215
- f"Available topics: {self.predefined_topics[:3]}..."
216
- )
217
- raise ValueError(
218
- f"Returned topic '{selected_topic}' is not in the predefined topics list"
219
- )
220
-
221
- return selected_topic
222
 
223
  except Exception as e:
224
  logger.error(f"Error extracting topic: {str(e)}")
 
1
  """Service for topic extraction from text using LangChain Groq"""
2
 
3
  import logging
 
4
  from typing import Optional, List
5
  from langchain_core.messages import HumanMessage, SystemMessage
6
  from langchain_groq import ChatGroq
7
+ from pydantic import BaseModel, Field
8
  from langsmith import traceable
9
 
10
  from config import GROQ_API_KEY
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
+
15
+ class TopicOutput(BaseModel):
16
+ """Pydantic schema for topic extraction output"""
17
+ topic: str = Field(..., description="A specific, detailed topic description")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  class TopicService:
21
+ """Service for extracting topics from text arguments"""
22
 
23
  def __init__(self):
24
  self.llm = None
25
+ self.model_name = "openai/gpt-oss-safeguard-120b" # another model meta-llama/llama-4-scout-17b-16e-instruct
26
  self.initialized = False
 
27
 
28
  def initialize(self, model_name: Optional[str] = None):
29
+ """Initialize the Groq LLM with structured output"""
30
  if self.initialized:
31
  logger.info("Topic service already initialized")
32
  return
 
40
  try:
41
  logger.info(f"Initializing topic extraction service with model: {self.model_name}")
42
 
43
+ llm = ChatGroq(
44
  model=self.model_name,
45
  api_key=GROQ_API_KEY,
46
  temperature=0.0,
47
  max_tokens=512,
48
  )
49
 
50
+ # Bind structured output directly to the model
51
+ self.llm = llm.with_structured_output(TopicOutput)
52
  self.initialized = True
53
 
54
  logger.info("✓ Topic extraction service initialized successfully")
 
57
  logger.error(f"Error initializing topic service: {str(e)}")
58
  raise RuntimeError(f"Failed to initialize topic service: {str(e)}")
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @traceable(name="extract_topic")
61
  def extract_topic(self, text: str) -> str:
62
  """
63
+ Extract a topic from the given text/argument
64
 
65
  Args:
66
  text: The input text/argument to extract topic from
67
 
68
  Returns:
69
+ The extracted topic string
70
  """
71
  if not self.initialized:
72
  self.initialize()
 
78
  if len(text) == 0:
79
  raise ValueError("Text cannot be empty")
80
 
81
+ system_message = """You are an information extraction model.
82
+ Extract a topic from the user text. The topic should be a single sentence that captures the main idea of the text in simple english.
83
+
84
+ Examples:
85
+ - Text: "Governments should subsidize electric cars to encourage adoption."
86
+ Output: topic="government subsidies for electric vehicle adoption"
87
+
88
+ - Text: "Raising the minimum wage will hurt small businesses and cost jobs."
89
+ Output: topic="raising the minimum wage and its economic impact on small businesses"
90
+ """
91
 
92
  try:
93
  result = self.llm.invoke(
 
97
  ]
98
  )
99
 
100
+ return result.topic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  except Exception as e:
103
  logger.error(f"Error extracting topic: {str(e)}")
services/topic_similarity_service.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Service for finding similar topics using Google Generative AI embeddings"""
2
+
3
+ import logging
4
+ import json
5
+ import hashlib
6
+ from pathlib import Path
7
+ from typing import Optional, List, Dict
8
+ from datetime import datetime
9
+ import numpy as np
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ from google import genai
12
+ from google.genai import types
13
+
14
+ from config import GOOGLE_API_KEY, PROJECT_ROOT
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Paths for topics and cache files
19
+ TOPICS_FILE = PROJECT_ROOT / "data" / "topics.json"
20
+ EMBEDDINGS_CACHE_FILE = PROJECT_ROOT / "data" / "topic_embeddings_cache.json"
21
+
22
+
23
+ class TopicSimilarityService:
24
+ """Service for finding the most similar topic from a predefined list using embeddings"""
25
+
26
+ def __init__(self):
27
+ self.client = None
28
+ self.topics = []
29
+ self.topic_embeddings = None
30
+ self.initialized = False
31
+ self.model_name = "models/text-embedding-004"
32
+
33
+ def initialize(self):
34
+ """Initialize the Google Generative AI client and load topic embeddings"""
35
+ if self.initialized:
36
+ logger.info("Topic similarity service already initialized")
37
+ return
38
+
39
+ if not GOOGLE_API_KEY:
40
+ raise ValueError("GOOGLE_API_KEY not found in environment variables")
41
+
42
+ try:
43
+ logger.info("Initializing topic similarity service with Google Generative AI")
44
+
45
+ # Create Google Generative AI client
46
+ self.client = genai.Client(api_key=GOOGLE_API_KEY)
47
+
48
+ # Load topics
49
+ self.topics = self._load_topics()
50
+ logger.info(f"Loaded {len(self.topics)} topics from {TOPICS_FILE}")
51
+
52
+ # Load or generate topic embeddings
53
+ self.topic_embeddings = self._get_topic_embeddings()
54
+ logger.info(f"Loaded {len(self.topic_embeddings)} topic embeddings")
55
+
56
+ self.initialized = True
57
+ logger.info("✓ Topic similarity service initialized successfully")
58
+
59
+ except Exception as e:
60
+ logger.error(f"Error initializing topic similarity service: {str(e)}")
61
+ raise RuntimeError(f"Failed to initialize topic similarity service: {str(e)}")
62
+
63
+ def _load_topics(self) -> List[str]:
64
+ """Load topics from topics.json file"""
65
+ if not TOPICS_FILE.exists():
66
+ raise FileNotFoundError(f"Topics file not found: {TOPICS_FILE}")
67
+
68
+ try:
69
+ with open(TOPICS_FILE, 'r', encoding='utf-8') as f:
70
+ data = json.load(f)
71
+ return data.get("topics", [])
72
+ except (json.JSONDecodeError, KeyError) as e:
73
+ raise ValueError(f"Error loading topics from {TOPICS_FILE}: {str(e)}")
74
+
75
+ def _get_topics_hash(self, topics: List[str]) -> str:
76
+ """Generate a hash of the topics list to verify cache validity"""
77
+ topics_str = json.dumps(topics, sort_keys=True)
78
+ return hashlib.md5(topics_str.encode('utf-8')).hexdigest()
79
+
80
+ def _load_cached_embeddings(self) -> Optional[np.ndarray]:
81
+ """Load cached topic embeddings if they exist and are valid"""
82
+ if not EMBEDDINGS_CACHE_FILE.exists():
83
+ return None
84
+
85
+ try:
86
+ with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f:
87
+ cache_data = json.load(f)
88
+
89
+ # Verify cache is valid by checking topics hash
90
+ current_hash = self._get_topics_hash(self.topics)
91
+
92
+ if cache_data.get("topics_hash") == current_hash:
93
+ # Convert list embeddings back to numpy arrays
94
+ embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])]
95
+ logger.info(f"Loaded {len(embeddings)} topic embeddings from cache")
96
+ return np.array(embeddings)
97
+ else:
98
+ # Topics have changed, cache is invalid
99
+ logger.info("Topics have changed, cache is invalid")
100
+ return None
101
+ except (json.JSONDecodeError, KeyError, ValueError) as e:
102
+ logger.warning(f"Could not load cached embeddings: {e}")
103
+ return None
104
+
105
+ def _save_cached_embeddings(self, embeddings: np.ndarray):
106
+ """Save topic embeddings to cache file"""
107
+ topics_hash = self._get_topics_hash(self.topics)
108
+
109
+ # Convert numpy arrays to lists for JSON serialization
110
+ embeddings_list = [emb.tolist() for emb in embeddings]
111
+
112
+ cache_data = {
113
+ "topics_hash": topics_hash,
114
+ "embeddings": embeddings_list,
115
+ "model": self.model_name,
116
+ "cached_at": datetime.now().isoformat()
117
+ }
118
+
119
+ try:
120
+ # Ensure directory exists
121
+ EMBEDDINGS_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
122
+
123
+ with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f:
124
+ json.dump(cache_data, f, indent=2)
125
+ logger.info(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}")
126
+ except Exception as e:
127
+ logger.warning(f"Could not save cached embeddings: {e}")
128
+
129
+ def _get_topic_embeddings(self) -> np.ndarray:
130
+ """
131
+ Get topic embeddings, loading from cache if available, otherwise generating and caching them
132
+
133
+ Returns:
134
+ numpy.ndarray: Array of topic embeddings
135
+ """
136
+ # Try to load from cache first
137
+ cached_embeddings = self._load_cached_embeddings()
138
+ if cached_embeddings is not None:
139
+ return cached_embeddings
140
+
141
+ # Cache miss or invalid - generate embeddings
142
+ logger.info(f"Generating embeddings for {len(self.topics)} topics (this may take a moment)...")
143
+
144
+ try:
145
+ embedding_response = self.client.models.embed_content(
146
+ model=self.model_name,
147
+ contents=self.topics,
148
+ config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
149
+ )
150
+
151
+ if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
152
+ raise RuntimeError("Embedding API did not return embeddings")
153
+
154
+ embeddings = [np.array(e.values) for e in embedding_response.embeddings]
155
+ embeddings_array = np.array(embeddings)
156
+
157
+ # Save to cache for future use
158
+ self._save_cached_embeddings(embeddings_array)
159
+
160
+ return embeddings_array
161
+
162
+ except Exception as e:
163
+ logger.error(f"Error generating topic embeddings: {str(e)}")
164
+ raise RuntimeError(f"Failed to generate topic embeddings: {str(e)}")
165
+
166
+ def find_most_similar_topic(self, input_text: str) -> Dict[str, any]:
167
+ """
168
+ Compare a single input text to all topics and return the highest cosine similarity
169
+
170
+ Args:
171
+ input_text: The text to compare against topics
172
+
173
+ Returns:
174
+ dict: Contains 'topic', 'similarity', and 'index' of the most similar topic
175
+ """
176
+ if not self.initialized:
177
+ self.initialize()
178
+
179
+ if not input_text or not isinstance(input_text, str):
180
+ raise ValueError("Input text must be a non-empty string")
181
+
182
+ input_text = input_text.strip()
183
+ if len(input_text) == 0:
184
+ raise ValueError("Input text cannot be empty")
185
+
186
+ if not self.topics:
187
+ raise ValueError("No topics found in topics.json")
188
+
189
+ try:
190
+ # Embed the input text
191
+ embedding_response = self.client.models.embed_content(
192
+ model=self.model_name,
193
+ contents=[input_text],
194
+ config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
195
+ )
196
+
197
+ if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
198
+ raise RuntimeError("Embedding API did not return embeddings")
199
+
200
+ # Extract input embedding
201
+ input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1)
202
+
203
+ # Calculate cosine similarity between input and each topic
204
+ similarities = cosine_similarity(input_embedding, self.topic_embeddings)[0]
205
+
206
+ # Find the highest similarity
207
+ max_index = np.argmax(similarities)
208
+ max_similarity = similarities[max_index]
209
+ most_similar_topic = self.topics[max_index]
210
+
211
+ return {
212
+ "topic": most_similar_topic,
213
+ "similarity": float(max_similarity),
214
+ "index": int(max_index)
215
+ }
216
+
217
+ except Exception as e:
218
+ logger.error(f"Error finding similar topic: {str(e)}")
219
+ raise RuntimeError(f"Failed to find similar topic: {str(e)}")
220
+
221
+ def batch_find_similar_topics(self, input_texts: List[str]) -> List[str]:
222
+ """
223
+ Find the most similar topic for each input text
224
+
225
+ Args:
226
+ input_texts: List of input texts to compare against topics
227
+
228
+ Returns:
229
+ List of most similar topics (one per input text)
230
+ """
231
+ if not self.initialized:
232
+ self.initialize()
233
+
234
+ if not input_texts or not isinstance(input_texts, list):
235
+ raise ValueError("Input texts must be a non-empty list")
236
+
237
+ results = []
238
+ for text in input_texts:
239
+ try:
240
+ result = self.find_most_similar_topic(text)
241
+ results.append(result["topic"])
242
+ except Exception as e:
243
+ logger.error(f"Error finding similar topic for text '{text[:50]}...': {str(e)}")
244
+ results.append(None) # Or raise, depending on desired behavior
245
+
246
+ return results
247
+
248
+
249
+ # Initialize singleton instance
250
+ topic_similarity_service = TopicSimilarityService()