detectNanoBananaImage2 / test_all_models_original.py
gh-rgupta
Add Git LFS configuration and update test files for Mac CPU compatibility
94421ed
"""
Test all available models on the same image
"""
import os
import sys
if __name__ == '__main__':
# Available models - test all 5 IQA-based models
models = ['contrique', 'hyperiqa', 'tres', 'reiqa', 'arniqa']
# Test images directory
test_images_dir = "new_images_to_test"
# Get all images from the directory
import glob
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
test_images = []
for ext in image_extensions:
test_images.extend(glob.glob(os.path.join(test_images_dir, ext)))
if not test_images:
print(f"Error: No images found in {test_images_dir}/")
sys.exit(1)
print(f"Found {len(test_images)} image(s) in {test_images_dir}/")
print("=" * 80)
# Import libraries once
sys.path.insert(0, '.')
from yaml import safe_load
from functions.loss_optimizers_metrics import *
from functions.run_on_images_fn import run_on_images
import functions.utils as utils
import functions.networks as networks
import defaults
import warnings
warnings.filterwarnings("ignore")
all_results = {}
# Test each model
for model_idx, model_name in enumerate(models, 1):
print(f"\n{'='*80}")
print(f"[{model_idx}/{len(models)}] Testing model: {model_name.upper()}")
print("="*80)
try:
config_path = f"configs/{model_name}.yaml"
config = safe_load(open(config_path, "r"))
# Override settings
config["dataset"]["dataset_type"] = "GenImage"
config["checkpoints"]["resume_dirname"] = "GenImage/extensive/MarginContrastiveLoss_CrossEntropy"
config["checkpoints"]["resume_filename"] = "best_model.ckpt"
config["checkpoints"]["checkpoint_dirname"] = "extensive/MarginContrastiveLoss_CrossEntropy"
config["checkpoints"]["checkpoint_filename"] = "best_model.ckpt"
# Training settings (for testing)
config["train_settings"]["train"] = False
config["train_loss_fn"]["name"] = "CrossEntropy"
config["val_loss_fn"]["name"] = "CrossEntropy"
# Model setup - use CPU (MPS has compatibility issues)
device = "cpu"
feature_extractor = networks.get_model(model_name=model_name, device=device)
# Classifier
config["classifier"]["hidden_layers"] = [1024]
classifier = networks.Classifier_Arch2(
input_dim=config["classifier"]["input_dim"],
hidden_layers=config["classifier"]["hidden_layers"]
)
# Preprocessing settings
preprocess_settings = {
"model_name": model_name,
"selected_transforms_name": "test",
"probability": -1,
"gaussian_blur_range": None,
"jpeg_compression_qfs": None,
"input_image_dimensions": (224, 224),
"resize": None
}
print(f"✓ {model_name.upper()} model loaded successfully\n")
results = []
# Test each image with this model
for idx, test_image in enumerate(test_images, 1):
image_name = os.path.basename(test_image)
print(f" [{idx}/{len(test_images)}] Testing: {image_name}")
# Test images
test_real_images_paths = [test_image]
test_fake_images_paths = []
try:
test_set_metrics, best_threshold, y_pred, y_true = run_on_images(
feature_extractor=feature_extractor,
classifier=classifier,
config=config,
test_real_images_paths=test_real_images_paths,
test_fake_images_paths=test_fake_images_paths,
preprocess_settings=preprocess_settings,
best_threshold=0.5,
verbose=False
)
score = y_pred[0] if len(y_pred) > 0 else None
prediction = "AI-Generated" if score and score > 0.5 else "Real"
confidence = abs(score - 0.5) * 200 if score else 0
results.append({
'image': image_name,
'score': score,
'prediction': prediction,
'confidence': confidence
})
print(f" ✓ Score: {score:.4f}{prediction} ({confidence:.1f}% confidence)")
except Exception as e:
print(f" ✗ Error: {e}")
results.append({
'image': image_name,
'score': None,
'prediction': 'Error',
'confidence': 0
})
all_results[model_name] = results
except Exception as e:
print(f"✗ Failed to load {model_name.upper()} model: {e}")
all_results[model_name] = None
# Final Summary
print("\n" + "="*80)
print("FINAL SUMMARY - ALL MODELS")
print("="*80)
for model_name, results in all_results.items():
if results is None:
print(f"\n{model_name.upper()}: Failed to load")
continue
print(f"\n{model_name.upper()}:")
print("-"*80)
print(f"{'Image':<50} {'Score':<10} {'Prediction':<15} {'Confidence':<12}")
print("-"*80)
for r in results:
score_str = f"{r['score']:.4f}" if r['score'] is not None else "N/A"
conf_str = f"{r['confidence']:.1f}%" if r['score'] is not None else "N/A"
img_name = r['image'][:47] + "..." if len(r['image']) > 50 else r['image']
print(f"{img_name:<50} {score_str:<10} {r['prediction']:<15} {conf_str:<12}")
# Statistics
valid_predictions = [r for r in results if r['score'] is not None]
if valid_predictions:
avg_score = sum(r['score'] for r in valid_predictions) / len(valid_predictions)
ai_count = sum(1 for r in valid_predictions if r['score'] > 0.5)
real_count = len(valid_predictions) - ai_count
avg_confidence = sum(r['confidence'] for r in valid_predictions) / len(valid_predictions)
print("-"*80)
print(f"Average Score: {avg_score:.4f} | AI: {ai_count} | Real: {real_count} | Avg Confidence: {avg_confidence:.1f}%")
print("\n" + "="*80)