Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Test MPS compatibility for each model individually | |
| """ | |
| import os | |
| import sys | |
| import glob | |
| import torch | |
| from yaml import safe_load | |
| import functions.networks as networks | |
| from functions.run_on_images_fn import run_on_images | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| if __name__ == '__main__': | |
| # Get test images | |
| test_images_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "new_images_to_test") | |
| image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG'] | |
| test_images = [] | |
| for ext in image_extensions: | |
| test_images.extend([os.path.abspath(p) for p in glob.glob(os.path.join(test_images_dir, ext))]) | |
| if not test_images: | |
| print("No test images found!") | |
| sys.exit(1) | |
| # Test one image for each model | |
| test_image = test_images[0] | |
| print(f"Testing with image: {os.path.basename(test_image)}\n") | |
| # Available models | |
| models = ['contrique', 'hyperiqa', 'tres', 'reiqa', 'arniqa'] | |
| # Check MPS availability | |
| if not torch.backends.mps.is_available(): | |
| print("MPS not available on this system!") | |
| sys.exit(1) | |
| print(f"MPS is available. Built: {torch.backends.mps.is_built()}\n") | |
| print("="*80) | |
| results = {} | |
| for model_name in models: | |
| print(f"\nTesting model: {model_name.upper()}") | |
| print("-"*80) | |
| try: | |
| # Load config | |
| config_path = f"configs/{model_name}.yaml" | |
| config = safe_load(open(config_path, "r")) | |
| # Override settings | |
| config["dataset"]["dataset_type"] = "GenImage" | |
| config["checkpoints"]["resume_dirname"] = "GenImage/extensive/MarginContrastiveLoss_CrossEntropy" | |
| config["checkpoints"]["resume_filename"] = "best_model.ckpt" | |
| config["checkpoints"]["checkpoint_dirname"] = "extensive/MarginContrastiveLoss_CrossEntropy" | |
| config["checkpoints"]["checkpoint_filename"] = "best_model.ckpt" | |
| config["train_settings"]["train"] = False | |
| config["train_loss_fn"]["name"] = "CrossEntropy" | |
| config["val_loss_fn"]["name"] = "CrossEntropy" | |
| # Try with MPS | |
| device = "mps" | |
| print(f" Loading model on {device}...") | |
| feature_extractor = networks.get_model(model_name=model_name, device=device) | |
| # Classifier | |
| config["classifier"]["hidden_layers"] = [1024] | |
| classifier = networks.Classifier_Arch2( | |
| input_dim=config["classifier"]["input_dim"], | |
| hidden_layers=config["classifier"]["hidden_layers"] | |
| ) | |
| # Preprocessing settings | |
| preprocess_settings = { | |
| "model_name": model_name, | |
| "selected_transforms_name": "test", | |
| "probability": -1, | |
| "gaussian_blur_range": None, | |
| "jpeg_compression_qfs": None, | |
| "input_image_dimensions": (224, 224), | |
| "resize": None | |
| } | |
| print(f" Running inference...") | |
| # Test on single image | |
| test_real_images_paths = [test_image] | |
| test_fake_images_paths = [] | |
| test_set_metrics, best_threshold, y_pred, y_true = run_on_images( | |
| feature_extractor=feature_extractor, | |
| classifier=classifier, | |
| config=config, | |
| test_real_images_paths=test_real_images_paths, | |
| test_fake_images_paths=test_fake_images_paths, | |
| preprocess_settings=preprocess_settings, | |
| best_threshold=0.5, | |
| verbose=False | |
| ) | |
| score = y_pred[0] if len(y_pred) > 0 else None | |
| prediction = "AI-Generated" if score and score > 0.5 else "Real" | |
| print(f" ✓ SUCCESS - Score: {score:.4f} → {prediction}") | |
| results[model_name] = {"status": "SUCCESS", "score": score, "prediction": prediction, "error": None} | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f" ✗ FAILED - {error_msg[:100]}") | |
| results[model_name] = {"status": "FAILED", "score": None, "prediction": None, "error": error_msg} | |
| # Summary | |
| print("\n" + "="*80) | |
| print("MPS COMPATIBILITY SUMMARY") | |
| print("="*80) | |
| successful = [] | |
| failed = [] | |
| for model_name, result in results.items(): | |
| status_icon = "✓" if result["status"] == "SUCCESS" else "✗" | |
| print(f"{status_icon} {model_name.upper():<12} - {result['status']}") | |
| if result["status"] == "SUCCESS": | |
| successful.append(model_name) | |
| print(f" Score: {result['score']:.4f} → {result['prediction']}") | |
| else: | |
| failed.append(model_name) | |
| # Print first line of error | |
| error_line = result['error'].split('\n')[0] | |
| print(f" Error: {error_line[:70]}") | |
| print("\n" + "="*80) | |
| print(f"Summary: {len(successful)} successful, {len(failed)} failed") | |
| print(f"MPS-compatible models: {', '.join([m.upper() for m in successful]) if successful else 'None'}") | |
| print(f"CPU-only models: {', '.join([m.upper() for m in failed]) if failed else 'None'}") | |
| print("="*80) | |