xss-js / api /services /model_service.py
daniel
Fix f-string syntax error in chunk score logging
ba9a127
"""
Model service for XSS detection - loads model from Hugging Face Hub
"""
import os
import re
import torch
from typing import Tuple, List
from transformers import RobertaTokenizer, RobertaForSequenceClassification
class ModelService:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Load tokenizer
self.tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
# Load PHP model from HuggingFace Hub
php_model_repo = os.getenv('PHP_MODEL_REPO', 'mekbus/codebert-xss-php')
try:
self.php_model = RobertaForSequenceClassification.from_pretrained(php_model_repo)
self.php_model.to(self.device)
self.php_model.eval()
print(f"βœ… PHP model loaded from {php_model_repo}")
except Exception as e:
print(f"⚠️ PHP model not found: {e}")
self.php_model = None
# Load JS model from HuggingFace Hub
js_model_repo = os.getenv('JS_MODEL_REPO', 'mekbus/codebert-xss-js')
try:
self.js_model = RobertaForSequenceClassification.from_pretrained(js_model_repo)
self.js_model.to(self.device)
self.js_model.eval()
print(f"βœ… JS model loaded from {js_model_repo}")
except Exception as e:
print(f"⚠️ JS model not found: {e}")
self.js_model = None
def extract_php_blocks(self, code: str) -> str:
"""Extract PHP code from mixed PHP/HTML and remove comments"""
php_blocks = re.findall(r'<\?(?:php)?(.*?)(?:\?>|$)', code, re.DOTALL | re.IGNORECASE)
if php_blocks:
processed_blocks = []
for block in php_blocks:
block = block.strip()
if block.startswith('='):
block = 'echo ' + block[1:].strip() + ';'
processed_blocks.append(block)
php_code = '\n'.join(processed_blocks)
else:
php_code = code
# Remove comments
php_code = re.sub(r'/\*.*?\*/', '', php_code, flags=re.DOTALL)
php_code = re.sub(r'//.*$', '', php_code, flags=re.MULTILINE)
php_code = re.sub(r'#.*$', '', php_code, flags=re.MULTILINE)
php_code = re.sub(r'\n\s*\n+', '\n', php_code.strip())
return php_code
def extract_js_code(self, code: str) -> str:
"""Extract and clean JavaScript code"""
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)
code = re.sub(r'\n\s*\n+', '\n', code.strip())
return code
def chunk_code(self, code: str, max_tokens: int = 400, overlap: int = 50) -> List[str]:
"""Split large code into overlapping chunks"""
lines = code.split('\n')
chunks = []
max_lines = 50
overlap_lines = 6
i = 0
while i < len(lines):
chunk_lines = lines[i:i + max_lines]
chunk = '\n'.join(chunk_lines)
if chunk.strip():
chunks.append(chunk)
i += max_lines - overlap_lines
return chunks if chunks else [code]
def predict_single(self, code: str, model) -> Tuple[float, float]:
"""Make a single prediction"""
inputs = self.tokenizer(
code,
return_tensors='pt',
truncation=True,
max_length=512,
padding='max_length'
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
return probs[0][0].item(), probs[0][1].item()
def predict(self, code: str, language: str) -> Tuple[bool, float, str]:
"""Predict if code is vulnerable"""
result = self.predict_multi(code, language)
if result['vulnerabilities']:
max_vuln = max(result['vulnerabilities'], key=lambda x: x['confidence'])
return True, max_vuln['confidence'], "VULNERABLE"
else:
return False, result['max_confidence'], "SAFE"
def predict_multi(self, code: str, language: str) -> dict:
"""Predict vulnerabilities - returns multiple if found using token-based chunking"""
if language == 'php':
model = self.php_model
code = self.extract_php_blocks(code)
elif language in ['js', 'javascript']:
model = self.js_model
code = self.extract_js_code(code)
else:
raise ValueError(f"Unsupported language: {language}")
if model is None:
raise RuntimeError(f"{language.upper()} model not loaded")
vulnerabilities = []
max_vuln_prob = 0.0
threshold = 0.5
max_length = 512
chunk_overlap = 50
# Tokenize to check length (token-based chunking like test script)
tokens = self.tokenizer.encode(code, add_special_tokens=False)
# If short enough, process normally (no chunking needed)
if len(tokens) <= max_length - 2: # -2 for [CLS] and [SEP]
safe_prob, vuln_prob = self.predict_single(code, model)
max_vuln_prob = vuln_prob
if vuln_prob >= threshold:
vulnerabilities.append({
'chunk_id': 1,
'start_line': 1,
'end_line': len(code.split('\n')),
'confidence': vuln_prob
})
else:
# Token-based chunking with overlap
chunk_size = max_length - 2
stride = chunk_size - chunk_overlap
chunks = []
for i in range(0, len(tokens), stride):
chunk_tokens = tokens[i:i + chunk_size]
if len(chunk_tokens) < 50: # Skip tiny final chunks
continue
chunks.append(chunk_tokens)
print(f"πŸ“„ Long {language.upper()} code ({len(tokens)} tokens) β†’ {len(chunks)} chunks")
lines = code.split('\n')
total_lines = len(lines)
lines_per_chunk = max(1, total_lines // len(chunks)) if chunks else total_lines
for i, chunk_tokens in enumerate(chunks):
# Decode chunk back to text
chunk_text = self.tokenizer.decode(chunk_tokens)
safe_prob, vuln_prob = self.predict_single(chunk_text, model)
if vuln_prob > max_vuln_prob:
max_vuln_prob = vuln_prob
if vuln_prob >= threshold:
start_line = i * lines_per_chunk + 1
end_line = min(start_line + lines_per_chunk - 1, total_lines)
vulnerabilities.append({
'chunk_id': i + 1,
'start_line': start_line,
'end_line': end_line,
'confidence': vuln_prob
})
# Log chunk scores
if vulnerabilities:
scores = [f"{v['confidence']:.1%}" for v in vulnerabilities]
print(f"πŸ“Š Chunk scores: {scores}")
else:
print("πŸ“Š Chunk scores: all safe")
print(f"πŸ“ˆ Max vulnerability score: {max_vuln_prob:.1%}")
return {
'is_vulnerable': len(vulnerabilities) > 0,
'max_confidence': max_vuln_prob,
'vulnerabilities': vulnerabilities
}