Spaces:
Running
on
Zero
A newer version of the Gradio SDK is available:
6.1.0
CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
Project Overview
Research implementation for detecting AI-generated images using perceptual features from Image Quality Assessment (IQA) models. The core approach trains two-layer classifiers on feature spaces extracted from pretrained IQA models to distinguish between real and synthetic images.
Key Commands
Training
python train.py
Trains classifiers on specified datasets with configured feature extractors. Training settings are controlled through:
- Config files in
configs/directory (arniqa.yaml, contrique.yaml, hyperiqa.yaml, reiqa.yaml, tres.yaml) - In-script settings for dataset type (GenImage, DRCT, UnivFD), loss function, and preprocessing
Testing
python test.py
Evaluates trained models across datasets with various distortions (Gaussian blur, JPEG compression). Tests both in-domain (same dataset) and cross-domain (different datasets) performance.
python test_on_images.py
Runs inference on specific image files. Modify image paths in the script before running.
Prior Methods Comparison
python prior_methods/prior_test.py
Tests baseline comparison methods (CLIP, DRCT) for benchmarking.
Analysis and Visualization
python analysis/polar_plot.py # Generate radar plots
python analysis/distortion_plots.py # Plot robustness curves
python analysis/feature_representations.py # Generate t-SNE visualizations
Architecture Overview
Three-Stage Pipeline
Feature Extraction (
features/):- IQA models act as frozen feature extractors
- Supported models: ARNIQA, CONTRIQUE, HyperIQA, ReIQA, TReS
- Also supports CLIP (various architectures) and ResNet50
- Each model in
features/wraps a pretrained backbone - Models loaded via
networks.get_model()infunctions/networks.py
Classification (
functions/networks.py):Classifier_Arch2: Two-layer MLP (Linear β ReLU β Linear)- Input: IQA feature vector (dimension varies by model, specified in config)
- Hidden layer: Typically 1024 units
- Output: 2-class logits (real vs. fake)
Training Loop (
functions/module.py):- PyTorch Lightning-based training
- Loss functions: CrossEntropy, MarginContrastiveLoss (in
loss_optimizers_metrics.py) - Feature extractor remains frozen; only classifier is trained
- Checkpoints saved based on validation loss
Dataset Structure
Three primary datasets configured in defaults.py:
- GenImage: 8 generative models (BigGAN, VQDM, SDv4, SDv5, Wukong, ADM, GLIDE, Midjourney)
- DRCT: 16 Stable Diffusion variants (various versions, ControlNet, inpainting, turbo)
- UnivFD: 19 generative models (ProGAN, StyleGAN, CycleGAN, various diffusion models)
Each dataset has separate train/val splits with different generative models.
Data Preprocessing (functions/preprocess.py)
Configurable augmentation pipeline:
- Gaussian blur (Ο=0-5)
- JPEG compression (QF=30-100)
- Probability-controlled application during training
- Image normalization specific to each feature extractor
Configuration System
YAML config files in configs/ specify per-model settings:
classifier:
input_dim: 4096 # Feature dimension from backbone
hidden_layers: [1024] # Single hidden layer
dataset:
model_name: "arniqa" # Feature extractor identifier
f_model_name: "arniqa" # Used for checkpoint naming
trainer:
devices: [0] # GPU indices
max_epochs: 20
batch_size: 64
The train.py script overrides certain config values based on in-script settings (dataset_type, loss function, preprocessing level).
Path Configuration (CRITICAL)
defaults.py contains hardcoded paths that MUST match your environment:
main_dataset_dir: Location of GenImage/UnivFD/DRCT datasetsmain_checkpoints_dir: Where trained classifier checkpoints are savedmain_feature_ckpts_dir: Pretrained IQA model weightsmain_prior_checkpoints_dir: Prior method checkpoints
The code checks for specific mount points and will assert False if none match. You must either:
- Update paths in
defaults.pyto match your environment - Create the expected directory structure
Checkpoint Management
Checkpoints organized hierarchically:
checkpoints/
βββ GenImage/
β βββ extensive/
β βββ MarginContrastiveLoss_CrossEntropy/
β βββ {model_name}/
β βββ best_model.ckpt
βββ DRCT/
βββ extensive/
βββ MarginContrastiveLoss_CrossEntropy/
βββ {model_name}/
βββ best_model.ckpt
Training automatically resumes from best_model.ckpt if found in expected location.
Dependencies
Core libraries (see functions/ReIQA/requirements.txt for full list):
- PyTorch + torchvision
- PyTorch Lightning (training framework)
- timm (model architectures)
- torchmetrics (evaluation)
- numpy, scipy, scikit-learn, scikit-image
- PIL (Pillow) for image loading
- pyyaml for config parsing
- tqdm for progress bars
Feature extractor dependencies loaded dynamically (e.g., ARNIQA via torch.hub.load).
Important Implementation Details
Training Script Pattern
Both train.py and test.py redirect stdout to log files in stdouts/ and results/ directories. Output is not visible in console by default.
Feature Extraction
In functions/module.py, the global feature_extractor_module function is set before training. During training/validation steps, features are extracted with torch.no_grad() to prevent gradient computation through the frozen backbone.
Metrics and Thresholds
- GenImage/DRCT: Fixed threshold of 0.5 for binary classification
- UnivFD: Threshold determined from validation set for optimal accuracy
Cross-Dataset Testing
test.py includes cross-dataset evaluation (e.g., trained on GenImage, tested on DRCT) to measure generalization.
Prior Methods (prior_methods/)
Comparison implementations of baseline detectors:
- CLIP-based classifiers (various architectures)
- DRCT (Detecting and Recovering Content Transformations)
These use similar training patterns but different feature extractors. Organized in parallel structure to main codebase.
Results and Analysis
results/: CSV files with per-model, per-dataset metricsanalysis/plots/: Generated visualizations (polar plots, t-SNE, robustness curves)- Log files track training progress and test results
Modifying for New Experiments
- Add new feature extractor: Create wrapper in
features/, add toget_model()infunctions/networks.py - Add new dataset: Update
defaults.pywith source lists, add getter function infunctions/utils.py - Change training settings: Modify settings list in
train.py(dataset, loss, augmentation level) - Test new distortions: Add preprocessing settings in
test.pypreprocess_settings_list