raghav
Switch back to Gradio with simple Blocks API - no schema issues
68f94ad
"""
Simple Gradio App: AI-Generated Image Detector
"""
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import sys
# Add features directory to path
sys.path.insert(0, os.path.dirname(__file__))
import features.ARNIQA as ARNIQA
import functions.networks as networks
# Global variables
feature_extractor = None
classifier = None
device = "cpu"
def load_models():
"""Load ARNIQA feature extractor and trained classifier"""
global feature_extractor, classifier, device
print("Loading ARNIQA feature extractor...")
feature_extractor = ARNIQA.Compute_ARNIQA(device=device)
feature_extractor.eval()
print("Loading classifier...")
classifier = networks.Classifier_Arch2(
input_dim=4096,
hidden_layers=[1024]
)
# Load checkpoint
checkpoint_path = os.path.join(
os.path.dirname(__file__),
"checkpoints/GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt"
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
state_dict = checkpoint['state_dict']
# Remove 'classifier.' prefix from keys if present
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('classifier.'):
new_state_dict[key.replace('classifier.', '')] = value
else:
new_state_dict[key] = value
classifier.load_state_dict(new_state_dict)
classifier = classifier.to(device)
classifier.eval()
print("Models loaded successfully!")
return feature_extractor, classifier
def preprocess_image(image_pil):
"""Preprocess image for ARNIQA"""
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
if image_pil.mode != 'RGB':
image_pil = image_pil.convert('RGB')
# Two scales: 224x224 and original
transform_224 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
transform_original = transforms.Compose([
transforms.ToTensor(),
normalize
])
img1 = transform_224(image_pil).unsqueeze(0)
img2 = transform_original(image_pil).unsqueeze(0)
return img1, img2
def predict_image(image):
"""Run inference on uploaded image"""
global feature_extractor, classifier, device
if feature_extractor is None or classifier is None:
return "Error: Models not loaded"
try:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Preprocess
img1, img2 = preprocess_image(image)
img1 = img1.to(device)
img2 = img2.to(device)
# Extract features and classify
with torch.no_grad():
features = feature_extractor(img1, img2)
features = torch.flatten(features, start_dim=1).to(torch.float32)
_, logits = classifier(features)
probs = F.softmax(logits, dim=1)
prob_real = probs[0, 0].item() * 100
prob_fake = probs[0, 1].item() * 100
# Format output
if prob_fake > 50.0:
result = f"🚨 AI-Generated (Fake)\nConfidence: {prob_fake:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%"
else:
result = f"✅ Real Image\nConfidence: {prob_real:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%"
return result
except Exception as e:
return f"Error: {str(e)}"
# Load models on startup
print("Initializing models...")
try:
load_models()
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
import traceback
traceback.print_exc()
# Create Gradio interface using Blocks
with gr.Blocks(title="AI Image Detector") as demo:
gr.Markdown("# 🔍 Real vs Fake: AI Image Detector")
gr.Markdown("""
Upload an image to detect if it's real or AI-generated.
This detector uses ARNIQA perceptual features to identify synthetic images.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
submit_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Result", lines=6)
submit_btn.click(fn=predict_image, inputs=image_input, outputs=output_text)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)