""" Inference module for AI-Generated Image Detection using ARNIQA model. Simplified for demo purposes - CPU-only, no training dependencies. """ import os import sys import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import numpy as np # Add parent directory to path to import existing modules parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, parent_dir) import features.ARNIQA as ARNIQA import functions.networks as networks import defaults def load_model(device='cpu'): """ Load ARNIQA feature extractor and classifier. Args: device (str): Device to load model on ('cpu' or 'cuda') Returns: tuple: (feature_extractor, classifier) - both models ready for inference """ print("Loading ARNIQA model...") # Load ARNIQA feature extractor feature_extractor = ARNIQA.Compute_ARNIQA(device=device) feature_extractor.eval() # Load classifier classifier = networks.Classifier_Arch2( input_dim=4096, # ARNIQA produces 4096-dim features hidden_layers=[1024] ) # Load checkpoint checkpoint_path = os.path.join( defaults.main_checkpoints_dir, "GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt" ) if not os.path.exists(checkpoint_path): raise FileNotFoundError( f"Checkpoint not found at: {checkpoint_path}\n" f"Please ensure the trained model checkpoint exists." ) # Load checkpoint (PyTorch Lightning format) checkpoint = torch.load(checkpoint_path, map_location=device) # Extract state dict from Lightning checkpoint state_dict = checkpoint['state_dict'] # Remove 'classifier.' prefix from keys if present new_state_dict = {} for key, value in state_dict.items(): if key.startswith('classifier.'): new_state_dict[key.replace('classifier.', '')] = value else: new_state_dict[key] = value classifier.load_state_dict(new_state_dict) classifier = classifier.to(device) classifier.eval() print("Model loaded successfully!") return feature_extractor, classifier def preprocess_image(image_pil): """ Preprocess image for ARNIQA feature extraction. Creates two scales: original size and 224x224. Args: image_pil (PIL.Image): Input image Returns: tuple: (img1_tensor, img2_tensor) - two scales of the image """ # Normalization (ImageNet stats) normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) # Transform for original size transform_original = transforms.Compose([ transforms.ToTensor(), normalize ]) # Transform for 224x224 transform_224 = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), normalize ]) # Convert to RGB if needed if image_pil.mode != 'RGB': image_pil = image_pil.convert('RGB') # Apply transforms img1 = transform_224(image_pil).unsqueeze(0) # Add batch dimension img2 = transform_original(image_pil).unsqueeze(0) return img1, img2 def predict(image_pil, feature_extractor, classifier, device='cpu'): """ Run inference on a single image. Args: image_pil (PIL.Image): Input image feature_extractor: ARNIQA feature extractor classifier: Trained classifier device (str): Device to run inference on Returns: tuple: (prediction, confidence, (prob_real, prob_fake)) - prediction (str): "Real" or "Fake" - confidence (float): Confidence percentage (0-100) - prob_real (float): Probability of real (0-100) - prob_fake (float): Probability of fake (0-100) """ # Preprocess img1, img2 = preprocess_image(image_pil) img1 = img1.to(device) img2 = img2.to(device) # Extract features with torch.no_grad(): features = feature_extractor(img1, img2) features = torch.flatten(features, start_dim=1).to(torch.float32) # Classify _, logits = classifier(features) # Get probabilities probs = F.softmax(logits, dim=1) prob_real = probs[0, 0].item() * 100 # Class 0 = Real prob_fake = probs[0, 1].item() * 100 # Class 1 = Fake # Determine prediction (threshold = 0.5) if prob_fake > 50.0: prediction = "Fake" confidence = prob_fake else: prediction = "Real" confidence = prob_real return prediction, confidence, (prob_real, prob_fake) # Test function if __name__ == "__main__": # Test the model loading feature_extractor, classifier = load_model() print("Model test passed!")