Jan Biermeyer commited on
Commit
34fc1eb
Β·
1 Parent(s): a905164

cpu optimization

Browse files
rag/model_loader.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- SUPRA Enhanced Model Loader for M2 Max
4
- Optimized model loading with MPS acceleration and Streamlit caching
5
  """
6
 
7
  import torch
@@ -28,41 +28,43 @@ except ImportError:
28
  logger.warning("⚠️ PEFT not available. LoRA adapter loading will be disabled.")
29
 
30
  def setup_m2_max_optimizations():
31
- """Configure optimizations for M2 Max."""
32
- logger.info("🍎 Setting up M2 Max optimizations for model loading...")
33
 
34
- # M2 Max specific environment variables
35
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
36
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
37
 
38
- # Disable bitsandbytes for M2 Max (not needed with MPS)
39
- os.environ["DISABLE_BITSANDBYTES"] = "1"
40
-
41
  # Set up Hugging Face token from HUGGINGFACE_TOKEN
42
  if os.environ.get("HUGGINGFACE_TOKEN") and not os.environ.get("HF_TOKEN"):
43
  os.environ["HF_TOKEN"] = os.environ["HUGGINGFACE_TOKEN"]
44
  logger.info("πŸ”‘ Using HUGGINGFACE_TOKEN for Hugging Face authentication")
45
 
46
- # Memory management
47
  if torch.backends.mps.is_available():
48
- logger.info("βœ… MPS (Metal Performance Shaders) available")
49
  device = "mps"
 
 
 
 
 
 
 
50
  else:
51
- logger.info("⚠️ MPS not available, using CPU")
52
  device = "cpu"
53
-
54
- # Optimize PyTorch for M2 Max
55
- torch.backends.mps.is_built()
56
 
57
  logger.info(f"πŸ”§ Using device: {device}")
58
  return device
59
 
60
  @st.cache_resource
61
  def load_enhanced_model_m2max() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
62
- """Load the enhanced SUPRA model optimized for M2 Max with caching."""
63
- logger.info("πŸ“₯ Loading enhanced SUPRA model for M2 Max...")
64
 
65
- # Setup M2 Max optimizations
66
  device = setup_m2_max_optimizations()
67
 
68
  # Model paths - try local lora/ folder first (for deployment), then outputs directory
@@ -111,23 +113,23 @@ def load_enhanced_model_m2max() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
111
  base_model_name = adapter_config.get("base_model_name_or_path")
112
  logger.info(f"πŸ“– Base model from adapter config: {base_model_name}")
113
 
114
- # Use non-quantized version for M2 Max (MPS), quantized for CUDA
115
- # Check if we're on MPS (M2 Max) or CUDA
116
  is_mps = torch.backends.mps.is_available()
 
117
 
118
  if base_model_name and "llama" in base_model_name.lower():
119
  if is_mps:
120
- # M2 Max: Use non-quantized model (no bitsandbytes needed)
121
  base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
122
  else:
123
- # CUDA: Use quantized Unsloth version
124
  base_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
125
  elif base_model_name and "mistral" in base_model_name.lower():
126
  if is_mps:
127
- # M2 Max: Use non-quantized model
128
  base_model_name = "mistralai/Mistral-7B-Instruct-v0.3"
129
  else:
130
- # CUDA: Use quantized Unsloth version
131
  base_model_name = "unsloth/Mistral-7B-Instruct-v0.3-bnb-4bit"
132
  except Exception as e:
133
  logger.warning(f"⚠️ Could not read adapter config: {e}")
@@ -137,6 +139,7 @@ def load_enhanced_model_m2max() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
137
  if is_mps:
138
  base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
139
  else:
 
140
  base_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
141
 
142
  # Fallback to old checkpoint structure
@@ -163,9 +166,9 @@ def load_enhanced_model_m2max() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
163
  if base_model_name is None:
164
  is_mps = torch.backends.mps.is_available()
165
  if is_mps:
166
- base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" # M2 Max: non-quantized
167
  else:
168
- base_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" # CUDA: quantized
169
 
170
  if use_local:
171
  logger.info(f"πŸ“š Loading base model: {base_model_name}")
@@ -196,21 +199,72 @@ def load_enhanced_model_m2max() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
196
 
197
  logger.info("βœ… Tokenizer loaded successfully")
198
 
199
- # Load base model with M2 Max optimizations
200
- logger.info("πŸ€– Loading base model with M2 Max optimizations...")
201
  # Use /workspace/.cache if WORKSPACE is set, otherwise use .cache relative to current dir
202
  cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or "/workspace/.cache/huggingface" if os.getenv("WORKSPACE") else ".cache/huggingface"
203
  offload_dir = os.getenv("WORKSPACE", "") + "/.cache/offload" if os.getenv("WORKSPACE") else ".cache/offload"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  base_model = AutoModelForCausalLM.from_pretrained(
205
  base_model_name,
206
- cache_dir=cache_dir,
207
- torch_dtype=torch.float16, # Use float16 for memory efficiency
208
- device_map="auto", # Let transformers handle device placement
209
- offload_folder=offload_dir, # Allow CPU offload when needed
210
- trust_remote_code=True,
211
- low_cpu_mem_usage=True, # Optimize for M2 Max memory
212
- load_in_8bit=False, # Disable 8-bit quantization (not needed for M2 Max)
213
- load_in_4bit=False # Disable 4-bit quantization (not needed for M2 Max)
214
  )
215
 
216
  logger.info("βœ… Base model loaded successfully")
@@ -249,21 +303,64 @@ def load_enhanced_model_m2max() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
249
 
250
  logger.info("βœ… Tokenizer loaded successfully")
251
 
252
- # Load base model (no LoRA adapter)
253
- logger.info("πŸ€– Loading base model with M2 Max optimizations (no fine-tuning)...")
254
  # Use /workspace/.cache if WORKSPACE is set, otherwise use .cache relative to current dir
255
  cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or "/workspace/.cache/huggingface" if os.getenv("WORKSPACE") else ".cache/huggingface"
256
  offload_dir = os.getenv("WORKSPACE", "") + "/.cache/offload" if os.getenv("WORKSPACE") else ".cache/offload"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  model = AutoModelForCausalLM.from_pretrained(
258
  base_model_name,
259
- cache_dir=cache_dir,
260
- torch_dtype=torch.float16,
261
- device_map="auto",
262
- offload_folder=offload_dir,
263
- trust_remote_code=True,
264
- low_cpu_mem_usage=True,
265
- load_in_8bit=False,
266
- load_in_4bit=False
267
  )
268
 
269
  logger.info("βœ… Base model loaded successfully (no fine-tuning)")
@@ -287,17 +384,40 @@ def load_enhanced_model_m2max() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
287
  if tokenizer.pad_token is None:
288
  tokenizer.pad_token = tokenizer.eos_token
289
 
290
- # Load model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  model = AutoModelForCausalLM.from_pretrained(
292
  base_model_name,
293
- cache_dir=cache_dir,
294
- torch_dtype=torch.float16,
295
- device_map="auto",
296
- offload_folder=offload_dir,
297
- trust_remote_code=True,
298
- low_cpu_mem_usage=True,
299
- load_in_8bit=False, # Disable 8-bit quantization (not needed for M2 Max)
300
- load_in_4bit=False # Disable 4-bit quantization (not needed for M2 Max)
301
  )
302
 
303
  logger.info("βœ… Model loaded from Hugging Face successfully")
@@ -338,6 +458,7 @@ def get_model_info() -> dict:
338
 
339
  # Determine base model based on device
340
  is_mps = torch.backends.mps.is_available()
 
341
  if tiny_models and tiny_models[0].exists() or small_models and small_models[0].exists() or prod_models and prod_models[0].exists():
342
  base_model = "meta-llama/Meta-Llama-3.1-8B-Instruct" if is_mps else "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
343
  else:
@@ -366,7 +487,7 @@ def generate_response_optimized(
366
  temperature: float = 0.7, # Adjusted for better quality
367
  top_p: float = 0.9
368
  ) -> str:
369
- """Generate response with M2 Max optimizations and full-sentence stopping."""
370
  try:
371
  # Import inference utilities
372
  from .inference_utils import create_stopping_criteria, ensure_supra_close
@@ -411,18 +532,40 @@ def generate_response_optimized(
411
  padding=False
412
  )
413
 
414
- # Move to same device as model
415
- device = next(model.parameters()).device
416
- inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
417
 
418
  # Create stopping criteria for full-sentence stopping
419
  stopping_criteria = create_stopping_criteria(tokenizer)
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  # Generate response with full-sentence stopping
422
  with torch.no_grad():
423
  outputs = model.generate(
424
  **inputs,
425
- max_new_tokens=max_new_tokens,
426
  temperature=temperature,
427
  top_p=top_p,
428
  do_sample=True,
 
1
  #!/usr/bin/env python3
2
  """
