File size: 5,146 Bytes
94421ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Test MPS compatibility for each model individually
"""
import os
import sys
import glob
import torch
from yaml import safe_load
import functions.networks as networks
from functions.run_on_images_fn import run_on_images
import warnings
warnings.filterwarnings("ignore")

if __name__ == '__main__':
    # Get test images
    test_images_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "new_images_to_test")
    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
    test_images = []
    for ext in image_extensions:
        test_images.extend([os.path.abspath(p) for p in glob.glob(os.path.join(test_images_dir, ext))])

    if not test_images:
        print("No test images found!")
        sys.exit(1)

    # Test one image for each model
    test_image = test_images[0]
    print(f"Testing with image: {os.path.basename(test_image)}\n")

    # Available models
    models = ['contrique', 'hyperiqa', 'tres', 'reiqa', 'arniqa']

    # Check MPS availability
    if not torch.backends.mps.is_available():
        print("MPS not available on this system!")
        sys.exit(1)

    print(f"MPS is available. Built: {torch.backends.mps.is_built()}\n")
    print("="*80)

    results = {}

    for model_name in models:
        print(f"\nTesting model: {model_name.upper()}")
        print("-"*80)

        try:
            # Load config
            config_path = f"configs/{model_name}.yaml"
            config = safe_load(open(config_path, "r"))

            # Override settings
            config["dataset"]["dataset_type"] = "GenImage"
            config["checkpoints"]["resume_dirname"] = "GenImage/extensive/MarginContrastiveLoss_CrossEntropy"
            config["checkpoints"]["resume_filename"] = "best_model.ckpt"
            config["checkpoints"]["checkpoint_dirname"] = "extensive/MarginContrastiveLoss_CrossEntropy"
            config["checkpoints"]["checkpoint_filename"] = "best_model.ckpt"
            config["train_settings"]["train"] = False
            config["train_loss_fn"]["name"] = "CrossEntropy"
            config["val_loss_fn"]["name"] = "CrossEntropy"

            # Try with MPS
            device = "mps"
            print(f"  Loading model on {device}...")
            feature_extractor = networks.get_model(model_name=model_name, device=device)

            # Classifier
            config["classifier"]["hidden_layers"] = [1024]
            classifier = networks.Classifier_Arch2(
                input_dim=config["classifier"]["input_dim"],
                hidden_layers=config["classifier"]["hidden_layers"]
            )

            # Preprocessing settings
            preprocess_settings = {
                "model_name": model_name,
                "selected_transforms_name": "test",
                "probability": -1,
                "gaussian_blur_range": None,
                "jpeg_compression_qfs": None,
                "input_image_dimensions": (224, 224),
                "resize": None
            }

            print(f"  Running inference...")

            # Test on single image
            test_real_images_paths = [test_image]
            test_fake_images_paths = []

            test_set_metrics, best_threshold, y_pred, y_true = run_on_images(
                feature_extractor=feature_extractor,
                classifier=classifier,
                config=config,
                test_real_images_paths=test_real_images_paths,
                test_fake_images_paths=test_fake_images_paths,
                preprocess_settings=preprocess_settings,
                best_threshold=0.5,
                verbose=False
            )

            score = y_pred[0] if len(y_pred) > 0 else None
            prediction = "AI-Generated" if score and score > 0.5 else "Real"

            print(f"  ✓ SUCCESS - Score: {score:.4f}{prediction}")
            results[model_name] = {"status": "SUCCESS", "score": score, "prediction": prediction, "error": None}

        except Exception as e:
            error_msg = str(e)
            print(f"  ✗ FAILED - {error_msg[:100]}")
            results[model_name] = {"status": "FAILED", "score": None, "prediction": None, "error": error_msg}

# Summary
print("\n" + "="*80)
print("MPS COMPATIBILITY SUMMARY")
print("="*80)

successful = []
failed = []

for model_name, result in results.items():
        status_icon = "✓" if result["status"] == "SUCCESS" else "✗"
        print(f"{status_icon} {model_name.upper():<12} - {result['status']}")
        if result["status"] == "SUCCESS":
            successful.append(model_name)
            print(f"   Score: {result['score']:.4f}{result['prediction']}")
        else:
            failed.append(model_name)
            # Print first line of error
            error_line = result['error'].split('\n')[0]
            print(f"   Error: {error_line[:70]}")

print("\n" + "="*80)
print(f"Summary: {len(successful)} successful, {len(failed)} failed")
print(f"MPS-compatible models: {', '.join([m.upper() for m in successful]) if successful else 'None'}")
print(f"CPU-only models: {', '.join([m.upper() for m in failed]) if failed else 'None'}")
print("="*80)