# llm_clients/finetuned_guard.py from typing import Generator, Any, Dict, Optional import json from .base import LlmClient class FinetunedGuardClient(LlmClient): """LLM client for finetuned model for safe/unsafe classification using zazaman/fmb.""" def __init__(self, config_dict: Dict[str, Any], system_prompt: str, shared_components: Optional[Dict[str, Any]] = None): super().__init__(config_dict, system_prompt) # If shared components are provided, use them instead of loading our own if shared_components: print(f" 🔗 FinetunedGuardClient: Using shared model components") self.model = shared_components["model"] self.tokenizer = shared_components["tokenizer"] self.classifier = shared_components["classifier"] self.transformers_available = True return # Fallback: Load our own model (this should rarely happen now) print(f" ⚠️ FinetunedGuardClient: Loading independent model (shared components not available)") try: from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline import torch # Disable torch compilation globally torch._dynamo.config.suppress_errors = True torch._dynamo.config.disable = True self.transformers_available = True except ImportError: raise ImportError( "transformers library is required for FinetunedGuardClient. " "Install it with: pip install transformers torch" ) except AttributeError: # If torch._dynamo doesn't exist in older versions, that's fine self.transformers_available = True # Get model name from config or use default model_name = config_dict.get("model_name", "zazaman/fmb") print(f"🔄 Loading finetuned model: {model_name}") try: # Disable torch compile optimizations for lightweight CPU-only devices import os os.environ["TORCH_COMPILE_DISABLE"] = "1" os.environ["TORCHDYNAMO_DISABLE"] = "1" # Disable TensorFlow oneDNN warnings os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" self.model = AutoModelForSequenceClassification.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for CPU device_map=None # Disable automatic device mapping ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Explicitly disable compilation on the model if hasattr(self.model, '_compiler_config'): self.model._compiler_config = None # Use CPU device for lightweight operation device = "cpu" self.model = self.model.to(device) self.classifier = pipeline( "text-classification", model=self.model, tokenizer=self.tokenizer, device=device, framework="pt", torch_dtype=torch.float32 ) print(f"✅ Finetuned Guard Client initialized successfully.") print(f" Model: {model_name}") print(f" Device: {device}") except Exception as e: raise RuntimeError(f"Failed to load finetuned model {model_name}: {e}") def generate_content(self, prompt: str) -> str: """ Classifies the prompt as safe or unsafe using the finetuned model. Returns a JSON response compatible with the existing AI detection system. """ try: # Classify the prompt result = self.classifier(prompt)[0] # Extract the prediction and confidence predicted_label = result['label'] confidence_score = result['score'] # Determine safety based on the model's prediction # Assuming 'SAFE' and 'UNSAFE' are the labels from your fine-tuned model is_safe = predicted_label.upper() == 'SAFE' # Create response in the expected format response_data = { "safety_status": "safe" if is_safe else "unsafe", "attack_type": "none" if is_safe else "prompt_injection", "confidence": confidence_score, "is_safe": is_safe, "model_used": "zazaman/fmb", "reason": f"Model predicted '{predicted_label}' with {confidence_score:.2%} confidence" } return json.dumps(response_data) except Exception as e: # Return error response in JSON format error_response = { "safety_status": "error", "attack_type": "unknown", "confidence": 0.0, "is_safe": False, "model_used": "zazaman/fmb", "reason": f"Classification error: {str(e)}" } return json.dumps(error_response) def generate_content_stream(self, prompt: str) -> Generator[str, None, None]: """ Streaming is not applicable for classification tasks. Returns the classification result as a single chunk. """ yield self.generate_content(prompt) def _generate_content_impl(self, prompt: str) -> str: """Implementation for base class compatibility.""" return self.generate_content(prompt) def _generate_content_stream_impl(self, prompt: str) -> Generator[str, None, None]: """Implementation for base class compatibility.""" return self.generate_content_stream(prompt)