gh-rgupta
Add CPU compatibility for Mac and testing improvements
d9c7b8a

A newer version of the Gradio SDK is available: 6.1.0

Upgrade

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Project Overview

Research implementation for detecting AI-generated images using perceptual features from Image Quality Assessment (IQA) models. The core approach trains two-layer classifiers on feature spaces extracted from pretrained IQA models to distinguish between real and synthetic images.

Key Commands

Training

python train.py

Trains classifiers on specified datasets with configured feature extractors. Training settings are controlled through:

  • Config files in configs/ directory (arniqa.yaml, contrique.yaml, hyperiqa.yaml, reiqa.yaml, tres.yaml)
  • In-script settings for dataset type (GenImage, DRCT, UnivFD), loss function, and preprocessing

Testing

python test.py

Evaluates trained models across datasets with various distortions (Gaussian blur, JPEG compression). Tests both in-domain (same dataset) and cross-domain (different datasets) performance.

python test_on_images.py

Runs inference on specific image files. Modify image paths in the script before running.

Prior Methods Comparison

python prior_methods/prior_test.py

Tests baseline comparison methods (CLIP, DRCT) for benchmarking.

Analysis and Visualization

python analysis/polar_plot.py          # Generate radar plots
python analysis/distortion_plots.py    # Plot robustness curves
python analysis/feature_representations.py  # Generate t-SNE visualizations

Architecture Overview

Three-Stage Pipeline

  1. Feature Extraction (features/):

    • IQA models act as frozen feature extractors
    • Supported models: ARNIQA, CONTRIQUE, HyperIQA, ReIQA, TReS
    • Also supports CLIP (various architectures) and ResNet50
    • Each model in features/ wraps a pretrained backbone
    • Models loaded via networks.get_model() in functions/networks.py
  2. Classification (functions/networks.py):

    • Classifier_Arch2: Two-layer MLP (Linear β†’ ReLU β†’ Linear)
    • Input: IQA feature vector (dimension varies by model, specified in config)
    • Hidden layer: Typically 1024 units
    • Output: 2-class logits (real vs. fake)
  3. Training Loop (functions/module.py):

    • PyTorch Lightning-based training
    • Loss functions: CrossEntropy, MarginContrastiveLoss (in loss_optimizers_metrics.py)
    • Feature extractor remains frozen; only classifier is trained
    • Checkpoints saved based on validation loss

Dataset Structure

Three primary datasets configured in defaults.py:

  • GenImage: 8 generative models (BigGAN, VQDM, SDv4, SDv5, Wukong, ADM, GLIDE, Midjourney)
  • DRCT: 16 Stable Diffusion variants (various versions, ControlNet, inpainting, turbo)
  • UnivFD: 19 generative models (ProGAN, StyleGAN, CycleGAN, various diffusion models)

Each dataset has separate train/val splits with different generative models.

Data Preprocessing (functions/preprocess.py)

Configurable augmentation pipeline:

  • Gaussian blur (Οƒ=0-5)
  • JPEG compression (QF=30-100)
  • Probability-controlled application during training
  • Image normalization specific to each feature extractor

Configuration System

YAML config files in configs/ specify per-model settings:

classifier:
  input_dim: 4096        # Feature dimension from backbone
  hidden_layers: [1024]  # Single hidden layer

dataset:
  model_name: "arniqa"   # Feature extractor identifier
  f_model_name: "arniqa" # Used for checkpoint naming

trainer:
  devices: [0]           # GPU indices
  max_epochs: 20
  batch_size: 64

The train.py script overrides certain config values based on in-script settings (dataset_type, loss function, preprocessing level).

Path Configuration (CRITICAL)

defaults.py contains hardcoded paths that MUST match your environment:

  • main_dataset_dir: Location of GenImage/UnivFD/DRCT datasets
  • main_checkpoints_dir: Where trained classifier checkpoints are saved
  • main_feature_ckpts_dir: Pretrained IQA model weights
  • main_prior_checkpoints_dir: Prior method checkpoints

The code checks for specific mount points and will assert False if none match. You must either:

  1. Update paths in defaults.py to match your environment
  2. Create the expected directory structure

Checkpoint Management

Checkpoints organized hierarchically:

checkpoints/
β”œβ”€β”€ GenImage/
β”‚   └── extensive/
β”‚       └── MarginContrastiveLoss_CrossEntropy/
β”‚           └── {model_name}/
β”‚               └── best_model.ckpt
└── DRCT/
    └── extensive/
        └── MarginContrastiveLoss_CrossEntropy/
            └── {model_name}/
                └── best_model.ckpt

Training automatically resumes from best_model.ckpt if found in expected location.

Dependencies

Core libraries (see functions/ReIQA/requirements.txt for full list):

  • PyTorch + torchvision
  • PyTorch Lightning (training framework)
  • timm (model architectures)
  • torchmetrics (evaluation)
  • numpy, scipy, scikit-learn, scikit-image
  • PIL (Pillow) for image loading
  • pyyaml for config parsing
  • tqdm for progress bars

Feature extractor dependencies loaded dynamically (e.g., ARNIQA via torch.hub.load).

Important Implementation Details

Training Script Pattern

Both train.py and test.py redirect stdout to log files in stdouts/ and results/ directories. Output is not visible in console by default.

Feature Extraction

In functions/module.py, the global feature_extractor_module function is set before training. During training/validation steps, features are extracted with torch.no_grad() to prevent gradient computation through the frozen backbone.

Metrics and Thresholds

  • GenImage/DRCT: Fixed threshold of 0.5 for binary classification
  • UnivFD: Threshold determined from validation set for optimal accuracy

Cross-Dataset Testing

test.py includes cross-dataset evaluation (e.g., trained on GenImage, tested on DRCT) to measure generalization.

Prior Methods (prior_methods/)

Comparison implementations of baseline detectors:

  • CLIP-based classifiers (various architectures)
  • DRCT (Detecting and Recovering Content Transformations)

These use similar training patterns but different feature extractors. Organized in parallel structure to main codebase.

Results and Analysis

  • results/: CSV files with per-model, per-dataset metrics
  • analysis/plots/: Generated visualizations (polar plots, t-SNE, robustness curves)
  • Log files track training progress and test results

Modifying for New Experiments

  1. Add new feature extractor: Create wrapper in features/, add to get_model() in functions/networks.py
  2. Add new dataset: Update defaults.py with source lists, add getter function in functions/utils.py
  3. Change training settings: Modify settings list in train.py (dataset, loss, augmentation level)
  4. Test new distortions: Add preprocessing settings in test.py preprocess_settings_list