File size: 6,167 Bytes
94421ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Test all available models on the same image
"""
import os
import sys

if __name__ == '__main__':
    # Available models - test all 5 IQA-based models
    models = ['contrique', 'hyperiqa', 'tres', 'reiqa', 'arniqa']

    # Test images directory
    test_images_dir = "new_images_to_test"

    # Get all images from the directory
    import glob
    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
    test_images = []
    for ext in image_extensions:
        test_images.extend(glob.glob(os.path.join(test_images_dir, ext)))

    if not test_images:
        print(f"Error: No images found in {test_images_dir}/")
        sys.exit(1)

    print(f"Found {len(test_images)} image(s) in {test_images_dir}/")
    print("=" * 80)

    # Import libraries once
    sys.path.insert(0, '.')
    from yaml import safe_load
    from functions.loss_optimizers_metrics import *
    from functions.run_on_images_fn import run_on_images
    import functions.utils as utils
    import functions.networks as networks
    import defaults
    import warnings
    warnings.filterwarnings("ignore")

    all_results = {}

    # Test each model
    for model_idx, model_name in enumerate(models, 1):
    print(f"\n{'='*80}")
    print(f"[{model_idx}/{len(models)}] Testing model: {model_name.upper()}")
    print("="*80)

    try:
        config_path = f"configs/{model_name}.yaml"
        config = safe_load(open(config_path, "r"))

        # Override settings
        config["dataset"]["dataset_type"] = "GenImage"
        config["checkpoints"]["resume_dirname"] = "GenImage/extensive/MarginContrastiveLoss_CrossEntropy"
        config["checkpoints"]["resume_filename"] = "best_model.ckpt"
        config["checkpoints"]["checkpoint_dirname"] = "extensive/MarginContrastiveLoss_CrossEntropy"
        config["checkpoints"]["checkpoint_filename"] = "best_model.ckpt"

        # Training settings (for testing)
        config["train_settings"]["train"] = False
        config["train_loss_fn"]["name"] = "CrossEntropy"
        config["val_loss_fn"]["name"] = "CrossEntropy"

        # Model setup - use CPU (MPS has compatibility issues)
        device = "cpu"
        feature_extractor = networks.get_model(model_name=model_name, device=device)

        # Classifier
        config["classifier"]["hidden_layers"] = [1024]
        classifier = networks.Classifier_Arch2(
            input_dim=config["classifier"]["input_dim"],
            hidden_layers=config["classifier"]["hidden_layers"]
        )

        # Preprocessing settings
        preprocess_settings = {
            "model_name": model_name,
            "selected_transforms_name": "test",
            "probability": -1,
            "gaussian_blur_range": None,
            "jpeg_compression_qfs": None,
            "input_image_dimensions": (224, 224),
            "resize": None
        }

        print(f"✓ {model_name.upper()} model loaded successfully\n")

        results = []

        # Test each image with this model
        for idx, test_image in enumerate(test_images, 1):
            image_name = os.path.basename(test_image)
            print(f"  [{idx}/{len(test_images)}] Testing: {image_name}")

            # Test images
            test_real_images_paths = [test_image]
            test_fake_images_paths = []

            try:
                test_set_metrics, best_threshold, y_pred, y_true = run_on_images(
                    feature_extractor=feature_extractor,
                    classifier=classifier,
                    config=config,
                    test_real_images_paths=test_real_images_paths,
                    test_fake_images_paths=test_fake_images_paths,
                    preprocess_settings=preprocess_settings,
                    best_threshold=0.5,
                    verbose=False
                )

                score = y_pred[0] if len(y_pred) > 0 else None
                prediction = "AI-Generated" if score and score > 0.5 else "Real"
                confidence = abs(score - 0.5) * 200 if score else 0

                results.append({
                    'image': image_name,
                    'score': score,
                    'prediction': prediction,
                    'confidence': confidence
                })

                print(f"    ✓ Score: {score:.4f}{prediction} ({confidence:.1f}% confidence)")

            except Exception as e:
                print(f"    ✗ Error: {e}")
                results.append({
                    'image': image_name,
                    'score': None,
                    'prediction': 'Error',
                    'confidence': 0
                })

        all_results[model_name] = results

    except Exception as e:
        print(f"✗ Failed to load {model_name.upper()} model: {e}")
        all_results[model_name] = None

# Final Summary
print("\n" + "="*80)
print("FINAL SUMMARY - ALL MODELS")
print("="*80)

for model_name, results in all_results.items():
    if results is None:
        print(f"\n{model_name.upper()}: Failed to load")
        continue

    print(f"\n{model_name.upper()}:")
    print("-"*80)
    print(f"{'Image':<50} {'Score':<10} {'Prediction':<15} {'Confidence':<12}")
    print("-"*80)

    for r in results:
        score_str = f"{r['score']:.4f}" if r['score'] is not None else "N/A"
        conf_str = f"{r['confidence']:.1f}%" if r['score'] is not None else "N/A"
        img_name = r['image'][:47] + "..." if len(r['image']) > 50 else r['image']
        print(f"{img_name:<50} {score_str:<10} {r['prediction']:<15} {conf_str:<12}")

    # Statistics
    valid_predictions = [r for r in results if r['score'] is not None]
    if valid_predictions:
        avg_score = sum(r['score'] for r in valid_predictions) / len(valid_predictions)
        ai_count = sum(1 for r in valid_predictions if r['score'] > 0.5)
        real_count = len(valid_predictions) - ai_count
        avg_confidence = sum(r['confidence'] for r in valid_predictions) / len(valid_predictions)

        print("-"*80)
        print(f"Average Score: {avg_score:.4f} | AI: {ai_count} | Real: {real_count} | Avg Confidence: {avg_confidence:.1f}%")

print("\n" + "="*80)