""" Simple Gradio App: AI-Generated Image Detector """ import gradio as gr import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import os import sys # Add features directory to path sys.path.insert(0, os.path.dirname(__file__)) import features.ARNIQA as ARNIQA import functions.networks as networks # Global variables feature_extractor = None classifier = None device = "cpu" def load_models(): """Load ARNIQA feature extractor and trained classifier""" global feature_extractor, classifier, device print("Loading ARNIQA feature extractor...") feature_extractor = ARNIQA.Compute_ARNIQA(device=device) feature_extractor.eval() print("Loading classifier...") classifier = networks.Classifier_Arch2( input_dim=4096, hidden_layers=[1024] ) # Load checkpoint checkpoint_path = os.path.join( os.path.dirname(__file__), "checkpoints/GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt" ) if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) state_dict = checkpoint['state_dict'] # Remove 'classifier.' prefix from keys if present new_state_dict = {} for key, value in state_dict.items(): if key.startswith('classifier.'): new_state_dict[key.replace('classifier.', '')] = value else: new_state_dict[key] = value classifier.load_state_dict(new_state_dict) classifier = classifier.to(device) classifier.eval() print("Models loaded successfully!") return feature_extractor, classifier def preprocess_image(image_pil): """Preprocess image for ARNIQA""" normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) if image_pil.mode != 'RGB': image_pil = image_pil.convert('RGB') # Two scales: 224x224 and original transform_224 = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), normalize ]) transform_original = transforms.Compose([ transforms.ToTensor(), normalize ]) img1 = transform_224(image_pil).unsqueeze(0) img2 = transform_original(image_pil).unsqueeze(0) return img1, img2 def predict_image(image): """Run inference on uploaded image""" global feature_extractor, classifier, device if feature_extractor is None or classifier is None: return "Error: Models not loaded" try: if not isinstance(image, Image.Image): image = Image.fromarray(image) # Preprocess img1, img2 = preprocess_image(image) img1 = img1.to(device) img2 = img2.to(device) # Extract features and classify with torch.no_grad(): features = feature_extractor(img1, img2) features = torch.flatten(features, start_dim=1).to(torch.float32) _, logits = classifier(features) probs = F.softmax(logits, dim=1) prob_real = probs[0, 0].item() * 100 prob_fake = probs[0, 1].item() * 100 # Format output if prob_fake > 50.0: result = f"🚨 AI-Generated (Fake)\nConfidence: {prob_fake:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%" else: result = f"✅ Real Image\nConfidence: {prob_real:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%" return result except Exception as e: return f"Error: {str(e)}" # Load models on startup print("Initializing models...") try: load_models() print("Models loaded successfully!") except Exception as e: print(f"Error loading models: {e}") import traceback traceback.print_exc() # Create Gradio interface using Blocks with gr.Blocks(title="AI Image Detector") as demo: gr.Markdown("# 🔍 Real vs Fake: AI Image Detector") gr.Markdown(""" Upload an image to detect if it's real or AI-generated. This detector uses ARNIQA perceptual features to identify synthetic images. """) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") submit_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(): output_text = gr.Textbox(label="Result", lines=6) submit_btn.click(fn=predict_image, inputs=image_input, outputs=output_text) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)