malek-messaoudii commited on
Commit
870d2ba
·
1 Parent(s): 91ae7d9

feat: Add GROQ_TOPIC_MODEL configuration and enhance TopicService to utilize it for improved model selection during initialization, including fallback options for robustness.

Browse files
Files changed (2) hide show
  1. config.py +3 -0
  2. services/topic_service.py +41 -21
config.py CHANGED
@@ -42,6 +42,9 @@ GROQ_TTS_FORMAT = "wav"
42
  # **Chat Model**
43
  GROQ_CHAT_MODEL = "llama3-70b-8192"
44
 
 
 
 
45
  # ============ SUPABASE ============
46
  SUPABASE_URL = os.getenv("SUPABASE_URL", "")
47
  SUPABASE_KEY = os.getenv("SUPABASE_KEY", "")
 
42
  # **Chat Model**
43
  GROQ_CHAT_MODEL = "llama3-70b-8192"
44
 
45
+ # **Topic Extraction Model**
46
+ GROQ_TOPIC_MODEL = "llama-3.1-70b-versatile" # Alternative: "llama3-70b-8192" or "llama-3.1-8b-instant"
47
+
48
  # ============ SUPABASE ============
49
  SUPABASE_URL = os.getenv("SUPABASE_URL", "")
50
  SUPABASE_KEY = os.getenv("SUPABASE_KEY", "")
services/topic_service.py CHANGED
@@ -7,7 +7,7 @@ from langchain_groq import ChatGroq
7
  from pydantic import BaseModel, Field
8
  from langsmith import traceable
9
 
10
- from config import GROQ_API_KEY
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -22,7 +22,15 @@ class TopicService:
22
 
23
  def __init__(self):
24
  self.llm = None
25
- self.model_name = "openai/gpt-oss-safeguard-120b" # another model meta-llama/llama-4-scout-17b-16e-instruct
 
 
 
 
 
 
 
 
26
  self.initialized = False
27
 
28
  def initialize(self, model_name: Optional[str] = None):
@@ -37,25 +45,37 @@ class TopicService:
37
  if model_name:
38
  self.model_name = model_name
39
 
40
- try:
41
- logger.info(f"Initializing topic extraction service with model: {self.model_name}")
42
-
43
- llm = ChatGroq(
44
- model=self.model_name,
45
- api_key=GROQ_API_KEY,
46
- temperature=0.0,
47
- max_tokens=512,
48
- )
49
-
50
- # Bind structured output directly to the model
51
- self.llm = llm.with_structured_output(TopicOutput)
52
- self.initialized = True
53
-
54
- logger.info("✓ Topic extraction service initialized successfully")
55
-
56
- except Exception as e:
57
- logger.error(f"Error initializing topic service: {str(e)}")
58
- raise RuntimeError(f"Failed to initialize topic service: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @traceable(name="extract_topic")
61
  def extract_topic(self, text: str) -> str:
 
7
  from pydantic import BaseModel, Field
8
  from langsmith import traceable
9
 
10
+ from config import GROQ_API_KEY, GROQ_TOPIC_MODEL
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
22
 
23
  def __init__(self):
24
  self.llm = None
25
+ # Use valid Groq model - defaults from config, fallback to common models
26
+ self.model_name = GROQ_TOPIC_MODEL if GROQ_TOPIC_MODEL else "llama-3.1-70b-versatile"
27
+ # Fallback models to try if primary fails
28
+ self.fallback_models = [
29
+ "llama-3.1-70b-versatile",
30
+ "llama3-70b-8192",
31
+ "llama-3.1-8b-instant",
32
+ "mixtral-8x7b-32768"
33
+ ]
34
  self.initialized = False
35
 
36
  def initialize(self, model_name: Optional[str] = None):
 
45
  if model_name:
46
  self.model_name = model_name
47
 
48
+ # Try primary model first, then fallbacks
49
+ models_to_try = [self.model_name] + [m for m in self.fallback_models if m != self.model_name]
50
+
51
+ last_error = None
52
+ for model_to_try in models_to_try:
53
+ try:
54
+ logger.info(f"Initializing topic extraction service with model: {model_to_try}")
55
+
56
+ llm = ChatGroq(
57
+ model=model_to_try,
58
+ api_key=GROQ_API_KEY,
59
+ temperature=0.0,
60
+ max_tokens=512,
61
+ )
62
+
63
+ # Bind structured output directly to the model
64
+ self.llm = llm.with_structured_output(TopicOutput)
65
+ self.model_name = model_to_try # Update to successful model
66
+ self.initialized = True
67
+
68
+ logger.info(f"✓ Topic extraction service initialized successfully with model: {model_to_try}")
69
+ return
70
+
71
+ except Exception as e:
72
+ last_error = e
73
+ logger.warning(f"Failed to initialize with model {model_to_try}: {str(e)}")
74
+ continue
75
+
76
+ # If all models failed
77
+ logger.error(f"Error initializing topic service with all models: {last_error}")
78
+ raise RuntimeError(f"Failed to initialize topic service with any model. Last error: {str(last_error)}")
79
 
80
  @traceable(name="extract_topic")
81
  def extract_topic(self, text: str) -> str: