ruslanmv commited on
Commit
53fae25
·
1 Parent(s): 5446ae0

Cache HF fix

Browse files
app/core/inference/client.py CHANGED
@@ -1,6 +1,6 @@
1
  # app/core/inference/client.py
2
  import os, json, time, logging
3
- from typing import Dict, List, Optional, Iterator, Tuple, Any
4
 
5
  import requests
6
 
@@ -34,15 +34,8 @@ class RouterRequestsClient:
34
  Simple requests-only client for HF Router Chat Completions.
35
  Supports non-streaming (returns str) and streaming (yields token strings).
36
  """
37
- def __init__(
38
- self,
39
- model: str,
40
- fallback: Optional[str] = None,
41
- provider: Optional[str] = None,
42
- max_retries: int = 2,
43
- connect_timeout: float = 10.0,
44
- read_timeout: float = 60.0,
45
- ):
46
  self.model = model
47
  self.fallback = fallback if fallback != model else None
48
  self.provider = provider
@@ -58,19 +51,22 @@ class RouterRequestsClient:
58
  max_tokens: int,
59
  temperature: float,
60
  stop: Optional[List[str]] = None,
61
- extra: Optional[Dict[str, Any]] = None,
 
62
  ) -> str:
63
- payload: Dict[str, Any] = {
64
  "model": _model_with_provider(self.model, self.provider),
65
  "messages": _mk_messages(system_prompt, user_text),
66
- "temperature": float(temperature),
67
  "max_tokens": int(max_tokens),
68
  "stream": False,
69
  }
70
  if stop:
71
  payload["stop"] = stop
72
- if extra:
73
- payload.update(extra)
 
 
74
 
75
  text, ok = self._try_once(payload)
76
  if ok:
@@ -113,19 +109,22 @@ class RouterRequestsClient:
113
  max_tokens: int,
114
  temperature: float,
115
  stop: Optional[List[str]] = None,
116
- extra: Optional[Dict[str, Any]] = None,
 
117
  ) -> Iterator[str]:
118
- payload: Dict[str, Any] = {
119
  "model": _model_with_provider(self.model, self.provider),
120
  "messages": _mk_messages(system_prompt, user_text),
121
- "temperature": float(temperature),
122
  "max_tokens": int(max_tokens),
123
  "stream": True,
124
  }
125
  if stop:
126
  payload["stop"] = stop
127
- if extra:
128
- payload.update(extra)
 
 
129
 
130
  # primary
131
  ok = False
 
1
  # app/core/inference/client.py
2
  import os, json, time, logging
3
+ from typing import Dict, List, Optional, Iterator, Tuple
4
 
5
  import requests
6
 
 
34
  Simple requests-only client for HF Router Chat Completions.
35
  Supports non-streaming (returns str) and streaming (yields token strings).
36
  """
37
+ def __init__(self, model: str, fallback: Optional[str] = None, provider: Optional[str] = None,
38
+ max_retries: int = 2, connect_timeout: float = 10.0, read_timeout: float = 60.0):
 
 
 
 
 
 
 
39
  self.model = model
40
  self.fallback = fallback if fallback != model else None
41
  self.provider = provider
 
51
  max_tokens: int,
52
  temperature: float,
53
  stop: Optional[List[str]] = None,
54
+ frequency_penalty: Optional[float] = None,
55
+ presence_penalty: Optional[float] = None,
56
  ) -> str:
57
+ payload = {
58
  "model": _model_with_provider(self.model, self.provider),
59
  "messages": _mk_messages(system_prompt, user_text),
60
+ "temperature": float(max(0.0, temperature)),
61
  "max_tokens": int(max_tokens),
62
  "stream": False,
63
  }
64
  if stop:
65
  payload["stop"] = stop
66
+ if frequency_penalty is not None:
67
+ payload["frequency_penalty"] = float(frequency_penalty)
68
+ if presence_penalty is not None:
69
+ payload["presence_penalty"] = float(presence_penalty)
70
 
71
  text, ok = self._try_once(payload)
72
  if ok:
 
109
  max_tokens: int,
110
  temperature: float,
111
  stop: Optional[List[str]] = None,
112
+ frequency_penalty: Optional[float] = None,
113
+ presence_penalty: Optional[float] = None,
114
  ) -> Iterator[str]:
115
+ payload = {
116
  "model": _model_with_provider(self.model, self.provider),
117
  "messages": _mk_messages(system_prompt, user_text),
118
+ "temperature": float(max(0.0, temperature)),
119
  "max_tokens": int(max_tokens),
120
  "stream": True,
121
  }
122
  if stop:
123
  payload["stop"] = stop
124
+ if frequency_penalty is not None:
125
+ payload["frequency_penalty"] = float(frequency_penalty)
126
+ if presence_penalty is not None:
127
+ payload["presence_penalty"] = float(presence_penalty)
128
 
129
  # primary
130
  ok = False
app/core/rag/retriever.py CHANGED
@@ -1,6 +1,6 @@
1
  # app/core/rag/retriever.py
2
  from __future__ import annotations
3
- import json, logging
4
  from pathlib import Path
5
  from typing import List, Dict, Optional
6
  import numpy as np
@@ -17,17 +17,24 @@ class Retriever:
17
  self.top_k = top_k
18
  if not self.kb_path.exists():
19
  raise FileNotFoundError(f"KB file not found: {self.kb_path} (jsonl with {{text,source}})")
20
- self.model = SentenceTransformer(model_name)
 
 
 
 
 
 
21
  self.docs: List[Dict[str, str]] = []
22
  with self.kb_path.open("r", encoding="utf-8") as f:
23
  for line in f:
24
  line = line.strip()
25
- if not line: continue
 
26
  self.docs.append(json.loads(line))
27
  texts = [d["text"] for d in self.docs]
