File size: 7,802 Bytes
356ff01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9993825
356ff01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1502ab1
356ff01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1502ab1
 
356ff01
1502ab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356ff01
 
1502ab1
 
 
 
 
 
 
 
356ff01
 
1502ab1
356ff01
1502ab1
 
356ff01
 
 
 
 
 
1502ab1
ba9a127
 
 
 
 
 
1502ab1
356ff01
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
Model service for XSS detection - loads model from Hugging Face Hub
"""
import os
import re
import torch
from typing import Tuple, List
from transformers import RobertaTokenizer, RobertaForSequenceClassification


class ModelService:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        
        # Load tokenizer
        self.tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
        
        # Load PHP model from HuggingFace Hub
        php_model_repo = os.getenv('PHP_MODEL_REPO', 'mekbus/codebert-xss-php')
        try:
            self.php_model = RobertaForSequenceClassification.from_pretrained(php_model_repo)
            self.php_model.to(self.device)
            self.php_model.eval()
            print(f"βœ… PHP model loaded from {php_model_repo}")
        except Exception as e:
            print(f"⚠️  PHP model not found: {e}")
            self.php_model = None
        
        # Load JS model from HuggingFace Hub
        js_model_repo = os.getenv('JS_MODEL_REPO', 'mekbus/codebert-xss-js')
        try:
            self.js_model = RobertaForSequenceClassification.from_pretrained(js_model_repo)
            self.js_model.to(self.device)
            self.js_model.eval()
            print(f"βœ… JS model loaded from {js_model_repo}")
        except Exception as e:
            print(f"⚠️  JS model not found: {e}")
            self.js_model = None
    
    def extract_php_blocks(self, code: str) -> str:
        """Extract PHP code from mixed PHP/HTML and remove comments"""
        php_blocks = re.findall(r'<\?(?:php)?(.*?)(?:\?>|$)', code, re.DOTALL | re.IGNORECASE)
        
        if php_blocks:
            processed_blocks = []
            for block in php_blocks:
                block = block.strip()
                if block.startswith('='):
                    block = 'echo ' + block[1:].strip() + ';'
                processed_blocks.append(block)
            php_code = '\n'.join(processed_blocks)
        else:
            php_code = code
        
        # Remove comments
        php_code = re.sub(r'/\*.*?\*/', '', php_code, flags=re.DOTALL)
        php_code = re.sub(r'//.*$', '', php_code, flags=re.MULTILINE)
        php_code = re.sub(r'#.*$', '', php_code, flags=re.MULTILINE)
        php_code = re.sub(r'\n\s*\n+', '\n', php_code.strip())
        
        return php_code
    
    def extract_js_code(self, code: str) -> str:
        """Extract and clean JavaScript code"""
        code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
        code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)
        code = re.sub(r'\n\s*\n+', '\n', code.strip())
        return code
    
    def chunk_code(self, code: str, max_tokens: int = 400, overlap: int = 50) -> List[str]:
        """Split large code into overlapping chunks"""
        lines = code.split('\n')
        chunks = []
        max_lines = 50
        overlap_lines = 6
        
        i = 0
        while i < len(lines):
            chunk_lines = lines[i:i + max_lines]
            chunk = '\n'.join(chunk_lines)
            if chunk.strip():
                chunks.append(chunk)
            i += max_lines - overlap_lines
        
        return chunks if chunks else [code]
    
    def predict_single(self, code: str, model) -> Tuple[float, float]:
        """Make a single prediction"""
        inputs = self.tokenizer(
            code,
            return_tensors='pt',
            truncation=True,
            max_length=512,
            padding='max_length'
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)
            return probs[0][0].item(), probs[0][1].item()
    
    def predict(self, code: str, language: str) -> Tuple[bool, float, str]:
        """Predict if code is vulnerable"""
        result = self.predict_multi(code, language)
        if result['vulnerabilities']:
            max_vuln = max(result['vulnerabilities'], key=lambda x: x['confidence'])
            return True, max_vuln['confidence'], "VULNERABLE"
        else:
            return False, result['max_confidence'], "SAFE"
    
    def predict_multi(self, code: str, language: str) -> dict:
        """Predict vulnerabilities - returns multiple if found using token-based chunking"""
        if language == 'php':
            model = self.php_model
            code = self.extract_php_blocks(code)
        elif language in ['js', 'javascript']:
            model = self.js_model
            code = self.extract_js_code(code)
        else:
            raise ValueError(f"Unsupported language: {language}")
        
        if model is None:
            raise RuntimeError(f"{language.upper()} model not loaded")
        
        vulnerabilities = []
        max_vuln_prob = 0.0
        threshold = 0.5
        max_length = 512
        chunk_overlap = 50
        
        # Tokenize to check length (token-based chunking like test script)
        tokens = self.tokenizer.encode(code, add_special_tokens=False)
        
        # If short enough, process normally (no chunking needed)
        if len(tokens) <= max_length - 2:  # -2 for [CLS] and [SEP]
            safe_prob, vuln_prob = self.predict_single(code, model)
            max_vuln_prob = vuln_prob
            if vuln_prob >= threshold:
                vulnerabilities.append({
                    'chunk_id': 1,
                    'start_line': 1,
                    'end_line': len(code.split('\n')),
                    'confidence': vuln_prob
                })
        else:
            # Token-based chunking with overlap
            chunk_size = max_length - 2
            stride = chunk_size - chunk_overlap
            chunks = []
            
            for i in range(0, len(tokens), stride):
                chunk_tokens = tokens[i:i + chunk_size]
                if len(chunk_tokens) < 50:  # Skip tiny final chunks
                    continue
                chunks.append(chunk_tokens)
            
            print(f"πŸ“„ Long {language.upper()} code ({len(tokens)} tokens) β†’ {len(chunks)} chunks")
            
            lines = code.split('\n')
            total_lines = len(lines)
            lines_per_chunk = max(1, total_lines // len(chunks)) if chunks else total_lines
            
            for i, chunk_tokens in enumerate(chunks):
                # Decode chunk back to text
                chunk_text = self.tokenizer.decode(chunk_tokens)
                safe_prob, vuln_prob = self.predict_single(chunk_text, model)
                
                if vuln_prob > max_vuln_prob:
                    max_vuln_prob = vuln_prob
                    
                if vuln_prob >= threshold:
                    start_line = i * lines_per_chunk + 1
                    end_line = min(start_line + lines_per_chunk - 1, total_lines)
                    vulnerabilities.append({
                        'chunk_id': i + 1,
                        'start_line': start_line,
                        'end_line': end_line,
                        'confidence': vuln_prob
                    })
            
            # Log chunk scores
            if vulnerabilities:
                scores = [f"{v['confidence']:.1%}" for v in vulnerabilities]
                print(f"πŸ“Š Chunk scores: {scores}")
            else:
                print("πŸ“Š Chunk scores: all safe")
            print(f"πŸ“ˆ Max vulnerability score: {max_vuln_prob:.1%}")
        
        return {
            'is_vulnerable': len(vulnerabilities) > 0,
            'max_confidence': max_vuln_prob,
            'vulnerabilities': vulnerabilities
        }