# CLAUDE.md This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Project Overview Research implementation for detecting AI-generated images using perceptual features from Image Quality Assessment (IQA) models. The core approach trains two-layer classifiers on feature spaces extracted from pretrained IQA models to distinguish between real and synthetic images. ## Key Commands ### Training ```bash python train.py ``` Trains classifiers on specified datasets with configured feature extractors. Training settings are controlled through: - Config files in `configs/` directory (arniqa.yaml, contrique.yaml, hyperiqa.yaml, reiqa.yaml, tres.yaml) - In-script settings for dataset type (GenImage, DRCT, UnivFD), loss function, and preprocessing ### Testing ```bash python test.py ``` Evaluates trained models across datasets with various distortions (Gaussian blur, JPEG compression). Tests both in-domain (same dataset) and cross-domain (different datasets) performance. ```bash python test_on_images.py ``` Runs inference on specific image files. Modify image paths in the script before running. ### Prior Methods Comparison ```bash python prior_methods/prior_test.py ``` Tests baseline comparison methods (CLIP, DRCT) for benchmarking. ### Analysis and Visualization ```bash python analysis/polar_plot.py # Generate radar plots python analysis/distortion_plots.py # Plot robustness curves python analysis/feature_representations.py # Generate t-SNE visualizations ``` ## Architecture Overview ### Three-Stage Pipeline 1. **Feature Extraction** (`features/`): - IQA models act as frozen feature extractors - Supported models: ARNIQA, CONTRIQUE, HyperIQA, ReIQA, TReS - Also supports CLIP (various architectures) and ResNet50 - Each model in `features/` wraps a pretrained backbone - Models loaded via `networks.get_model()` in `functions/networks.py` 2. **Classification** (`functions/networks.py`): - `Classifier_Arch2`: Two-layer MLP (Linear → ReLU → Linear) - Input: IQA feature vector (dimension varies by model, specified in config) - Hidden layer: Typically 1024 units - Output: 2-class logits (real vs. fake) 3. **Training Loop** (`functions/module.py`): - PyTorch Lightning-based training - Loss functions: CrossEntropy, MarginContrastiveLoss (in `loss_optimizers_metrics.py`) - Feature extractor remains frozen; only classifier is trained - Checkpoints saved based on validation loss ### Dataset Structure Three primary datasets configured in `defaults.py`: - **GenImage**: 8 generative models (BigGAN, VQDM, SDv4, SDv5, Wukong, ADM, GLIDE, Midjourney) - **DRCT**: 16 Stable Diffusion variants (various versions, ControlNet, inpainting, turbo) - **UnivFD**: 19 generative models (ProGAN, StyleGAN, CycleGAN, various diffusion models) Each dataset has separate train/val splits with different generative models. ### Data Preprocessing (`functions/preprocess.py`) Configurable augmentation pipeline: - Gaussian blur (σ=0-5) - JPEG compression (QF=30-100) - Probability-controlled application during training - Image normalization specific to each feature extractor ## Configuration System YAML config files in `configs/` specify per-model settings: ```yaml classifier: input_dim: 4096 # Feature dimension from backbone hidden_layers: [1024] # Single hidden layer dataset: model_name: "arniqa" # Feature extractor identifier f_model_name: "arniqa" # Used for checkpoint naming trainer: devices: [0] # GPU indices max_epochs: 20 batch_size: 64 ``` The `train.py` script overrides certain config values based on in-script settings (dataset_type, loss function, preprocessing level). ## Path Configuration (CRITICAL) `defaults.py` contains hardcoded paths that MUST match your environment: - `main_dataset_dir`: Location of GenImage/UnivFD/DRCT datasets - `main_checkpoints_dir`: Where trained classifier checkpoints are saved - `main_feature_ckpts_dir`: Pretrained IQA model weights - `main_prior_checkpoints_dir`: Prior method checkpoints **The code checks for specific mount points and will assert False if none match.** You must either: 1. Update paths in `defaults.py` to match your environment 2. Create the expected directory structure ## Checkpoint Management Checkpoints organized hierarchically: ``` checkpoints/ ├── GenImage/ │ └── extensive/ │ └── MarginContrastiveLoss_CrossEntropy/ │ └── {model_name}/ │ └── best_model.ckpt └── DRCT/ └── extensive/ └── MarginContrastiveLoss_CrossEntropy/ └── {model_name}/ └── best_model.ckpt ``` Training automatically resumes from `best_model.ckpt` if found in expected location. ## Dependencies Core libraries (see `functions/ReIQA/requirements.txt` for full list): - PyTorch + torchvision - PyTorch Lightning (training framework) - timm (model architectures) - torchmetrics (evaluation) - numpy, scipy, scikit-learn, scikit-image - PIL (Pillow) for image loading - pyyaml for config parsing - tqdm for progress bars Feature extractor dependencies loaded dynamically (e.g., ARNIQA via `torch.hub.load`). ## Important Implementation Details ### Training Script Pattern Both `train.py` and `test.py` redirect stdout to log files in `stdouts/` and `results/` directories. Output is not visible in console by default. ### Feature Extraction In `functions/module.py`, the global `feature_extractor_module` function is set before training. During training/validation steps, features are extracted with `torch.no_grad()` to prevent gradient computation through the frozen backbone. ### Metrics and Thresholds - **GenImage/DRCT**: Fixed threshold of 0.5 for binary classification - **UnivFD**: Threshold determined from validation set for optimal accuracy ### Cross-Dataset Testing `test.py` includes cross-dataset evaluation (e.g., trained on GenImage, tested on DRCT) to measure generalization. ## Prior Methods (`prior_methods/`) Comparison implementations of baseline detectors: - CLIP-based classifiers (various architectures) - DRCT (Detecting and Recovering Content Transformations) These use similar training patterns but different feature extractors. Organized in parallel structure to main codebase. ## Results and Analysis - `results/`: CSV files with per-model, per-dataset metrics - `analysis/plots/`: Generated visualizations (polar plots, t-SNE, robustness curves) - Log files track training progress and test results ## Modifying for New Experiments 1. **Add new feature extractor**: Create wrapper in `features/`, add to `get_model()` in `functions/networks.py` 2. **Add new dataset**: Update `defaults.py` with source lists, add getter function in `functions/utils.py` 3. **Change training settings**: Modify settings list in `train.py` (dataset, loss, augmentation level) 4. **Test new distortions**: Add preprocessing settings in `test.py` preprocess_settings_list