28
  emb = self.model.encode(texts, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
29
  self.dim = int(emb.shape[1])
30
- self.index = faiss.IndexFlatIP(self.dim) # cosine via normalized vectors = dot product
31
  self.index.add(emb.astype("float32"))
32
 
33
  def retrieve(self, query: str, k: Optional[int] = None) -> List[Dict]:
@@ -36,7 +43,8 @@ class Retriever:
36
  D, I = self.index.search(vec.astype("float32"), k)
37
  out: List[Dict] = []
38
  for idx, score in zip(I[0], D[0]):
39
- if int(idx) < 0: continue
 
40
  d = self.docs[int(idx)]
41
  out.append({"text": d["text"], "source": d.get("source", f"kb:{idx}"), "score": float(score)})
42
  return out
 
1
  # app/core/rag/retriever.py
2
  from __future__ import annotations
3
+ import json, logging, os
4
  from pathlib import Path
5
  from typing import List, Dict, Optional
6
  import numpy as np
 
17
  self.top_k = top_k
18
  if not self.kb_path.exists():
19
  raise FileNotFoundError(f"KB file not found: {self.kb_path} (jsonl with {{text,source}})")
20
+
21
+ # Use a project-local cache to avoid '/.cache' permission issues
22
+ cache_dir = Path(os.getenv("HF_HOME", "./.cache"))
23
+ cache_dir.mkdir(parents=True, exist_ok=True)
24
+
25
+ self.model = SentenceTransformer(model_name, cache_folder=str(cache_dir))
26
+
27
  self.docs: List[Dict[str, str]] = []
28
  with self.kb_path.open("r", encoding="utf-8") as f:
29
  for line in f:
30
  line = line.strip()
31
+ if not line:
32
+ continue
33
  self.docs.append(json.loads(line))
34
  texts = [d["text"] for d in self.docs]
35
  emb = self.model.encode(texts, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
36
  self.dim = int(emb.shape[1])
37
+ self.index = faiss.IndexFlatIP(self.dim)
38
  self.index.add(emb.astype("float32"))
39
 
40
  def retrieve(self, query: str, k: Optional[int] = None) -> List[Dict]:
 
43
  D, I = self.index.search(vec.astype("float32"), k)
44
  out: List[Dict] = []
45
  for idx, score in zip(I[0], D[0]):
46
+ if int(idx) < 0:
47
+ continue
48
  d = self.docs[int(idx)]
49
  out.append({"text": d["text"], "source": d.get("source", f"kb:{idx}"), "score": float(score)})
50
  return out
app/services/chat_service.py CHANGED
@@ -14,27 +14,30 @@ from ..core.rag.retriever import Retriever
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
- # --- Optional cross-encoder reranker (graceful fallback) ---
18
  try:
19
- from sentence_transformers import CrossEncoder # type: ignore
20
- except Exception: # pragma: no cover
21
  CrossEncoder = None # type: ignore
22
 
 
23
  SYSTEM_PROMPT = (
24
  "You are MATRIX-AI, a concise assistant for the Matrix EcoSystem.\n"
25
- "Answer the user's question directly in 2–4 short sentences.\n"
26
- "Do NOT restate the question. Do NOT use labels like 'Question:' or 'Answer:'.\n"
27
- "Use the provided CONTEXT if present; if the answer is not supported by it, say you don't know.\n"
28
- "Do not ask follow-up questions unless the user explicitly asks you to."
29
  )
30
 
 
 
 
 
 
 
31
  # Thread-safe singleton retriever
32
  _retriever_instance: Optional[Retriever] = None
33
  _retriever_lock = threading.Lock()
34
 
35
-
36
  def get_retriever(settings: Settings) -> Optional[Retriever]:
37
- """Initialize and return a single Retriever instance (double-checked locking)."""
38
  global _retriever_instance
39
  if _retriever_instance is not None:
40
  return _retriever_instance
@@ -55,19 +58,15 @@ def get_retriever(settings: Settings) -> Optional[Retriever]:
55
  _retriever_instance = None
56
  return _retriever_instance
57
 
58
-
59
- # ----------------------------
60
- # Anti-repetition + de-label helpers
61
- # ----------------------------
62
  _SENT_SPLIT = re.compile(r'(?<=[\.\!\?])\s+')
63
  _NORM = re.compile(r'[^a-z0-9\s]+')
64
- _QA_LINE_RE = re.compile(r'^\s*(question|q|user)\s*:\s*', re.I)
65
- _ANSWER_PREFIX_RE = re.compile(r'^\s*(answer|a)\s*:\s*', re.I)
66
 
67
  def _norm_sentence(s: str) -> str:
68
  s = s.lower().strip()
69
  s = _NORM.sub(' ', s)
70
- return re.sub(r'\s+', ' ', s)
 
71
 
72
  def _jaccard(a: str, b: str) -> float:
73
  ta = set(a.split())
@@ -76,32 +75,7 @@ def _jaccard(a: str, b: str) -> float:
76
  return 0.0
77
  return len(ta & tb) / max(1, len(ta | tb))
78
 
79
- def _strip_qa_meta(text: str) -> str:
80
- """Drop lines like 'Question: ...' and leading 'Answer:' labels."""
81
- lines = text.splitlines()
82
- out: List[str] = []
83
- for i, l in enumerate(lines):
84
- if i == 0:
85
- l = _ANSWER_PREFIX_RE.sub('', l).strip()
86
- if _QA_LINE_RE.match(l):
87
- continue
88
- out.append(l)
89
- return "\n".join(out).strip()
90
-
91
- def _remove_query_echo(text: str, query: str, sim_threshold: float = 0.9) -> str:
92
- """Remove sentences that are near-duplicates of the original query."""
93
- qn = _norm_sentence(query)
94
- parts = _SENT_SPLIT.split(re.sub(r'\s+', ' ', text).strip()) or [text]
95
- kept: List[str] = []
96
- for s in parts:
97
- sn = _norm_sentence(s)
98
- if _jaccard(qn, sn) >= sim_threshold:
99
- continue
100
- kept.append(s.strip())
101
- return ' '.join(kept).strip()
102
-
103
  def _squash_repetition(text: str, max_sentences: int = 4, sim_threshold: float = 0.88) -> str:
104
- """Remove near-duplicate sentences while keeping order and cap total sentences."""
105
  t = re.sub(r'\s+', ' ', text).strip()
106
  if not t:
107
  return t
@@ -120,16 +94,19 @@ def _squash_repetition(text: str, max_sentences: int = 4, sim_threshold: float =
120
  break
121
  return ' '.join(out).strip()
122
 
123
- def _clean_answer(text: str, query: str) -> str:
124
- t = _strip_qa_meta(text)
125
- t = _remove_query_echo(t, query)
126
- t = _squash_repetition(t, max_sentences=4, sim_threshold=0.88)
127
- return t
128
 
 
 
 
 
 
 
 
129
 
130
- # ----------------------------
131
- # RAG helpers (query expansion, ranking, snippets)
132
- # ----------------------------
133
  _ALIAS_TABLE: Dict[str, List[str]] = {
134
  "matrixhub": ["matrix hub", "hub api", "catalog", "registry", "cas"],
135
  "mcp": ["model context protocol", "manifest", "server manifest", "admin api"],
@@ -184,9 +161,7 @@ def _best_paragraphs(text: str, query: str, max_chars: int = 700) -> str:
184
  break
185
  return "\n".join(picked)
186
 
187
- def _cross_encoder_scores(
188
- model: Optional["CrossEncoder"], query: str, docs: List[Dict], max_pairs: int = 50
189
- ) -> Optional[List[float]]:
190
  if not model:
191
  return None
192
  try:
@@ -196,9 +171,7 @@ def _cross_encoder_scores(
196
  logger.warning("Cross-encoder scoring failed; continuing without it (%s)", e)
197
  return None
198
 
199
- def _rerank_docs(
200
- docs: List[Dict], query: str, k_final: int, reranker: Optional["CrossEncoder"] = None
201
- ) -> List[Dict]:
202
  if not docs:
203
  return []
204
  vec_scores = [float(d.get("score", 0.0)) for d in docs]
@@ -226,6 +199,7 @@ def _rerank_docs(
226
  if ce_norm is not None:
227
  score = 0.80 * score + 0.20 * ce_norm[i]
228
  merged.append((score, d))
 
229
  merged.sort(key=lambda x: x[0], reverse=True)
230
  return [d for _s, d in merged[:k_final]]
231
 
@@ -242,7 +216,6 @@ def _build_context_from_docs(docs: List[Dict], query: str, max_blocks: int = 4)
242
  prelude = "CONTEXT (use only these facts; if missing, say you don't know):"
243
  return prelude + "\n\n" + "\n\n".join(blocks), sources
244
 
245
-
246
  # ----------------------------
247
  # Service
248
  # ----------------------------
@@ -268,10 +241,6 @@ class ChatService:
268
  except Exception as e:
269
  logger.warning("Reranker disabled: %s", e)
270
 
271
- # default inference knobs to reduce repetition
272
- self._stop = ["\nQuestion:", "\nUser:", "\nQ:", "\nAnswer:", "\nA:"]
273
- self._extra = {"frequency_penalty": 0.2, "presence_penalty": 0.0}
274
-
275
  # ---------- RAG core ----------
276
  def _retrieve_best(self, query: str) -> Tuple[str, List[str]]:
277
  if not self.retriever:
@@ -292,17 +261,13 @@ class ChatService:
292
  def _augment(self, query: str) -> Tuple[str, List[str]]:
293
  ctx, sources = self._retrieve_best(query)
294
  if ctx:
295
- # No Q:/A: labels — just a clear directive + the raw question
296
  user_msg = (
297
  f"{ctx}\n\n"
298
- "Using only the context above, respond concisely (2–4 sentences) to this query.\n"
299
  f"{query}"
300
  )
301
  else:
302
- user_msg = (
303
- "Respond concisely (2–4 sentences). Do not restate the question or add labels.\n"
304
- f"{query}"
305
- )
306
  return user_msg, sources
307
 
308
  # ---------- Non-stream ----------
@@ -313,36 +278,35 @@ class ChatService:
313
  user_msg,
314
  max_tokens=self.settings.model.max_new_tokens,
315
  temperature=self.settings.model.temperature,
316
- stop=self._stop,
317
- extra=self._extra,
 
318
  )
319
- text = _clean_answer(text, query)
320
  return text, sources
321
 
322
  # ---------- Stream ----------
323
  def stream_answer(self, query: str):
324
- """
325
- Stream while cleaning: suppress Q/A labels and near-duplicate lines as they appear.
326
- """
327
  user_msg, _ = self._augment(query)
328
  raw = self.client.chat_stream(
329
  SYSTEM_PROMPT,
330
  user_msg,
331
  max_tokens=self.settings.model.max_new_tokens,
332
  temperature=self.settings.model.temperature,
333
- stop=self._stop,
334
- extra=self._extra,
 
335
  )
336
 
337
- buf = "" # collected raw
338
- emitted = "" # cleaned we already sent
339
  for token in raw:
340
  if not token:
341
  continue
342
  buf += token
343
- cleaned = _clean_answer(buf, query)
 
344
  if len(cleaned) < len(emitted):
345
- # parser got stricter; resync
346
  emitted = cleaned
347
  continue
348
  delta = cleaned[len(emitted):]
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
17
  try:
18
+ from sentence_transformers import CrossEncoder # optional
19
+ except Exception:
20
  CrossEncoder = None # type: ignore
21
 
22
+ # Tighter, grounding-first instruction + anti-question/label rules
23
  SYSTEM_PROMPT = (
24
  "You are MATRIX-AI, a concise assistant for the Matrix EcoSystem.\n"
25
+ "Use the provided CONTEXT strictly when present. If the answer is not supported by the context, say you don't know.\n"
26
+ "Reply in 2–4 short sentences. Do NOT include labels like 'Question:' or 'Answer:' in your output.\n"
27
+ "Do NOT ask me questions unless I explicitly asked you to. Do NOT repeat yourself.\n"
 
28
  )
29
 
30
+ # Hard stops if the model tries to start a new question/role header
31
+ STOP_SEQS: List[str] = [
32
+ "\nQuestion:", "Question:", "\nQ:", "Q:",
33
+ "\nUser:", "User:", "\nAssistant:", "Assistant:"
34
+ ]
35
+
36
  # Thread-safe singleton retriever
37
  _retriever_instance: Optional[Retriever] = None
38
  _retriever_lock = threading.Lock()
39
 
 
40
  def get_retriever(settings: Settings) -> Optional[Retriever]:
 
41
  global _retriever_instance
42
  if _retriever_instance is not None:
43
  return _retriever_instance
 
58
  _retriever_instance = None
59
  return _retriever_instance
60
 
61
+ # ---------- anti-repetition / anti-label helpers ----------
 
 
 
62
  _SENT_SPLIT = re.compile(r'(?<=[\.\!\?])\s+')
63
  _NORM = re.compile(r'[^a-z0-9\s]+')
 
 
64
 
65
  def _norm_sentence(s: str) -> str:
66
  s = s.lower().strip()
67
  s = _NORM.sub(' ', s)
68
+ s = re.sub(r'\s+', ' ', s)
69
+ return s
70
 
71
  def _jaccard(a: str, b: str) -> float:
72
  ta = set(a.split())
 
75
  return 0.0
76
  return len(ta & tb) / max(1, len(ta | tb))
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def _squash_repetition(text: str, max_sentences: int = 4, sim_threshold: float = 0.88) -> str:
 
79
  t = re.sub(r'\s+', ' ', text).strip()
80
  if not t:
81
  return t
 
94
  break
95
  return ' '.join(out).strip()
96
 
97
+ # Strip common label patterns
98
+ _LABEL_PREFIX = re.compile(r'^\s*(?:Answer:|A:)\s*', re.IGNORECASE)
99
+ _LABEL_INLINE_Q = re.compile(r'\s*(?:Question:|Q:)\s*$', re.IGNORECASE)
 
 
100
 
101
+ def _strip_labels(text: str) -> str:
102
+ s = _LABEL_PREFIX.sub('', text)
103
+ # If the model tries to end with "Question:" remove that tail prompt
104
+ s = _LABEL_INLINE_Q.sub('', s)
105
+ # also remove mid-text accidental "Answer:" fragments
106
+ s = re.sub(r'\b(?:Answer:|A:)\s*', '', s, flags=re.IGNORECASE)
107
+ return s.strip()
108
 
109
+ # ---------- RAG utilities (ranking & snippets) ----------
 
 
110
  _ALIAS_TABLE: Dict[str, List[str]] = {
111
  "matrixhub": ["matrix hub", "hub api", "catalog", "registry", "cas"],
112
  "mcp": ["model context protocol", "manifest", "server manifest", "admin api"],
 
161
  break
162
  return "\n".join(picked)
163
 
164
+ def _cross_encoder_scores(model: Optional["CrossEncoder"], query: str, docs: List[Dict], max_pairs: int = 50) -> Optional[List[float]]:
 
 
165
  if not model:
166
  return None
167
  try:
 
171
  logger.warning("Cross-encoder scoring failed; continuing without it (%s)", e)
172
  return None
173
 
174
+ def _rerank_docs(docs: List[Dict], query: str, k_final: int, reranker: Optional["CrossEncoder"] = None) -> List[Dict]:
 
 
175
  if not docs:
176
  return []
177
  vec_scores = [float(d.get("score", 0.0)) for d in docs]
 
199
  if ce_norm is not None:
200
  score = 0.80 * score + 0.20 * ce_norm[i]
201
  merged.append((score, d))
202
+
203
  merged.sort(key=lambda x: x[0], reverse=True)
204
  return [d for _s, d in merged[:k_final]]
205
 
 
216
  prelude = "CONTEXT (use only these facts; if missing, say you don't know):"
217
  return prelude + "\n\n" + "\n\n".join(blocks), sources
218
 
 
219
  # ----------------------------
220
  # Service
221
  # ----------------------------
 
241
  except Exception as e:
242
  logger.warning("Reranker disabled: %s", e)
243
 
 
 
 
 
244
  # ---------- RAG core ----------
245
  def _retrieve_best(self, query: str) -> Tuple[str, List[str]]:
246
  if not self.retriever:
 
261
  def _augment(self, query: str) -> Tuple[str, List[str]]:
262
  ctx, sources = self._retrieve_best(query)
263
  if ctx:
 
264
  user_msg = (
265
  f"{ctx}\n\n"
266
+ "Based only on the context above, answer succinctly in 2–4 sentences.\n"
267
  f"{query}"
268
  )
269
  else:
270
+ user_msg = f"Answer succinctly in 2–4 sentences. Do not repeat yourself.\n{query}"
 
 
 
271
  return user_msg, sources
272
 
273
  # ---------- Non-stream ----------
 
278
  user_msg,
279
  max_tokens=self.settings.model.max_new_tokens,
280
  temperature=self.settings.model.temperature,
281
+ stop=STOP_SEQS,
282
+ frequency_penalty=0.2, # mild anti-repeat
283
+ presence_penalty=0.0,
284
  )
285
+ text = _strip_labels(_squash_repetition(text, max_sentences=4, sim_threshold=0.88))
286
  return text, sources
287
 
288
  # ---------- Stream ----------
289
  def stream_answer(self, query: str):
 
 
 
290
  user_msg, _ = self._augment(query)
291
  raw = self.client.chat_stream(
292
  SYSTEM_PROMPT,
293
  user_msg,
294
  max_tokens=self.settings.model.max_new_tokens,
295
  temperature=self.settings.model.temperature,
296
+ stop=STOP_SEQS,
297
+ frequency_penalty=0.2,
298
+ presence_penalty=0.0,
299
  )
300
 
301
+ buf = ""
302
+ emitted = ""
303
  for token in raw:
304
  if not token:
305
  continue
306
  buf += token
307
+ cleaned = _squash_repetition(buf, max_sentences=4, sim_threshold=0.88)
308
+ cleaned = _strip_labels(cleaned)
309
  if len(cleaned) < len(emitted):
 
310
  emitted = cleaned
311
  continue
312
  delta = cleaned[len(emitted):]