3
+ SUPRA Enhanced Model Loader
4
+ Optimized model loading with CPU/MPS/CUDA support and Streamlit caching
5
  """
6
 
7
  import torch
 
28
  logger.warning("⚠️ PEFT not available. LoRA adapter loading will be disabled.")
29
 
30
  def setup_m2_max_optimizations():
31
+ """Configure optimizations for CPU/MPS/CUDA."""
32
+ logger.info("πŸ”§ Setting up device optimizations for model loading...")
33
 
34
+ # Environment variables
 
35
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
36
 
 
 
 
37
  # Set up Hugging Face token from HUGGINGFACE_TOKEN
38
  if os.environ.get("HUGGINGFACE_TOKEN") and not os.environ.get("HF_TOKEN"):
39
  os.environ["HF_TOKEN"] = os.environ["HUGGINGFACE_TOKEN"]
40
  logger.info("πŸ”‘ Using HUGGINGFACE_TOKEN for Hugging Face authentication")
41
 
42
+ # Detect device: MPS > CUDA > CPU
43
  if torch.backends.mps.is_available():
44
+ logger.info("βœ… MPS (Metal Performance Shaders) available - using MPS")
45
  device = "mps"
46
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
47
+ os.environ["DISABLE_BITSANDBYTES"] = "1" # Disable for MPS
48
+ torch.backends.mps.is_built()
49
+ elif torch.cuda.is_available():
50
+ logger.info("βœ… CUDA available - using GPU")
51
+ device = "cuda"
52
+ os.environ.pop("DISABLE_BITSANDBYTES", None) # Enable bitsandbytes for CUDA
53
  else:
54
+ logger.info("πŸ’» CPU detected - enabling CPU optimizations")
55
  device = "cpu"
56
+ os.environ.pop("DISABLE_BITSANDBYTES", None) # Enable bitsandbytes for CPU
57
+ os.environ.pop("PYTORCH_ENABLE_MPS_FALLBACK", None)
 
58
 
59
  logger.info(f"πŸ”§ Using device: {device}")
60
  return device
61
 
62
  @st.cache_resource
63
  def load_enhanced_model_m2max() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
64
+ """Load the enhanced SUPRA model with device-specific optimizations (CPU/MPS/CUDA) with caching."""
65
+ logger.info("πŸ“₯ Loading enhanced SUPRA model with device optimizations...")
66
 
67
+ # Setup device optimizations
68
  device = setup_m2_max_optimizations()
69
 
70
  # Model paths - try local lora/ folder first (for deployment), then outputs directory
 
113
  base_model_name = adapter_config.get("base_model_name_or_path")
114
  logger.info(f"πŸ“– Base model from adapter config: {base_model_name}")
115
 
116
+ # Select model version based on device: non-quantized for MPS, quantized for CPU/CUDA
 
117
  is_mps = torch.backends.mps.is_available()
118
+ is_cpu = not is_mps and not torch.cuda.is_available()
119
 
120
  if base_model_name and "llama" in base_model_name.lower():
121
  if is_mps:
122
+ # MPS: Use non-quantized model (no bitsandbytes needed)
123
  base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
124
  else:
125
+ # CPU/CUDA: Use quantized Unsloth version
126
  base_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
127
  elif base_model_name and "mistral" in base_model_name.lower():
128
  if is_mps:
129
+ # MPS: Use non-quantized model
130
  base_model_name = "mistralai/Mistral-7B-Instruct-v0.3"
131
  else:
132
+ # CPU/CUDA: Use quantized Unsloth version
133
  base_model_name = "unsloth/Mistral-7B-Instruct-v0.3-bnb-4bit"
134
  except Exception as e:
135
  logger.warning(f"⚠️ Could not read adapter config: {e}")
 
139
  if is_mps:
140
  base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
141
  else:
142
+ # CPU/CUDA: Use quantized version
143
  base_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
144
 
145
  # Fallback to old checkpoint structure
 
166
  if base_model_name is None:
167
  is_mps = torch.backends.mps.is_available()
168
  if is_mps:
169
+ base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" # MPS: non-quantized
170
  else:
171
+ base_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" # CPU/CUDA: quantized
172
 
173
  if use_local:
174
  logger.info(f"πŸ“š Loading base model: {base_model_name}")
 
199
 
200
  logger.info("βœ… Tokenizer loaded successfully")
201
 
202
+ # Load base model with device-specific optimizations
203
+ logger.info("πŸ€– Loading base model with device optimizations...")
204
  # Use /workspace/.cache if WORKSPACE is set, otherwise use .cache relative to current dir
205
  cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or "/workspace/.cache/huggingface" if os.getenv("WORKSPACE") else ".cache/huggingface"
206
  offload_dir = os.getenv("WORKSPACE", "") + "/.cache/offload" if os.getenv("WORKSPACE") else ".cache/offload"
207
+
208
+ # Detect device type for optimization
209
+ is_cpu = device == "cpu"
210
+ is_mps = device == "mps"
211
+ is_cuda = device == "cuda"
212
+
213
+ # Configure quantization for CPU
214
+ quantization_config = None
215
+ if is_cpu:
216
+ try:
217
+ from transformers import BitsAndBytesConfig
218
+ quantization_config = BitsAndBytesConfig(
219
+ load_in_8bit=True,
220
+ llm_int8_enable_fp32_cpu_offload=True
221
+ )
222
+ logger.info("πŸ’» Using 8-bit quantization for CPU")
223
+ except ImportError:
224
+ logger.warning("⚠️ bitsandbytes not available, loading without quantization")
225
+
226
+ # Set dtype and quantization settings based on device
227
+ if is_cpu:
228
+ torch_dtype = torch.float32 # CPU: use float32
229
+ # If quantization_config is provided, don't also pass load_in_8bit
230
+ load_in_8bit = False if quantization_config else False
231
+ load_in_4bit = False
232
+ elif is_mps:
233
+ torch_dtype = torch.float16 # MPS: use float16
234
+ load_in_8bit = False
235
+ load_in_4bit = False
236
+ else: # CUDA
237
+ torch_dtype = torch.float16 # CUDA: use float16
238
+ load_in_8bit = False # CUDA can use 4-bit if needed
239
+ load_in_4bit = False
240
+
241
+ # Build model loading kwargs
242
+ model_kwargs = {
243
+ "cache_dir": cache_dir,
244
+ "torch_dtype": torch_dtype,
245
+ "trust_remote_code": True,
246
+ "low_cpu_mem_usage": True,
247
+ }
248
+
249
+ # Add device-specific settings
250
+ if is_cpu:
251
+ if quantization_config:
252
+ model_kwargs["quantization_config"] = quantization_config
253
+ # For CPU, don't use device_map (model stays on CPU)
254
+ model_kwargs["offload_folder"] = offload_dir
255
+ else:
256
+ model_kwargs["device_map"] = "auto"
257
+ if not is_mps: # For CUDA, we can add offload if needed
258
+ model_kwargs["offload_folder"] = offload_dir
259
+
260
+ # Add quantization flags only if quantization_config is None
261
+ if not quantization_config:
262
+ model_kwargs["load_in_8bit"] = load_in_8bit
263
+ model_kwargs["load_in_4bit"] = load_in_4bit
264
+
265
  base_model = AutoModelForCausalLM.from_pretrained(
266
  base_model_name,
267
+ **model_kwargs
 
 
 
 
 
 
 
268
  )
269
 
270
  logger.info("βœ… Base model loaded successfully")
 
303
 
304
  logger.info("βœ… Tokenizer loaded successfully")
305
 
306
+ # Load base model (no LoRA adapter) with device-specific optimizations
307
+ logger.info("πŸ€– Loading base model with device optimizations (no fine-tuning)...")
308
  # Use /workspace/.cache if WORKSPACE is set, otherwise use .cache relative to current dir
309
  cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or "/workspace/.cache/huggingface" if os.getenv("WORKSPACE") else ".cache/huggingface"
310
  offload_dir = os.getenv("WORKSPACE", "") + "/.cache/offload" if os.getenv("WORKSPACE") else ".cache/offload"
311
+
312
+ # Detect device type for optimization
313
+ is_cpu = device == "cpu"
314
+ is_mps = device == "mps"
315
+
316
+ # Configure quantization for CPU
317
+ quantization_config = None
318
+ if is_cpu:
319
+ try:
320
+ from transformers import BitsAndBytesConfig
321
+ quantization_config = BitsAndBytesConfig(
322
+ load_in_8bit=True,
323
+ llm_int8_enable_fp32_cpu_offload=True
324
+ )
325
+ logger.info("πŸ’» Using 8-bit quantization for CPU")
326
+ except ImportError:
327
+ logger.warning("⚠️ bitsandbytes not available, loading without quantization")
328
+
329
+ # Set dtype and quantization settings based on device
330
+ if is_cpu:
331
+ torch_dtype = torch.float32
332
+ load_in_8bit = False if quantization_config else False
333
+ load_in_4bit = False
334
+ else:
335
+ torch_dtype = torch.float16
336
+ load_in_8bit = False
337
+ load_in_4bit = False
338
+
339
+ # Build model loading kwargs
340
+ model_kwargs = {
341
+ "cache_dir": cache_dir,
342
+ "torch_dtype": torch_dtype,
343
+ "trust_remote_code": True,
344
+ "low_cpu_mem_usage": True,
345
+ }
346
+
347
+ # Add device-specific settings
348
+ if is_cpu:
349
+ if quantization_config:
350
+ model_kwargs["quantization_config"] = quantization_config
351
+ model_kwargs["offload_folder"] = offload_dir
352
+ else:
353
+ model_kwargs["device_map"] = "auto"
354
+ model_kwargs["offload_folder"] = offload_dir
355
+
356
+ # Add quantization flags only if quantization_config is None
357
+ if not quantization_config:
358
+ model_kwargs["load_in_8bit"] = load_in_8bit
359
+ model_kwargs["load_in_4bit"] = load_in_4bit
360
+
361
  model = AutoModelForCausalLM.from_pretrained(
362
  base_model_name,
363
+ **model_kwargs
 
 
 
 
 
 
 
364
  )
365
 
366
  logger.info("βœ… Base model loaded successfully (no fine-tuning)")
 
384
  if tokenizer.pad_token is None:
385
  tokenizer.pad_token = tokenizer.eos_token
386
 
387
+ # Load model with device-specific optimizations (fallback code - usually not used)
388
+ is_cpu = device == "cpu"
389
+ quantization_config = None
390
+ if is_cpu:
391
+ try:
392
+ from transformers import BitsAndBytesConfig
393
+ quantization_config = BitsAndBytesConfig(
394
+ load_in_8bit=True,
395
+ llm_int8_enable_fp32_cpu_offload=True
396
+ )
397
+ except ImportError:
398
+ pass
399
+
400
+ # Build model loading kwargs
401
+ model_kwargs = {
402
+ "cache_dir": cache_dir,
403
+ "torch_dtype": torch.float32 if is_cpu else torch.float16,
404
+ "trust_remote_code": True,
405
+ "low_cpu_mem_usage": True,
406
+ }
407
+
408
+ if is_cpu:
409
+ if quantization_config:
410
+ model_kwargs["quantization_config"] = quantization_config
411
+ model_kwargs["offload_folder"] = offload_dir
412
+ else:
413
+ model_kwargs["device_map"] = "auto"
414
+ model_kwargs["offload_folder"] = offload_dir
415
+ model_kwargs["load_in_8bit"] = False
416
+ model_kwargs["load_in_4bit"] = False
417
+
418
  model = AutoModelForCausalLM.from_pretrained(
419
  base_model_name,
420
+ **model_kwargs
 
 
 
 
 
 
 
421
  )
422
 
423
  logger.info("βœ… Model loaded from Hugging Face successfully")
 
458
 
459
  # Determine base model based on device
460
  is_mps = torch.backends.mps.is_available()
461
+ is_cpu = not is_mps and not torch.cuda.is_available()
462
  if tiny_models and tiny_models[0].exists() or small_models and small_models[0].exists() or prod_models and prod_models[0].exists():
463
  base_model = "meta-llama/Meta-Llama-3.1-8B-Instruct" if is_mps else "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
464
  else:
 
487
  temperature: float = 0.7, # Adjusted for better quality
488
  top_p: float = 0.9
489
  ) -> str:
490
+ """Generate response with device-specific optimizations and full-sentence stopping."""
491
  try:
492
  # Import inference utilities
493
  from .inference_utils import create_stopping_criteria, ensure_supra_close
 
532
  padding=False
533
  )
534
 
535
+ # Move to same device as model (handle quantized models on CPU)
536
+ try:
537
+ device = next(model.parameters()).device
538
+ inputs = {k: v.to(device) for k, v in inputs.items()}
539
+ except (StopIteration, AttributeError):
540
+ # Quantized models on CPU might not have .device on parameters
541
+ # Check if model has a device attribute or default to CPU
542
+ if hasattr(model, 'device'):
543
+ device = model.device
544
+ else:
545
+ device = torch.device('cpu')
546
+ inputs = {k: v.to(device) for k, v in inputs.items()}
547
 
548
  # Create stopping criteria for full-sentence stopping
549
  stopping_criteria = create_stopping_criteria(tokenizer)
550
 
551
+ # Reduce max_new_tokens for CPU to optimize performance
552
+ try:
553
+ model_device = next(model.parameters()).device if hasattr(model, 'parameters') else None
554
+ is_cpu_device = model_device is None or str(model_device) == 'cpu'
555
+ except (StopIteration, AttributeError):
556
+ is_cpu_device = True
557
+
558
+ # Adjust max_new_tokens for CPU (reduce for faster inference)
559
+ effective_max_tokens = max_new_tokens
560
+ if is_cpu_device and max_new_tokens > 512:
561
+ effective_max_tokens = 512
562
+ logger.info(f"πŸ’» CPU detected: reducing max_new_tokens from {max_new_tokens} to {effective_max_tokens} for faster inference")
563
+
564
  # Generate response with full-sentence stopping
565
  with torch.no_grad():
566
  outputs = model.generate(
567
  **inputs,
568
+ max_new_tokens=effective_max_tokens,
569
  temperature=temperature,
570
  top_p=top_p,
571
  do_sample=True,
rag/{rag_m2max.py β†’ rag.py} RENAMED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- SUPRA RAG System with M2 Max Optimizations
4
- Optimized for Apple Silicon with efficient memory management
5
  """
6
 
7
  import json
@@ -18,7 +18,7 @@ import logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- class SupraRAGM2Max:
22
  def __init__(self, rag_data_path: str = None):
23
  # Default RAG data path (for HF Spaces deployment)
24
  if rag_data_path is None:
@@ -37,17 +37,19 @@ class SupraRAGM2Max:
37
  rag_data_path = "data/processed/rag_seeds/rag_seeds.jsonl"
38
  self.rag_data_path = Path(rag_data_path)
39
 
40
- # M2 Max optimizations
41
- self._setup_m2_max_optimizations()
42
 
43
- # Initialize ChromaDB with M2 Max optimizations
44
  self.client = chromadb.Client()
45
  self.collection_name = "supra_knowledge"
46
 
47
- # Use efficient embedding model for M2 Max
 
 
48
  self.embedding_model = SentenceTransformer(
49
  'all-MiniLM-L6-v2',
50
- device='cpu' # Force CPU for M2 Max compatibility
51
  )
52
 
53
  # Initialize or load collection
@@ -71,29 +73,30 @@ class SupraRAGM2Max:
71
  self.collection = self.client.create_collection(self.collection_name)
72
  self._load_rag_documents()
73
 
74
- def _setup_m2_max_optimizations(self):
75
- """Configure optimizations for M2 Max."""
76
- logger.info("🍎 Setting up M2 Max optimizations...")
77
 
78
- # M2 Max specific environment variables
79
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
80
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
81
 
82
- # Memory management
83
  if torch.backends.mps.is_available():
84
- logger.info("βœ… MPS (Metal Performance Shaders) available")
85
  self.device = "mps"
 
 
 
 
 
86
  else:
87
- logger.info("⚠️ MPS not available, using CPU")
88
  self.device = "cpu"
89
 
90
- # Optimize PyTorch for M2 Max
91
- torch.backends.mps.is_built()
92
-
93
  logger.info(f"πŸ”§ Using device: {self.device}")
94
 
95
  def _load_rag_documents(self):
96
- """Load RAG documents from JSONL file with M2 Max optimizations."""
97
  if not self.rag_data_path.exists():
98
  logger.warning("⚠️ RAG data file not found")
99
  if st:
@@ -112,7 +115,7 @@ class SupraRAGM2Max:
112
  try:
113
  doc = json.loads(line)
114
  if 'content' in doc and 'id' in doc:
115
- # Truncate content for M2 Max memory efficiency
116
  content = doc['content']
117
  if len(content) > 2000: # Limit content length
118
  content = content[:2000] + "..."
@@ -131,8 +134,8 @@ class SupraRAGM2Max:
131
  logger.warning(f"⚠️ Skipping line {line_num}: JSON decode error - {e}")
132
 
133
  if documents:
134
- # Add to ChromaDB with batch processing for M2 Max
135
- batch_size = 50 # Smaller batches for M2 Max
136
  for i in range(0, len(documents), batch_size):
137
  batch_docs = documents[i:i+batch_size]
138
  batch_metadatas = metadatas[i:i+batch_size]
@@ -265,13 +268,13 @@ class SupraRAGM2Max:
265
  st.error(f"Error generating response: {e}")
266
  return f"I apologize, but I encountered an error while generating a response: {e}"
267
 
268
- # Global RAG instance with M2 Max optimizations
269
  @st.cache_resource
270
  def get_supra_rag_m2max():
271
- """Get cached SUPRA RAG instance optimized for M2 Max."""
272
  return SupraRAGM2Max()
273
 
274
  # Backward compatibility
275
  def get_supra_rag():
276
- """Backward compatible function that returns M2 Max optimized RAG."""
277
  return get_supra_rag_m2max()
 
1
  #!/usr/bin/env python3
2
  """
3
+ SUPRA RAG System with CPU/MPS/CUDA Optimizations
4
+ Optimized for CPU (HF Spaces), MPS (Apple Silicon), and CUDA with efficient memory management
5
  """
6
 
7
  import json
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
+ class SupraRAG:
22
  def __init__(self, rag_data_path: str = None):
23
  # Default RAG data path (for HF Spaces deployment)
24
  if rag_data_path is None:
 
37
  rag_data_path = "data/processed/rag_seeds/rag_seeds.jsonl"
38
  self.rag_data_path = Path(rag_data_path)
39
 
40
+ # Device-specific optimizations
41
+ self._setup_device_optimizations()
42
 
43
+ # Initialize ChromaDB with device optimizations
44
  self.client = chromadb.Client()
45
  self.collection_name = "supra_knowledge"
46
 
47
+ # Use efficient embedding model (CPU for HF Spaces free tier)
48
+ # CPU is optimal for sentence-transformers on CPU-only deployments
49
+ embedding_device = 'cpu' if self.device == 'cpu' else self.device
50
  self.embedding_model = SentenceTransformer(
51
  'all-MiniLM-L6-v2',
52
+ device=embedding_device
53
  )
54
 
55
  # Initialize or load collection
 
73
  self.collection = self.client.create_collection(self.collection_name)
74
  self._load_rag_documents()
75
 
76
+ def _setup_device_optimizations(self):
77
+ """Configure optimizations for CPU/MPS/CUDA."""
78
+ logger.info("πŸ”§ Setting up device optimizations...")
79
 
80
+ # Environment variables
 
81
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
82
 
