""" Model service for XSS detection - loads model from Hugging Face Hub """ import os import re import torch from typing import Tuple, List from transformers import RobertaTokenizer, RobertaForSequenceClassification class ModelService: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {self.device}") # Load tokenizer self.tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") # Load PHP model from HuggingFace Hub php_model_repo = os.getenv('PHP_MODEL_REPO', 'mekbus/codebert-xss-php') try: self.php_model = RobertaForSequenceClassification.from_pretrained(php_model_repo) self.php_model.to(self.device) self.php_model.eval() print(f"✅ PHP model loaded from {php_model_repo}") except Exception as e: print(f"⚠️ PHP model not found: {e}") self.php_model = None # Load JS model from HuggingFace Hub js_model_repo = os.getenv('JS_MODEL_REPO', 'mekbus/codebert-xss-js') try: self.js_model = RobertaForSequenceClassification.from_pretrained(js_model_repo) self.js_model.to(self.device) self.js_model.eval() print(f"✅ JS model loaded from {js_model_repo}") except Exception as e: print(f"⚠️ JS model not found: {e}") self.js_model = None def extract_php_blocks(self, code: str) -> str: """Extract PHP code from mixed PHP/HTML and remove comments""" php_blocks = re.findall(r'<\?(?:php)?(.*?)(?:\?>|$)', code, re.DOTALL | re.IGNORECASE) if php_blocks: processed_blocks = [] for block in php_blocks: block = block.strip() if block.startswith('='): block = 'echo ' + block[1:].strip() + ';' processed_blocks.append(block) php_code = '\n'.join(processed_blocks) else: php_code = code # Remove comments php_code = re.sub(r'/\*.*?\*/', '', php_code, flags=re.DOTALL) php_code = re.sub(r'//.*$', '', php_code, flags=re.MULTILINE) php_code = re.sub(r'#.*$', '', php_code, flags=re.MULTILINE) php_code = re.sub(r'\n\s*\n+', '\n', php_code.strip()) return php_code def extract_js_code(self, code: str) -> str: """Extract and clean JavaScript code""" code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) code = re.sub(r'//.*$', '', code, flags=re.MULTILINE) code = re.sub(r'\n\s*\n+', '\n', code.strip()) return code def chunk_code(self, code: str, max_tokens: int = 400, overlap: int = 50) -> List[str]: """Split large code into overlapping chunks""" lines = code.split('\n') chunks = [] max_lines = 50 overlap_lines = 6 i = 0 while i < len(lines): chunk_lines = lines[i:i + max_lines] chunk = '\n'.join(chunk_lines) if chunk.strip(): chunks.append(chunk) i += max_lines - overlap_lines return chunks if chunks else [code] def predict_single(self, code: str, model) -> Tuple[float, float]: """Make a single prediction""" inputs = self.tokenizer( code, return_tensors='pt', truncation=True, max_length=512, padding='max_length' ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1) return probs[0][0].item(), probs[0][1].item() def predict(self, code: str, language: str) -> Tuple[bool, float, str]: """Predict if code is vulnerable""" result = self.predict_multi(code, language) if result['vulnerabilities']: max_vuln = max(result['vulnerabilities'], key=lambda x: x['confidence']) return True, max_vuln['confidence'], "VULNERABLE" else: return False, result['max_confidence'], "SAFE" def predict_multi(self, code: str, language: str) -> dict: """Predict vulnerabilities - returns multiple if found using token-based chunking""" if language == 'php': model = self.php_model code = self.extract_php_blocks(code) elif language in ['js', 'javascript']: model = self.js_model code = self.extract_js_code(code) else: raise ValueError(f"Unsupported language: {language}") if model is None: raise RuntimeError(f"{language.upper()} model not loaded") vulnerabilities = [] max_vuln_prob = 0.0 threshold = 0.5 max_length = 512 chunk_overlap = 50 # Tokenize to check length (token-based chunking like test script) tokens = self.tokenizer.encode(code, add_special_tokens=False) # If short enough, process normally (no chunking needed) if len(tokens) <= max_length - 2: # -2 for [CLS] and [SEP] safe_prob, vuln_prob = self.predict_single(code, model) max_vuln_prob = vuln_prob if vuln_prob >= threshold: vulnerabilities.append({ 'chunk_id': 1, 'start_line': 1, 'end_line': len(code.split('\n')), 'confidence': vuln_prob }) else: # Token-based chunking with overlap chunk_size = max_length - 2 stride = chunk_size - chunk_overlap chunks = [] for i in range(0, len(tokens), stride): chunk_tokens = tokens[i:i + chunk_size] if len(chunk_tokens) < 50: # Skip tiny final chunks continue chunks.append(chunk_tokens) print(f"📄 Long {language.upper()} code ({len(tokens)} tokens) → {len(chunks)} chunks") lines = code.split('\n') total_lines = len(lines) lines_per_chunk = max(1, total_lines // len(chunks)) if chunks else total_lines for i, chunk_tokens in enumerate(chunks): # Decode chunk back to text chunk_text = self.tokenizer.decode(chunk_tokens) safe_prob, vuln_prob = self.predict_single(chunk_text, model) if vuln_prob > max_vuln_prob: max_vuln_prob = vuln_prob if vuln_prob >= threshold: start_line = i * lines_per_chunk + 1 end_line = min(start_line + lines_per_chunk - 1, total_lines) vulnerabilities.append({ 'chunk_id': i + 1, 'start_line': start_line, 'end_line': end_line, 'confidence': vuln_prob }) # Log chunk scores if vulnerabilities: scores = [f"{v['confidence']:.1%}" for v in vulnerabilities] print(f"📊 Chunk scores: {scores}") else: print("📊 Chunk scores: all safe") print(f"📈 Max vulnerability score: {max_vuln_prob:.1%}") return { 'is_vulnerable': len(vulnerabilities) > 0, 'max_confidence': max_vuln_prob, 'vulnerabilities': vulnerabilities }