File size: 2,613 Bytes
6bfb0e3
 
 
94421ed
 
6bfb0e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Streamlit Demo: AI-Generated Image Detector
Simple web interface for detecting AI-generated images using ARNIQA model.

python3 -m streamlit run app.py --server.port=25000 --server.address=0.0.0.0
"""
import streamlit as st
from PIL import Image
import inference

# Page configuration
st.set_page_config(
    page_title="Real vs Fake - AI Image Detector",
    page_icon="🔍",
    layout="centered"
)

# Title and description
st.title("Real vs Fake")
st.markdown("### Detect AI-Generated Images")
st.markdown("---")

# Load model (cached to avoid reloading)
@st.cache_resource
def load_models():
    """Load ARNIQA feature extractor and classifier"""
    with st.spinner("Loading AI detection model..."):
        feature_extractor, classifier = inference.load_model(device='cpu')
    return feature_extractor, classifier

try:
    feature_extractor, classifier = load_models()
    model_loaded = True
except Exception as e:
    st.error("Error loading detection model. Please contact support.")
    model_loaded = False

# File uploader
if model_loaded:
    uploaded_file = st.file_uploader(
        "Choose an image...",
        type=['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG'],
        help="Upload an image in PNG or JPEG format"
    )

    if uploaded_file is not None:
        try:
            # Load and display image
            image = Image.open(uploaded_file)

            # Display image
            col1, col2 = st.columns([1, 1])

            with col1:
                st.image(image, caption="Uploaded Image", use_container_width=True)

            # Run prediction
            with col2:
                with st.spinner("Analyzing image..."):
                    prediction, confidence, (prob_real, prob_fake) = inference.predict(
                        image, feature_extractor, classifier, device='cpu'
                    )

                # Display results
                st.subheader("Results")

                if prediction == "Real":
                    st.success(f"**Real**")
                    st.metric("Confidence", f"{confidence:.1f}%")
                else:
                    st.error(f"**Fake**")
                    st.metric("Confidence", f"{confidence:.1f}%")

                # Show probability breakdown
                st.markdown("---")
                st.write("**Probability Breakdown:**")
                st.write(f"- Real: **{prob_real:.1f}%**")
                st.write(f"- Fake: **{prob_fake:.1f}%**")

        except Exception as e:
            st.error(f"Error processing image: {str(e)}")
            st.write("Please try uploading a different image.")