83
+ # Detect device: MPS > CUDA > CPU
84
  if torch.backends.mps.is_available():
85
+ logger.info("βœ… MPS (Metal Performance Shaders) available - using MPS")
86
  self.device = "mps"
87
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
88
+ torch.backends.mps.is_built()
89
+ elif torch.cuda.is_available():
90
+ logger.info("βœ… CUDA available - using GPU")
91
+ self.device = "cuda"
92
  else:
93
+ logger.info("πŸ’» CPU detected - using CPU optimizations")
94
  self.device = "cpu"
95
 
 
 
 
96
  logger.info(f"πŸ”§ Using device: {self.device}")
97
 
98
  def _load_rag_documents(self):
99
+ """Load RAG documents from JSONL file with device optimizations."""
100
  if not self.rag_data_path.exists():
101
  logger.warning("⚠️ RAG data file not found")
102
  if st:
 
115
  try:
116
  doc = json.loads(line)
117
  if 'content' in doc and 'id' in doc:
118
+ # Truncate content for memory efficiency
119
  content = doc['content']
120
  if len(content) > 2000: # Limit content length
121
  content = content[:2000] + "..."
 
134
  logger.warning(f"⚠️ Skipping line {line_num}: JSON decode error - {e}")
135
 
136
  if documents:
