""" Test MPS compatibility for each model individually """ import os import sys import glob import torch from yaml import safe_load import functions.networks as networks from functions.run_on_images_fn import run_on_images import warnings warnings.filterwarnings("ignore") if __name__ == '__main__': # Get test images test_images_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "new_images_to_test") image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG'] test_images = [] for ext in image_extensions: test_images.extend([os.path.abspath(p) for p in glob.glob(os.path.join(test_images_dir, ext))]) if not test_images: print("No test images found!") sys.exit(1) # Test one image for each model test_image = test_images[0] print(f"Testing with image: {os.path.basename(test_image)}\n") # Available models models = ['contrique', 'hyperiqa', 'tres', 'reiqa', 'arniqa'] # Check MPS availability if not torch.backends.mps.is_available(): print("MPS not available on this system!") sys.exit(1) print(f"MPS is available. Built: {torch.backends.mps.is_built()}\n") print("="*80) results = {} for model_name in models: print(f"\nTesting model: {model_name.upper()}") print("-"*80) try: # Load config config_path = f"configs/{model_name}.yaml" config = safe_load(open(config_path, "r")) # Override settings config["dataset"]["dataset_type"] = "GenImage" config["checkpoints"]["resume_dirname"] = "GenImage/extensive/MarginContrastiveLoss_CrossEntropy" config["checkpoints"]["resume_filename"] = "best_model.ckpt" config["checkpoints"]["checkpoint_dirname"] = "extensive/MarginContrastiveLoss_CrossEntropy" config["checkpoints"]["checkpoint_filename"] = "best_model.ckpt" config["train_settings"]["train"] = False config["train_loss_fn"]["name"] = "CrossEntropy" config["val_loss_fn"]["name"] = "CrossEntropy" # Try with MPS device = "mps" print(f" Loading model on {device}...") feature_extractor = networks.get_model(model_name=model_name, device=device) # Classifier config["classifier"]["hidden_layers"] = [1024] classifier = networks.Classifier_Arch2( input_dim=config["classifier"]["input_dim"], hidden_layers=config["classifier"]["hidden_layers"] ) # Preprocessing settings preprocess_settings = { "model_name": model_name, "selected_transforms_name": "test", "probability": -1, "gaussian_blur_range": None, "jpeg_compression_qfs": None, "input_image_dimensions": (224, 224), "resize": None } print(f" Running inference...") # Test on single image test_real_images_paths = [test_image] test_fake_images_paths = [] test_set_metrics, best_threshold, y_pred, y_true = run_on_images( feature_extractor=feature_extractor, classifier=classifier, config=config, test_real_images_paths=test_real_images_paths, test_fake_images_paths=test_fake_images_paths, preprocess_settings=preprocess_settings, best_threshold=0.5, verbose=False ) score = y_pred[0] if len(y_pred) > 0 else None prediction = "AI-Generated" if score and score > 0.5 else "Real" print(f" ✓ SUCCESS - Score: {score:.4f} → {prediction}") results[model_name] = {"status": "SUCCESS", "score": score, "prediction": prediction, "error": None} except Exception as e: error_msg = str(e) print(f" ✗ FAILED - {error_msg[:100]}") results[model_name] = {"status": "FAILED", "score": None, "prediction": None, "error": error_msg} # Summary print("\n" + "="*80) print("MPS COMPATIBILITY SUMMARY") print("="*80) successful = [] failed = [] for model_name, result in results.items(): status_icon = "✓" if result["status"] == "SUCCESS" else "✗" print(f"{status_icon} {model_name.upper():<12} - {result['status']}") if result["status"] == "SUCCESS": successful.append(model_name) print(f" Score: {result['score']:.4f} → {result['prediction']}") else: failed.append(model_name) # Print first line of error error_line = result['error'].split('\n')[0] print(f" Error: {error_line[:70]}") print("\n" + "="*80) print(f"Summary: {len(successful)} successful, {len(failed)} failed") print(f"MPS-compatible models: {', '.join([m.upper() for m in successful]) if successful else 'None'}") print(f"CPU-only models: {', '.join([m.upper() for m in failed]) if failed else 'None'}") print("="*80)