File size: 7,463 Bytes
223ef32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from dotenv import load_dotenv
load_dotenv()

from src.agents.correlation_agent.correlation_logic import CorrelationAgent
from src.agents.correlation_agent.types import LogInput, MitreInput
from src.agents.mitre_retriever_agent.mitre_example_input import create_sample_log_input, create_elaborate_mockup_incident
from src.agents.mitre_retriever_agent.mitre_agent import MitreAgent
from src.agents.correlation_agent.input_converters import convert_mitre_agent_input, convert_mitre_analysis_output

def test_sample_correlation():
    """Test basic correlation functionality using real MITRE agent output"""
    print("="*60)
    print("TESTING BASIC CORRELATION WITH MITRE AGENT")
    print("="*60)
    
    # Create sample input from mitre_example_input
    mitre_agent_input = create_sample_log_input()
    log_input = convert_mitre_agent_input(mitre_agent_input)
    
    print(f"\nSAMPLE INPUT CREATED:")
    print(f"- Analysis ID: {log_input.analysis_id}")
    print(f"- Severity: {log_input.severity}")
    print(f"- Affected Systems: {log_input.affected_systems}")
    print(f"- Anomalies: {len(log_input.anomalies)} detected")
    print(f"- Processes: {log_input.processes}")
    
    print("\nRUNNING MITRE AGENT ANALYSIS...")
    mitre_agent = MitreAgent(
        llm_provider="openai",
        model_name="gpt-4o",
        max_iterations=3
    )
    
    mitre_analysis_result = mitre_agent.analyze_threat(mitre_agent_input)
    print(f"βœ“ MITRE analysis completed")
    print(f"  - Techniques found: {len(mitre_analysis_result.get('technique_details', []))}")
    print(f"  - Coverage score: {mitre_analysis_result.get('coverage_score', 0):.3f}")
    print(f"  - Confidence: {mitre_analysis_result.get('confidence', 0):.3f}")
    
    # Convert MITRE analysis to MitreInput format
    mitre_input = convert_mitre_analysis_output(mitre_analysis_result, mitre_agent_input)
    
    print(f"\nβœ“ MITRE INPUT CONVERTED:")
    print(f"  - Top {min(5, len(mitre_input.techniques))} techniques:")
    for i, tech in enumerate(mitre_input.techniques[:5], 1):
        print(f"    {i}. {tech['attack_id']}: {tech['name']} (Score: {tech['relevance_score']:.3f})")
    
    # Run correlation analysis
    print("\nRUNNING CORRELATION ANALYSIS...")
    correlation_agent = CorrelationAgent()
    result = correlation_agent.process(log_input, mitre_input)
    
    # Display results
    print(f"\n{'='*60}")
    print(f"CORRELATION RESULTS:")
    print(f"{'='*60}")
    print(f"ID: {result.correlation_id}")
    print(f"Score: {result.correlation_score:.3f}")
    print(f"Threat Level: {result.threat_level.value.upper()}")
    print(f"Confidence: {result.confidence.value.upper()}")
    print(f"Timestamp: {result.timestamp}")
    
    print(f"\nMATCHED TECHNIQUES ({len(result.matched_techniques)}):")
    for i, tech in enumerate(result.matched_techniques, 1):
        print(f"{i}. {tech.technique_id} - Confidence: {tech.match_confidence:.3f}")
        print(f"   Evidence: {tech.evidence[:100]}{'...' if len(tech.evidence) > 100 else ''}")
    
    print(f"\nREASONING:")
    print(f"{result.reasoning}")
    
    return result

def test_elaborate_correlation():
    """Test correlation using elaborate mockup incident with MITRE agent"""
    print("\n" + "="*60)
    print("TESTING CORRELATION - ELABORATE INCIDENT")
    print("="*60)
    
    # Create elaborate incident input
    mitre_agent_input = create_elaborate_mockup_incident()
    log_input = convert_mitre_agent_input(mitre_agent_input)
    
    print(f"\nELABORATE INCIDENT INPUT:")
    print(f"- Analysis ID: {log_input.analysis_id}")
    print(f"- Severity: {log_input.severity}")
    print(f"- Affected Systems: {len(log_input.affected_systems)} systems")
    print(f"- Anomalies: {len(log_input.anomalies)} detected")
    print(f"- Processes: {len(log_input.processes)} processes")
    
    # Run MITRE agent analysis for elaborate incident
    print("\nRUNNING MITRE AGENT ANALYSIS FOR ELABORATE INCIDENT...")
    mitre_agent = MitreAgent(
        llm_provider="openai",
        model_name="gpt-4o",
        max_iterations=3
    )
    
    mitre_analysis_result = mitre_agent.analyze_threat(mitre_agent_input)
    print(f"βœ“ MITRE analysis completed")
    print(f"  - Techniques found: {len(mitre_analysis_result.get('technique_details', []))}")
    print(f"  - Coverage score: {mitre_analysis_result.get('coverage_score', 0):.3f}")
    print(f"  - Confidence: {mitre_analysis_result.get('confidence', 0):.3f}")
    
    # Convert to MitreInput
    mitre_input = convert_mitre_analysis_output(mitre_analysis_result, mitre_agent_input)
    
    print(f"\nβœ“ TOP TECHNIQUES FROM MITRE ANALYSIS:")
    for i, tech in enumerate(mitre_input.techniques[:5], 1):
        print(f"  {i}. {tech['attack_id']}: {tech['name'][:50]}... (Score: {tech['relevance_score']:.3f})")
    
    # Run correlation analysis
    print("\nRUNNING ELABORATE CORRELATION ANALYSIS...")
    correlation_agent = CorrelationAgent()
    result = correlation_agent.process(log_input, mitre_input)

    print(f"\n{'='*60}")
    print(f"ELABORATE CORRELATION RESULTS:")
    print(f"{'='*60}")
    print(f"ID: {result.correlation_id}")
    print(f"Score: {result.correlation_score:.3f}")
    print(f"Threat Level: {result.threat_level.value.upper()}")
    print(f"Confidence: {result.confidence.value.upper()}")
    print(f"Matched Techniques: {len(result.matched_techniques)}")
    
    print(f"\nTOP CORRELATED TECHNIQUES:")
    for i, tech in enumerate(result.matched_techniques[:5], 1):
        print(f"{i}. {tech.technique_id} - Confidence: {tech.match_confidence:.3f}")
        print(f"   Evidence: {tech.evidence[:80]}{'...' if len(tech.evidence) > 80 else ''}")
    
    print(f"\nREASONING:")
    print(f"{result.reasoning}")
    
    return result

def main():
    """Main test function"""
    print("β•”" + "="*58 + "β•—")
    print("β•‘" + " "*10 + "CORRELATION AGENT TEST SUITE" + " "*20 + "β•‘")
    print("β•š" + "="*58 + "╝")
    print()
    
    try:
        # Test sample correlation with MITRE agent
        result1 = test_sample_correlation()

        # Test elaborate correlation with elaborate incident
        result2 = test_elaborate_correlation()

        print("\n" + "="*60)
        print("βœ“ ALL TESTS COMPLETED SUCCESSFULLY")
        print("="*60)
        
        # Summary
        print(f"\nTEST SUMMARY:")
        print(f"\n1. Sample Input Test:")
        print(f"   - Threat Level: {result1.threat_level.value.upper()}")
        print(f"   - Confidence: {result1.confidence.value.upper()}")
        print(f"   - Correlation Score: {result1.correlation_score:.3f}")
        print(f"   - Matched Techniques: {len(result1.matched_techniques)}")
        
        print(f"\n2. Elaborate Incident Test:")
        print(f"   - Threat Level: {result2.threat_level.value.upper()}")
        print(f"   - Confidence: {result2.confidence.value.upper()}")
        print(f"   - Correlation Score: {result2.correlation_score:.3f}")
        print(f"   - Matched Techniques: {len(result2.matched_techniques)}")
        
        print("\n" + "="*60)
        
    except Exception as e:
        print(f"\n❌ TEST FAILED: {e}")
        import traceback
        traceback.print_exc()
        raise

if __name__ == "__main__":
    main()