detectNanoBananaImage2 / test_mps_compatibility.py
gh-rgupta
Add Git LFS configuration and update test files for Mac CPU compatibility
94421ed
"""
Test MPS compatibility for each model individually
"""
import os
import sys
import glob
import torch
from yaml import safe_load
import functions.networks as networks
from functions.run_on_images_fn import run_on_images
import warnings
warnings.filterwarnings("ignore")
if __name__ == '__main__':
# Get test images
test_images_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "new_images_to_test")
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
test_images = []
for ext in image_extensions:
test_images.extend([os.path.abspath(p) for p in glob.glob(os.path.join(test_images_dir, ext))])
if not test_images:
print("No test images found!")
sys.exit(1)
# Test one image for each model
test_image = test_images[0]
print(f"Testing with image: {os.path.basename(test_image)}\n")
# Available models
models = ['contrique', 'hyperiqa', 'tres', 'reiqa', 'arniqa']
# Check MPS availability
if not torch.backends.mps.is_available():
print("MPS not available on this system!")
sys.exit(1)
print(f"MPS is available. Built: {torch.backends.mps.is_built()}\n")
print("="*80)
results = {}
for model_name in models:
print(f"\nTesting model: {model_name.upper()}")
print("-"*80)
try:
# Load config
config_path = f"configs/{model_name}.yaml"
config = safe_load(open(config_path, "r"))
# Override settings
config["dataset"]["dataset_type"] = "GenImage"
config["checkpoints"]["resume_dirname"] = "GenImage/extensive/MarginContrastiveLoss_CrossEntropy"
config["checkpoints"]["resume_filename"] = "best_model.ckpt"
config["checkpoints"]["checkpoint_dirname"] = "extensive/MarginContrastiveLoss_CrossEntropy"
config["checkpoints"]["checkpoint_filename"] = "best_model.ckpt"
config["train_settings"]["train"] = False
config["train_loss_fn"]["name"] = "CrossEntropy"
config["val_loss_fn"]["name"] = "CrossEntropy"
# Try with MPS
device = "mps"
print(f" Loading model on {device}...")
feature_extractor = networks.get_model(model_name=model_name, device=device)
# Classifier
config["classifier"]["hidden_layers"] = [1024]
classifier = networks.Classifier_Arch2(
input_dim=config["classifier"]["input_dim"],
hidden_layers=config["classifier"]["hidden_layers"]
)
# Preprocessing settings
preprocess_settings = {
"model_name": model_name,
"selected_transforms_name": "test",
"probability": -1,
"gaussian_blur_range": None,
"jpeg_compression_qfs": None,
"input_image_dimensions": (224, 224),
"resize": None
}
print(f" Running inference...")
# Test on single image
test_real_images_paths = [test_image]
test_fake_images_paths = []
test_set_metrics, best_threshold, y_pred, y_true = run_on_images(
feature_extractor=feature_extractor,
classifier=classifier,
config=config,
test_real_images_paths=test_real_images_paths,
test_fake_images_paths=test_fake_images_paths,
preprocess_settings=preprocess_settings,
best_threshold=0.5,
verbose=False
)
score = y_pred[0] if len(y_pred) > 0 else None
prediction = "AI-Generated" if score and score > 0.5 else "Real"
print(f" ✓ SUCCESS - Score: {score:.4f}{prediction}")
results[model_name] = {"status": "SUCCESS", "score": score, "prediction": prediction, "error": None}
except Exception as e:
error_msg = str(e)
print(f" ✗ FAILED - {error_msg[:100]}")
results[model_name] = {"status": "FAILED", "score": None, "prediction": None, "error": error_msg}
# Summary
print("\n" + "="*80)
print("MPS COMPATIBILITY SUMMARY")
print("="*80)
successful = []
failed = []
for model_name, result in results.items():
status_icon = "✓" if result["status"] == "SUCCESS" else "✗"
print(f"{status_icon} {model_name.upper():<12} - {result['status']}")
if result["status"] == "SUCCESS":
successful.append(model_name)
print(f" Score: {result['score']:.4f}{result['prediction']}")
else:
failed.append(model_name)
# Print first line of error
error_line = result['error'].split('\n')[0]
print(f" Error: {error_line[:70]}")
print("\n" + "="*80)
print(f"Summary: {len(successful)} successful, {len(failed)} failed")
print(f"MPS-compatible models: {', '.join([m.upper() for m in successful]) if successful else 'None'}")
print(f"CPU-only models: {', '.join([m.upper() for m in failed]) if failed else 'None'}")
print("="*80)