Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,702 Bytes
2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 01433ab 2cab2eb 21e189f 2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 68f94ad 2cab2eb 68f94ad 01433ab 2cab2eb 68f94ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
Simple Gradio App: AI-Generated Image Detector
"""
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import sys
# Add features directory to path
sys.path.insert(0, os.path.dirname(__file__))
import features.ARNIQA as ARNIQA
import functions.networks as networks
# Global variables
feature_extractor = None
classifier = None
device = "cpu"
def load_models():
"""Load ARNIQA feature extractor and trained classifier"""
global feature_extractor, classifier, device
print("Loading ARNIQA feature extractor...")
feature_extractor = ARNIQA.Compute_ARNIQA(device=device)
feature_extractor.eval()
print("Loading classifier...")
classifier = networks.Classifier_Arch2(
input_dim=4096,
hidden_layers=[1024]
)
# Load checkpoint
checkpoint_path = os.path.join(
os.path.dirname(__file__),
"checkpoints/GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt"
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
state_dict = checkpoint['state_dict']
# Remove 'classifier.' prefix from keys if present
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('classifier.'):
new_state_dict[key.replace('classifier.', '')] = value
else:
new_state_dict[key] = value
classifier.load_state_dict(new_state_dict)
classifier = classifier.to(device)
classifier.eval()
print("Models loaded successfully!")
return feature_extractor, classifier
def preprocess_image(image_pil):
"""Preprocess image for ARNIQA"""
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
if image_pil.mode != 'RGB':
image_pil = image_pil.convert('RGB')
# Two scales: 224x224 and original
transform_224 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
transform_original = transforms.Compose([
transforms.ToTensor(),
normalize
])
img1 = transform_224(image_pil).unsqueeze(0)
img2 = transform_original(image_pil).unsqueeze(0)
return img1, img2
def predict_image(image):
"""Run inference on uploaded image"""
global feature_extractor, classifier, device
if feature_extractor is None or classifier is None:
return "Error: Models not loaded"
try:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Preprocess
img1, img2 = preprocess_image(image)
img1 = img1.to(device)
img2 = img2.to(device)
# Extract features and classify
with torch.no_grad():
features = feature_extractor(img1, img2)
features = torch.flatten(features, start_dim=1).to(torch.float32)
_, logits = classifier(features)
probs = F.softmax(logits, dim=1)
prob_real = probs[0, 0].item() * 100
prob_fake = probs[0, 1].item() * 100
# Format output
if prob_fake > 50.0:
result = f"🚨 AI-Generated (Fake)\nConfidence: {prob_fake:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%"
else:
result = f"✅ Real Image\nConfidence: {prob_real:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%"
return result
except Exception as e:
return f"Error: {str(e)}"
# Load models on startup
print("Initializing models...")
try:
load_models()
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
import traceback
traceback.print_exc()
# Create Gradio interface using Blocks
with gr.Blocks(title="AI Image Detector") as demo:
gr.Markdown("# 🔍 Real vs Fake: AI Image Detector")
gr.Markdown("""
Upload an image to detect if it's real or AI-generated.
This detector uses ARNIQA perceptual features to identify synthetic images.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
submit_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Result", lines=6)
submit_btn.click(fn=predict_image, inputs=image_input, outputs=output_text)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|