Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,867 Bytes
6bfb0e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""
Inference module for AI-Generated Image Detection using ARNIQA model.
Simplified for demo purposes - CPU-only, no training dependencies.
"""
import os
import sys
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
# Add parent directory to path to import existing modules
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parent_dir)
import features.ARNIQA as ARNIQA
import functions.networks as networks
import defaults
def load_model(device='cpu'):
"""
Load ARNIQA feature extractor and classifier.
Args:
device (str): Device to load model on ('cpu' or 'cuda')
Returns:
tuple: (feature_extractor, classifier) - both models ready for inference
"""
print("Loading ARNIQA model...")
# Load ARNIQA feature extractor
feature_extractor = ARNIQA.Compute_ARNIQA(device=device)
feature_extractor.eval()
# Load classifier
classifier = networks.Classifier_Arch2(
input_dim=4096, # ARNIQA produces 4096-dim features
hidden_layers=[1024]
)
# Load checkpoint
checkpoint_path = os.path.join(
defaults.main_checkpoints_dir,
"GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt"
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint not found at: {checkpoint_path}\n"
f"Please ensure the trained model checkpoint exists."
)
# Load checkpoint (PyTorch Lightning format)
checkpoint = torch.load(checkpoint_path, map_location=device)
# Extract state dict from Lightning checkpoint
state_dict = checkpoint['state_dict']
# Remove 'classifier.' prefix from keys if present
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('classifier.'):
new_state_dict[key.replace('classifier.', '')] = value
else:
new_state_dict[key] = value
classifier.load_state_dict(new_state_dict)
classifier = classifier.to(device)
classifier.eval()
print("Model loaded successfully!")
return feature_extractor, classifier
def preprocess_image(image_pil):
"""
Preprocess image for ARNIQA feature extraction.
Creates two scales: original size and 224x224.
Args:
image_pil (PIL.Image): Input image
Returns:
tuple: (img1_tensor, img2_tensor) - two scales of the image
"""
# Normalization (ImageNet stats)
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
# Transform for original size
transform_original = transforms.Compose([
transforms.ToTensor(),
normalize
])
# Transform for 224x224
transform_224 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
# Convert to RGB if needed
if image_pil.mode != 'RGB':
image_pil = image_pil.convert('RGB')
# Apply transforms
img1 = transform_224(image_pil).unsqueeze(0) # Add batch dimension
img2 = transform_original(image_pil).unsqueeze(0)
return img1, img2
def predict(image_pil, feature_extractor, classifier, device='cpu'):
"""
Run inference on a single image.
Args:
image_pil (PIL.Image): Input image
feature_extractor: ARNIQA feature extractor
classifier: Trained classifier
device (str): Device to run inference on
Returns:
tuple: (prediction, confidence, (prob_real, prob_fake))
- prediction (str): "Real" or "Fake"
- confidence (float): Confidence percentage (0-100)
- prob_real (float): Probability of real (0-100)
- prob_fake (float): Probability of fake (0-100)
"""
# Preprocess
img1, img2 = preprocess_image(image_pil)
img1 = img1.to(device)
img2 = img2.to(device)
# Extract features
with torch.no_grad():
features = feature_extractor(img1, img2)
features = torch.flatten(features, start_dim=1).to(torch.float32)
# Classify
_, logits = classifier(features)
# Get probabilities
probs = F.softmax(logits, dim=1)
prob_real = probs[0, 0].item() * 100 # Class 0 = Real
prob_fake = probs[0, 1].item() * 100 # Class 1 = Fake
# Determine prediction (threshold = 0.5)
if prob_fake > 50.0:
prediction = "Fake"
confidence = prob_fake
else:
prediction = "Real"
confidence = prob_real
return prediction, confidence, (prob_real, prob_fake)
# Test function
if __name__ == "__main__":
# Test the model loading
feature_extractor, classifier = load_model()
print("Model test passed!")
|