Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Simple Gradio App: AI-Generated Image Detector | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| import sys | |
| # Add features directory to path | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| import features.ARNIQA as ARNIQA | |
| import functions.networks as networks | |
| # Global variables | |
| feature_extractor = None | |
| classifier = None | |
| device = "cpu" | |
| def load_models(): | |
| """Load ARNIQA feature extractor and trained classifier""" | |
| global feature_extractor, classifier, device | |
| print("Loading ARNIQA feature extractor...") | |
| feature_extractor = ARNIQA.Compute_ARNIQA(device=device) | |
| feature_extractor.eval() | |
| print("Loading classifier...") | |
| classifier = networks.Classifier_Arch2( | |
| input_dim=4096, | |
| hidden_layers=[1024] | |
| ) | |
| # Load checkpoint | |
| checkpoint_path = os.path.join( | |
| os.path.dirname(__file__), | |
| "checkpoints/GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt" | |
| ) | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| state_dict = checkpoint['state_dict'] | |
| # Remove 'classifier.' prefix from keys if present | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith('classifier.'): | |
| new_state_dict[key.replace('classifier.', '')] = value | |
| else: | |
| new_state_dict[key] = value | |
| classifier.load_state_dict(new_state_dict) | |
| classifier = classifier.to(device) | |
| classifier.eval() | |
| print("Models loaded successfully!") | |
| return feature_extractor, classifier | |
| def preprocess_image(image_pil): | |
| """Preprocess image for ARNIQA""" | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| if image_pil.mode != 'RGB': | |
| image_pil = image_pil.convert('RGB') | |
| # Two scales: 224x224 and original | |
| transform_224 = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| transform_original = transforms.Compose([ | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| img1 = transform_224(image_pil).unsqueeze(0) | |
| img2 = transform_original(image_pil).unsqueeze(0) | |
| return img1, img2 | |
| def predict_image(image): | |
| """Run inference on uploaded image""" | |
| global feature_extractor, classifier, device | |
| if feature_extractor is None or classifier is None: | |
| return "Error: Models not loaded" | |
| try: | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| # Preprocess | |
| img1, img2 = preprocess_image(image) | |
| img1 = img1.to(device) | |
| img2 = img2.to(device) | |
| # Extract features and classify | |
| with torch.no_grad(): | |
| features = feature_extractor(img1, img2) | |
| features = torch.flatten(features, start_dim=1).to(torch.float32) | |
| _, logits = classifier(features) | |
| probs = F.softmax(logits, dim=1) | |
| prob_real = probs[0, 0].item() * 100 | |
| prob_fake = probs[0, 1].item() * 100 | |
| # Format output | |
| if prob_fake > 50.0: | |
| result = f"🚨 AI-Generated (Fake)\nConfidence: {prob_fake:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%" | |
| else: | |
| result = f"✅ Real Image\nConfidence: {prob_real:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%" | |
| return result | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Load models on startup | |
| print("Initializing models...") | |
| try: | |
| load_models() | |
| print("Models loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Create Gradio interface using Blocks | |
| with gr.Blocks(title="AI Image Detector") as demo: | |
| gr.Markdown("# 🔍 Real vs Fake: AI Image Detector") | |
| gr.Markdown(""" | |
| Upload an image to detect if it's real or AI-generated. | |
| This detector uses ARNIQA perceptual features to identify synthetic images. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| submit_btn = gr.Button("Analyze Image", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Result", lines=6) | |
| submit_btn.click(fn=predict_image, inputs=image_input, outputs=output_text) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |