gh-rgupta
Add demo web application for fake vs real image detection
6bfb0e3
"""
Inference module for AI-Generated Image Detection using ARNIQA model.
Simplified for demo purposes - CPU-only, no training dependencies.
"""
import os
import sys
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
# Add parent directory to path to import existing modules
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parent_dir)
import features.ARNIQA as ARNIQA
import functions.networks as networks
import defaults
def load_model(device='cpu'):
"""
Load ARNIQA feature extractor and classifier.
Args:
device (str): Device to load model on ('cpu' or 'cuda')
Returns:
tuple: (feature_extractor, classifier) - both models ready for inference
"""
print("Loading ARNIQA model...")
# Load ARNIQA feature extractor
feature_extractor = ARNIQA.Compute_ARNIQA(device=device)
feature_extractor.eval()
# Load classifier
classifier = networks.Classifier_Arch2(
input_dim=4096, # ARNIQA produces 4096-dim features
hidden_layers=[1024]
)
# Load checkpoint
checkpoint_path = os.path.join(
defaults.main_checkpoints_dir,
"GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt"
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint not found at: {checkpoint_path}\n"
f"Please ensure the trained model checkpoint exists."
)
# Load checkpoint (PyTorch Lightning format)
checkpoint = torch.load(checkpoint_path, map_location=device)
# Extract state dict from Lightning checkpoint
state_dict = checkpoint['state_dict']
# Remove 'classifier.' prefix from keys if present
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('classifier.'):
new_state_dict[key.replace('classifier.', '')] = value
else:
new_state_dict[key] = value
classifier.load_state_dict(new_state_dict)
classifier = classifier.to(device)
classifier.eval()
print("Model loaded successfully!")
return feature_extractor, classifier
def preprocess_image(image_pil):
"""
Preprocess image for ARNIQA feature extraction.
Creates two scales: original size and 224x224.
Args:
image_pil (PIL.Image): Input image
Returns:
tuple: (img1_tensor, img2_tensor) - two scales of the image
"""
# Normalization (ImageNet stats)
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
# Transform for original size
transform_original = transforms.Compose([
transforms.ToTensor(),
normalize
])
# Transform for 224x224
transform_224 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
# Convert to RGB if needed
if image_pil.mode != 'RGB':
image_pil = image_pil.convert('RGB')
# Apply transforms
img1 = transform_224(image_pil).unsqueeze(0) # Add batch dimension
img2 = transform_original(image_pil).unsqueeze(0)
return img1, img2
def predict(image_pil, feature_extractor, classifier, device='cpu'):
"""
Run inference on a single image.
Args:
image_pil (PIL.Image): Input image
feature_extractor: ARNIQA feature extractor
classifier: Trained classifier
device (str): Device to run inference on
Returns:
tuple: (prediction, confidence, (prob_real, prob_fake))
- prediction (str): "Real" or "Fake"
- confidence (float): Confidence percentage (0-100)
- prob_real (float): Probability of real (0-100)
- prob_fake (float): Probability of fake (0-100)
"""
# Preprocess
img1, img2 = preprocess_image(image_pil)
img1 = img1.to(device)
img2 = img2.to(device)
# Extract features
with torch.no_grad():
features = feature_extractor(img1, img2)
features = torch.flatten(features, start_dim=1).to(torch.float32)
# Classify
_, logits = classifier(features)
# Get probabilities
probs = F.softmax(logits, dim=1)
prob_real = probs[0, 0].item() * 100 # Class 0 = Real
prob_fake = probs[0, 1].item() * 100 # Class 1 = Fake
# Determine prediction (threshold = 0.5)
if prob_fake > 50.0:
prediction = "Fake"
confidence = prob_fake
else:
prediction = "Real"
confidence = prob_real
return prediction, confidence, (prob_real, prob_fake)
# Test function
if __name__ == "__main__":
# Test the model loading
feature_extractor, classifier = load_model()
print("Model test passed!")