137
+ # Add to ChromaDB with batch processing
138
+ batch_size = 50 # Smaller batches for memory efficiency
139
  for i in range(0, len(documents), batch_size):
140
  batch_docs = documents[i:i+batch_size]
141
  batch_metadatas = metadatas[i:i+batch_size]
 
268
  st.error(f"Error generating response: {e}")
269
  return f"I apologize, but I encountered an error while generating a response: {e}"
270
 
271
+ # Global RAG instance with device-specific optimizations
272
  @st.cache_resource
273
  def get_supra_rag_m2max():
274
+ """Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA."""
275
  return SupraRAGM2Max()
276
 
277
  # Backward compatibility
278
  def get_supra_rag():
279
+ """Backward compatible function that returns device-optimized RAG."""
280
  return get_supra_rag_m2max()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  # SUPRA-Nexus RAG UI Dependencies
2
- # For Hugging Face Spaces Deployment
3
 
4
  # Streamlit UI Framework
5
  streamlit>=1.28.0
@@ -15,6 +15,10 @@ torch>=2.0.0
15
  # PEFT for LoRA loading
16
  peft>=0.6.0
17
 
 
 
 
 
18
  # NLP utilities
19
  nltk>=3.8.0
20
 
 
1
  # SUPRA-Nexus RAG UI Dependencies
2
+ # For Hugging Face Spaces Deployment (CPU Optimized)
3
 
4
  # Streamlit UI Framework
5
  streamlit>=1.28.0
 
15
  # PEFT for LoRA loading
16
  peft>=0.6.0
17
 
18
+ # CPU Optimizations
19
+ accelerate>=0.30.0 # For CPU inference optimization
20
+ bitsandbytes>=0.43.0 # For 8-bit quantization (CPU compatible)
21
+
22
  # NLP utilities
23
  nltk>=3.8.0
24