Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Inference module for AI-Generated Image Detection using ARNIQA model. | |
| Simplified for demo purposes - CPU-only, no training dependencies. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| # Add parent directory to path to import existing modules | |
| parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| sys.path.insert(0, parent_dir) | |
| import features.ARNIQA as ARNIQA | |
| import functions.networks as networks | |
| import defaults | |
| def load_model(device='cpu'): | |
| """ | |
| Load ARNIQA feature extractor and classifier. | |
| Args: | |
| device (str): Device to load model on ('cpu' or 'cuda') | |
| Returns: | |
| tuple: (feature_extractor, classifier) - both models ready for inference | |
| """ | |
| print("Loading ARNIQA model...") | |
| # Load ARNIQA feature extractor | |
| feature_extractor = ARNIQA.Compute_ARNIQA(device=device) | |
| feature_extractor.eval() | |
| # Load classifier | |
| classifier = networks.Classifier_Arch2( | |
| input_dim=4096, # ARNIQA produces 4096-dim features | |
| hidden_layers=[1024] | |
| ) | |
| # Load checkpoint | |
| checkpoint_path = os.path.join( | |
| defaults.main_checkpoints_dir, | |
| "GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt" | |
| ) | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError( | |
| f"Checkpoint not found at: {checkpoint_path}\n" | |
| f"Please ensure the trained model checkpoint exists." | |
| ) | |
| # Load checkpoint (PyTorch Lightning format) | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| # Extract state dict from Lightning checkpoint | |
| state_dict = checkpoint['state_dict'] | |
| # Remove 'classifier.' prefix from keys if present | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith('classifier.'): | |
| new_state_dict[key.replace('classifier.', '')] = value | |
| else: | |
| new_state_dict[key] = value | |
| classifier.load_state_dict(new_state_dict) | |
| classifier = classifier.to(device) | |
| classifier.eval() | |
| print("Model loaded successfully!") | |
| return feature_extractor, classifier | |
| def preprocess_image(image_pil): | |
| """ | |
| Preprocess image for ARNIQA feature extraction. | |
| Creates two scales: original size and 224x224. | |
| Args: | |
| image_pil (PIL.Image): Input image | |
| Returns: | |
| tuple: (img1_tensor, img2_tensor) - two scales of the image | |
| """ | |
| # Normalization (ImageNet stats) | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| # Transform for original size | |
| transform_original = transforms.Compose([ | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| # Transform for 224x224 | |
| transform_224 = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| # Convert to RGB if needed | |
| if image_pil.mode != 'RGB': | |
| image_pil = image_pil.convert('RGB') | |
| # Apply transforms | |
| img1 = transform_224(image_pil).unsqueeze(0) # Add batch dimension | |
| img2 = transform_original(image_pil).unsqueeze(0) | |
| return img1, img2 | |
| def predict(image_pil, feature_extractor, classifier, device='cpu'): | |
| """ | |
| Run inference on a single image. | |
| Args: | |
| image_pil (PIL.Image): Input image | |
| feature_extractor: ARNIQA feature extractor | |
| classifier: Trained classifier | |
| device (str): Device to run inference on | |
| Returns: | |
| tuple: (prediction, confidence, (prob_real, prob_fake)) | |
| - prediction (str): "Real" or "Fake" | |
| - confidence (float): Confidence percentage (0-100) | |
| - prob_real (float): Probability of real (0-100) | |
| - prob_fake (float): Probability of fake (0-100) | |
| """ | |
| # Preprocess | |
| img1, img2 = preprocess_image(image_pil) | |
| img1 = img1.to(device) | |
| img2 = img2.to(device) | |
| # Extract features | |
| with torch.no_grad(): | |
| features = feature_extractor(img1, img2) | |
| features = torch.flatten(features, start_dim=1).to(torch.float32) | |
| # Classify | |
| _, logits = classifier(features) | |
| # Get probabilities | |
| probs = F.softmax(logits, dim=1) | |
| prob_real = probs[0, 0].item() * 100 # Class 0 = Real | |
| prob_fake = probs[0, 1].item() * 100 # Class 1 = Fake | |
| # Determine prediction (threshold = 0.5) | |
| if prob_fake > 50.0: | |
| prediction = "Fake" | |
| confidence = prob_fake | |
| else: | |
| prediction = "Real" | |
| confidence = prob_real | |
| return prediction, confidence, (prob_real, prob_fake) | |
| # Test function | |
| if __name__ == "__main__": | |
| # Test the model loading | |
| feature_extractor, classifier = load_model() | |
| print("Model test passed!") | |