File size: 4,702 Bytes
2cab2eb
68f94ad
2cab2eb
68f94ad
2cab2eb
 
 
 
 
 
 
 
 
 
 
 
 
68f94ad
 
 
 
2cab2eb
 
 
68f94ad
2cab2eb
68f94ad
2cab2eb
 
 
68f94ad
2cab2eb
 
 
 
 
01433ab
2cab2eb
 
 
 
 
 
 
 
21e189f
2cab2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
68f94ad
 
2cab2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f94ad
2cab2eb
68f94ad
 
 
 
 
2cab2eb
68f94ad
 
 
2cab2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f94ad
 
 
 
 
2cab2eb
68f94ad
01433ab
2cab2eb
68f94ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Simple Gradio App: AI-Generated Image Detector
"""
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import sys

# Add features directory to path
sys.path.insert(0, os.path.dirname(__file__))

import features.ARNIQA as ARNIQA
import functions.networks as networks

# Global variables
feature_extractor = None
classifier = None
device = "cpu"

def load_models():
    """Load ARNIQA feature extractor and trained classifier"""
    global feature_extractor, classifier, device

    print("Loading ARNIQA feature extractor...")
    feature_extractor = ARNIQA.Compute_ARNIQA(device=device)
    feature_extractor.eval()

    print("Loading classifier...")
    classifier = networks.Classifier_Arch2(
        input_dim=4096,
        hidden_layers=[1024]
    )

    # Load checkpoint
    checkpoint_path = os.path.join(
        os.path.dirname(__file__),
        "checkpoints/GenImage/extensive/MarginContrastiveLoss_CrossEntropy/arniqa/best_model.ckpt"
    )

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    state_dict = checkpoint['state_dict']

    # Remove 'classifier.' prefix from keys if present
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('classifier.'):
            new_state_dict[key.replace('classifier.', '')] = value
        else:
            new_state_dict[key] = value

    classifier.load_state_dict(new_state_dict)
    classifier = classifier.to(device)
    classifier.eval()

    print("Models loaded successfully!")
    return feature_extractor, classifier

def preprocess_image(image_pil):
    """Preprocess image for ARNIQA"""
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

    if image_pil.mode != 'RGB':
        image_pil = image_pil.convert('RGB')

    # Two scales: 224x224 and original
    transform_224 = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalize
    ])

    transform_original = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    img1 = transform_224(image_pil).unsqueeze(0)
    img2 = transform_original(image_pil).unsqueeze(0)

    return img1, img2

def predict_image(image):
    """Run inference on uploaded image"""
    global feature_extractor, classifier, device

    if feature_extractor is None or classifier is None:
        return "Error: Models not loaded"

    try:
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        # Preprocess
        img1, img2 = preprocess_image(image)
        img1 = img1.to(device)
        img2 = img2.to(device)

        # Extract features and classify
        with torch.no_grad():
            features = feature_extractor(img1, img2)
            features = torch.flatten(features, start_dim=1).to(torch.float32)
            _, logits = classifier(features)
            probs = F.softmax(logits, dim=1)

            prob_real = probs[0, 0].item() * 100
            prob_fake = probs[0, 1].item() * 100

        # Format output
        if prob_fake > 50.0:
            result = f"🚨 AI-Generated (Fake)\nConfidence: {prob_fake:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%"
        else:
            result = f"✅ Real Image\nConfidence: {prob_real:.1f}%\n\nScores:\nReal: {prob_real:.1f}%\nAI-Generated: {prob_fake:.1f}%"

        return result

    except Exception as e:
        return f"Error: {str(e)}"

# Load models on startup
print("Initializing models...")
try:
    load_models()
    print("Models loaded successfully!")
except Exception as e:
    print(f"Error loading models: {e}")
    import traceback
    traceback.print_exc()

# Create Gradio interface using Blocks
with gr.Blocks(title="AI Image Detector") as demo:
    gr.Markdown("# 🔍 Real vs Fake: AI Image Detector")
    gr.Markdown("""
    Upload an image to detect if it's real or AI-generated.

    This detector uses ARNIQA perceptual features to identify synthetic images.
    """)

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            submit_btn = gr.Button("Analyze Image", variant="primary")

        with gr.Column():
            output_text = gr.Textbox(label="Result", lines=6)

    submit_btn.click(fn=predict_image, inputs=image_input, outputs=output_text)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)