gh-rgupta Claude commited on
Commit
d9c7b8a
·
1 Parent(s): 2cda712

Add CPU compatibility for Mac and testing improvements

Browse files

- Modified device handling to use CPU instead of CUDA for Mac compatibility
- Updated test_on_images.py to test all images in new_images_to_test folder
- Added test_all_models.py for testing multiple IQA models
- Fixed PyTorch Lightning trainer to use CPU accelerator
- Added .gitignore for checkpoints, logs, and cache files
- Added CLAUDE.md documentation for project setup and usage

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

Files changed (6) hide show
  1. .gitignore +41 -0
  2. CLAUDE.md +185 -0
  3. defaults.py +12 -1
  4. functions/run_on_images_fn.py +8 -2
  5. test_all_models.py +169 -0
  6. test_on_images.py +18 -10
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+
8
+ # Virtual environments
9
+ venv/
10
+ env/
11
+ ENV/
12
+
13
+ # IDE
14
+ .vscode/
15
+ .idea/
16
+ *.swp
17
+ *.swo
18
+ *~
19
+
20
+ # Model checkpoints and weights
21
+ checkpoints/
22
+ feature_extractor_checkpoints/
23
+ prior_methods_checkpoints/
24
+
25
+ # Training logs and results
26
+ lightning_logs/
27
+ results/*.log
28
+ stdouts/
29
+
30
+ # Test images
31
+ new_images_to_test/
32
+
33
+ # Output files
34
+ *.txt
35
+ test_all_models_output.txt
36
+
37
+ # macOS
38
+ .DS_Store
39
+
40
+ # Jupyter
41
+ .ipynb_checkpoints/
CLAUDE.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ Research implementation for detecting AI-generated images using perceptual features from Image Quality Assessment (IQA) models. The core approach trains two-layer classifiers on feature spaces extracted from pretrained IQA models to distinguish between real and synthetic images.
8
+
9
+ ## Key Commands
10
+
11
+ ### Training
12
+ ```bash
13
+ python train.py
14
+ ```
15
+ Trains classifiers on specified datasets with configured feature extractors. Training settings are controlled through:
16
+ - Config files in `configs/` directory (arniqa.yaml, contrique.yaml, hyperiqa.yaml, reiqa.yaml, tres.yaml)
17
+ - In-script settings for dataset type (GenImage, DRCT, UnivFD), loss function, and preprocessing
18
+
19
+ ### Testing
20
+ ```bash
21
+ python test.py
22
+ ```
23
+ Evaluates trained models across datasets with various distortions (Gaussian blur, JPEG compression). Tests both in-domain (same dataset) and cross-domain (different datasets) performance.
24
+
25
+ ```bash
26
+ python test_on_images.py
27
+ ```
28
+ Runs inference on specific image files. Modify image paths in the script before running.
29
+
30
+ ### Prior Methods Comparison
31
+ ```bash
32
+ python prior_methods/prior_test.py
33
+ ```
34
+ Tests baseline comparison methods (CLIP, DRCT) for benchmarking.
35
+
36
+ ### Analysis and Visualization
37
+ ```bash
38
+ python analysis/polar_plot.py # Generate radar plots
39
+ python analysis/distortion_plots.py # Plot robustness curves
40
+ python analysis/feature_representations.py # Generate t-SNE visualizations
41
+ ```
42
+
43
+ ## Architecture Overview
44
+
45
+ ### Three-Stage Pipeline
46
+
47
+ 1. **Feature Extraction** (`features/`):
48
+ - IQA models act as frozen feature extractors
49
+ - Supported models: ARNIQA, CONTRIQUE, HyperIQA, ReIQA, TReS
50
+ - Also supports CLIP (various architectures) and ResNet50
51
+ - Each model in `features/` wraps a pretrained backbone
52
+ - Models loaded via `networks.get_model()` in `functions/networks.py`
53
+
54
+ 2. **Classification** (`functions/networks.py`):
55
+ - `Classifier_Arch2`: Two-layer MLP (Linear → ReLU → Linear)
56
+ - Input: IQA feature vector (dimension varies by model, specified in config)
57
+ - Hidden layer: Typically 1024 units
58
+ - Output: 2-class logits (real vs. fake)
59
+
60
+ 3. **Training Loop** (`functions/module.py`):
61
+ - PyTorch Lightning-based training
62
+ - Loss functions: CrossEntropy, MarginContrastiveLoss (in `loss_optimizers_metrics.py`)
63
+ - Feature extractor remains frozen; only classifier is trained
64
+ - Checkpoints saved based on validation loss
65
+
66
+ ### Dataset Structure
67
+
68
+ Three primary datasets configured in `defaults.py`:
69
+
70
+ - **GenImage**: 8 generative models (BigGAN, VQDM, SDv4, SDv5, Wukong, ADM, GLIDE, Midjourney)
71
+ - **DRCT**: 16 Stable Diffusion variants (various versions, ControlNet, inpainting, turbo)
72
+ - **UnivFD**: 19 generative models (ProGAN, StyleGAN, CycleGAN, various diffusion models)
73
+
74
+ Each dataset has separate train/val splits with different generative models.
75
+
76
+ ### Data Preprocessing (`functions/preprocess.py`)
77
+
78
+ Configurable augmentation pipeline:
79
+ - Gaussian blur (σ=0-5)
80
+ - JPEG compression (QF=30-100)
81
+ - Probability-controlled application during training
82
+ - Image normalization specific to each feature extractor
83
+
84
+ ## Configuration System
85
+
86
+ YAML config files in `configs/` specify per-model settings:
87
+
88
+ ```yaml
89
+ classifier:
90
+ input_dim: 4096 # Feature dimension from backbone
91
+ hidden_layers: [1024] # Single hidden layer
92
+
93
+ dataset:
94
+ model_name: "arniqa" # Feature extractor identifier
95
+ f_model_name: "arniqa" # Used for checkpoint naming
96
+
97
+ trainer:
98
+ devices: [0] # GPU indices
99
+ max_epochs: 20
100
+ batch_size: 64
101
+ ```
102
+
103
+ The `train.py` script overrides certain config values based on in-script settings (dataset_type, loss function, preprocessing level).
104
+
105
+ ## Path Configuration (CRITICAL)
106
+
107
+ `defaults.py` contains hardcoded paths that MUST match your environment:
108
+
109
+ - `main_dataset_dir`: Location of GenImage/UnivFD/DRCT datasets
110
+ - `main_checkpoints_dir`: Where trained classifier checkpoints are saved
111
+ - `main_feature_ckpts_dir`: Pretrained IQA model weights
112
+ - `main_prior_checkpoints_dir`: Prior method checkpoints
113
+
114
+ **The code checks for specific mount points and will assert False if none match.** You must either:
115
+ 1. Update paths in `defaults.py` to match your environment
116
+ 2. Create the expected directory structure
117
+
118
+ ## Checkpoint Management
119
+
120
+ Checkpoints organized hierarchically:
121
+ ```
122
+ checkpoints/
123
+ ├── GenImage/
124
+ │ └── extensive/
125
+ │ └── MarginContrastiveLoss_CrossEntropy/
126
+ │ └── {model_name}/
127
+ │ └── best_model.ckpt
128
+ └── DRCT/
129
+ └── extensive/
130
+ └── MarginContrastiveLoss_CrossEntropy/
131
+ └── {model_name}/
132
+ └── best_model.ckpt
133
+ ```
134
+
135
+ Training automatically resumes from `best_model.ckpt` if found in expected location.
136
+
137
+ ## Dependencies
138
+
139
+ Core libraries (see `functions/ReIQA/requirements.txt` for full list):
140
+ - PyTorch + torchvision
141
+ - PyTorch Lightning (training framework)
142
+ - timm (model architectures)
143
+ - torchmetrics (evaluation)
144
+ - numpy, scipy, scikit-learn, scikit-image
145
+ - PIL (Pillow) for image loading
146
+ - pyyaml for config parsing
147
+ - tqdm for progress bars
148
+
149
+ Feature extractor dependencies loaded dynamically (e.g., ARNIQA via `torch.hub.load`).
150
+
151
+ ## Important Implementation Details
152
+
153
+ ### Training Script Pattern
154
+ Both `train.py` and `test.py` redirect stdout to log files in `stdouts/` and `results/` directories. Output is not visible in console by default.
155
+
156
+ ### Feature Extraction
157
+ In `functions/module.py`, the global `feature_extractor_module` function is set before training. During training/validation steps, features are extracted with `torch.no_grad()` to prevent gradient computation through the frozen backbone.
158
+
159
+ ### Metrics and Thresholds
160
+ - **GenImage/DRCT**: Fixed threshold of 0.5 for binary classification
161
+ - **UnivFD**: Threshold determined from validation set for optimal accuracy
162
+
163
+ ### Cross-Dataset Testing
164
+ `test.py` includes cross-dataset evaluation (e.g., trained on GenImage, tested on DRCT) to measure generalization.
165
+
166
+ ## Prior Methods (`prior_methods/`)
167
+
168
+ Comparison implementations of baseline detectors:
169
+ - CLIP-based classifiers (various architectures)
170
+ - DRCT (Detecting and Recovering Content Transformations)
171
+
172
+ These use similar training patterns but different feature extractors. Organized in parallel structure to main codebase.
173
+
174
+ ## Results and Analysis
175
+
176
+ - `results/`: CSV files with per-model, per-dataset metrics
177
+ - `analysis/plots/`: Generated visualizations (polar plots, t-SNE, robustness curves)
178
+ - Log files track training progress and test results
179
+
180
+ ## Modifying for New Experiments
181
+
182
+ 1. **Add new feature extractor**: Create wrapper in `features/`, add to `get_model()` in `functions/networks.py`
183
+ 2. **Add new dataset**: Update `defaults.py` with source lists, add getter function in `functions/utils.py`
184
+ 3. **Change training settings**: Modify settings list in `train.py` (dataset, loss, augmentation level)
185
+ 4. **Test new distortions**: Add preprocessing settings in `test.py` preprocess_settings_list
defaults.py CHANGED
@@ -21,7 +21,18 @@ elif os.path.exists("/mnt/LIVELAB_NAS2/krishna/Perceptual-Classifiers"):
21
  main_feature_ckpts_dir = "/mnt/LIVELAB_NAS2/krishna/Perceptual-Classifiers/feature_extractor_checkpoints"
22
  main_prior_checkpoints_dir = "/mnt/LIVELAB_NAS2/krishna/Perceptual-Classifiers/prior_methods_checkpoints"
23
  else:
24
- assert False, "Invalid Dataset Directory"
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
 
 
21
  main_feature_ckpts_dir = "/mnt/LIVELAB_NAS2/krishna/Perceptual-Classifiers/feature_extractor_checkpoints"
22
  main_prior_checkpoints_dir = "/mnt/LIVELAB_NAS2/krishna/Perceptual-Classifiers/prior_methods_checkpoints"
23
  else:
24
+ # Local setup - use directories relative to this file
25
+ _base_dir = os.path.dirname(os.path.abspath(__file__))
26
+ main_dataset_dir = os.path.join(_base_dir, "datasets")
27
+ main_checkpoints_dir = os.path.join(_base_dir, "checkpoints")
28
+ main_feature_ckpts_dir = os.path.join(_base_dir, "feature_extractor_checkpoints")
29
+ main_prior_checkpoints_dir = os.path.join(_base_dir, "prior_methods_checkpoints")
30
+
31
+ # Create directories if they don't exist
32
+ os.makedirs(main_dataset_dir, exist_ok=True)
33
+ os.makedirs(main_checkpoints_dir, exist_ok=True)
34
+ os.makedirs(main_feature_ckpts_dir, exist_ok=True)
35
+ os.makedirs(main_prior_checkpoints_dir, exist_ok=True)
36
 
37
 
38
 
functions/run_on_images_fn.py CHANGED
@@ -273,7 +273,9 @@ def run_on_images(feature_extractor, classifier, config, test_real_images_paths,
273
  # Global Variables: (feature_extractor)
274
  global feature_extractor_module
275
  feature_extractor_module = feature_extractor
276
- feature_extractor_module.to("cuda")
 
 
277
  feature_extractor_module.eval()
278
  for params in feature_extractor_module.parameters():
279
  params.requires_grad = False
@@ -285,8 +287,12 @@ def run_on_images(feature_extractor, classifier, config, test_real_images_paths,
285
  Model = Model_LightningModule(classifier, config)
286
 
287
  # PyTorch Lightning Trainer
 
 
 
 
288
  trainer = pl.Trainer(
289
- **config["trainer"],
290
  callbacks=[best_checkpoint_callback, utils.LitProgressBar()],
291
  precision=32
292
  )
 
273
  # Global Variables: (feature_extractor)
274
  global feature_extractor_module
275
  feature_extractor_module = feature_extractor
276
+ # Use CPU for Mac compatibility (change to "cuda" if you have NVIDIA GPU)
277
+ device = "cpu"
278
+ feature_extractor_module.to(device)
279
  feature_extractor_module.eval()
280
  for params in feature_extractor_module.parameters():
281
  params.requires_grad = False
 
287
  Model = Model_LightningModule(classifier, config)
288
 
289
  # PyTorch Lightning Trainer
290
+ # Override accelerator and devices for Mac compatibility
291
+ trainer_config = config["trainer"].copy()
292
+ trainer_config["accelerator"] = "cpu" # Use "cuda" for NVIDIA GPU, "mps" for Apple Silicon GPU
293
+ trainer_config["devices"] = 1 # CPU uses integer, GPU uses list like [0]
294
  trainer = pl.Trainer(
295
+ **trainer_config,
296
  callbacks=[best_checkpoint_callback, utils.LitProgressBar()],
297
  precision=32
298
  )
test_all_models.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test all available models on the same image
3
+ """
4
+ import os
5
+ import sys
6
+
7
+ # Available models - test all 5 IQA-based models
8
+ models = ['contrique', 'hyperiqa', 'tres', 'reiqa', 'arniqa']
9
+
10
+ # Test images directory
11
+ test_images_dir = "new_images_to_test"
12
+
13
+ # Get all images from the directory
14
+ import glob
15
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
16
+ test_images = []
17
+ for ext in image_extensions:
18
+ test_images.extend(glob.glob(os.path.join(test_images_dir, ext)))
19
+
20
+ if not test_images:
21
+ print(f"Error: No images found in {test_images_dir}/")
22
+ sys.exit(1)
23
+
24
+ print(f"Found {len(test_images)} image(s) in {test_images_dir}/")
25
+ print("=" * 80)
26
+
27
+ # Import libraries once
28
+ sys.path.insert(0, '.')
29
+ from yaml import safe_load
30
+ from functions.loss_optimizers_metrics import *
31
+ from functions.run_on_images_fn import run_on_images
32
+ import functions.utils as utils
33
+ import functions.networks as networks
34
+ import defaults
35
+ import warnings
36
+ warnings.filterwarnings("ignore")
37
+
38
+ all_results = {}
39
+
40
+ # Test each model
41
+ for model_idx, model_name in enumerate(models, 1):
42
+ print(f"\n{'='*80}")
43
+ print(f"[{model_idx}/{len(models)}] Testing model: {model_name.upper()}")
44
+ print("="*80)
45
+
46
+ try:
47
+ config_path = f"configs/{model_name}.yaml"
48
+ config = safe_load(open(config_path, "r"))
49
+
50
+ # Override settings
51
+ config["dataset"]["dataset_type"] = "GenImage"
52
+ config["checkpoints"]["resume_dirname"] = "GenImage/extensive/MarginContrastiveLoss_CrossEntropy"
53
+ config["checkpoints"]["resume_filename"] = "best_model.ckpt"
54
+ config["checkpoints"]["checkpoint_dirname"] = "extensive/MarginContrastiveLoss_CrossEntropy"
55
+ config["checkpoints"]["checkpoint_filename"] = "best_model.ckpt"
56
+
57
+ # Training settings (for testing)
58
+ config["train_settings"]["train"] = False
59
+ config["train_loss_fn"]["name"] = "CrossEntropy"
60
+ config["val_loss_fn"]["name"] = "CrossEntropy"
61
+
62
+ # Model setup
63
+ device = "cpu"
64
+ feature_extractor = networks.get_model(model_name=model_name, device=device)
65
+
66
+ # Classifier
67
+ config["classifier"]["hidden_layers"] = [1024]
68
+ classifier = networks.Classifier_Arch2(
69
+ input_dim=config["classifier"]["input_dim"],
70
+ hidden_layers=config["classifier"]["hidden_layers"]
71
+ )
72
+
73
+ # Preprocessing settings
74
+ preprocess_settings = {
75
+ "model_name": model_name,
76
+ "selected_transforms_name": "test",
77
+ "probability": -1,
78
+ "gaussian_blur_range": None,
79
+ "jpeg_compression_qfs": None,
80
+ "input_image_dimensions": (224, 224),
81
+ "resize": None
82
+ }
83
+
84
+ print(f"✓ {model_name.upper()} model loaded successfully\n")
85
+
86
+ results = []
87
+
88
+ # Test each image with this model
89
+ for idx, test_image in enumerate(test_images, 1):
90
+ image_name = os.path.basename(test_image)
91
+ print(f" [{idx}/{len(test_images)}] Testing: {image_name}")
92
+
93
+ # Test images
94
+ test_real_images_paths = [test_image]
95
+ test_fake_images_paths = []
96
+
97
+ try:
98
+ test_set_metrics, best_threshold, y_pred, y_true = run_on_images(
99
+ feature_extractor=feature_extractor,
100
+ classifier=classifier,
101
+ config=config,
102
+ test_real_images_paths=test_real_images_paths,
103
+ test_fake_images_paths=test_fake_images_paths,
104
+ preprocess_settings=preprocess_settings,
105
+ best_threshold=0.5,
106
+ verbose=False
107
+ )
108
+
109
+ score = y_pred[0] if len(y_pred) > 0 else None
110
+ prediction = "AI-Generated" if score and score > 0.5 else "Real"
111
+ confidence = abs(score - 0.5) * 200 if score else 0
112
+
113
+ results.append({
114
+ 'image': image_name,
115
+ 'score': score,
116
+ 'prediction': prediction,
117
+ 'confidence': confidence
118
+ })
119
+
120
+ print(f" ✓ Score: {score:.4f} → {prediction} ({confidence:.1f}% confidence)")
121
+
122
+ except Exception as e:
123
+ print(f" ✗ Error: {e}")
124
+ results.append({
125
+ 'image': image_name,
126
+ 'score': None,
127
+ 'prediction': 'Error',
128
+ 'confidence': 0
129
+ })
130
+
131
+ all_results[model_name] = results
132
+
133
+ except Exception as e:
134
+ print(f"✗ Failed to load {model_name.upper()} model: {e}")
135
+ all_results[model_name] = None
136
+
137
+ # Final Summary
138
+ print("\n" + "="*80)
139
+ print("FINAL SUMMARY - ALL MODELS")
140
+ print("="*80)
141
+
142
+ for model_name, results in all_results.items():
143
+ if results is None:
144
+ print(f"\n{model_name.upper()}: Failed to load")
145
+ continue
146
+
147
+ print(f"\n{model_name.upper()}:")
148
+ print("-"*80)
149
+ print(f"{'Image':<50} {'Score':<10} {'Prediction':<15} {'Confidence':<12}")
150
+ print("-"*80)
151
+
152
+ for r in results:
153
+ score_str = f"{r['score']:.4f}" if r['score'] is not None else "N/A"
154
+ conf_str = f"{r['confidence']:.1f}%" if r['score'] is not None else "N/A"
155
+ img_name = r['image'][:47] + "..." if len(r['image']) > 50 else r['image']
156
+ print(f"{img_name:<50} {score_str:<10} {r['prediction']:<15} {conf_str:<12}")
157
+
158
+ # Statistics
159
+ valid_predictions = [r for r in results if r['score'] is not None]
160
+ if valid_predictions:
161
+ avg_score = sum(r['score'] for r in valid_predictions) / len(valid_predictions)
162
+ ai_count = sum(1 for r in valid_predictions if r['score'] > 0.5)
163
+ real_count = len(valid_predictions) - ai_count
164
+ avg_confidence = sum(r['confidence'] for r in valid_predictions) / len(valid_predictions)
165
+
166
+ print("-"*80)
167
+ print(f"Average Score: {avg_score:.4f} | AI: {ai_count} | Real: {real_count} | Avg Confidence: {avg_confidence:.1f}%")
168
+
169
+ print("\n" + "="*80)
test_on_images.py CHANGED
@@ -15,18 +15,25 @@ import functions.utils as utils
15
  import functions.networks as networks
16
  import defaults
17
 
18
- dir_path = "/home/krishna/Perceptual-Classifiers-Working/images/True=Real_Pred=Fake"
19
- test_real_images_files = os.listdir(dir_path)
 
 
20
  test_real_images_paths = []
21
- for f in test_real_images_files:
22
- test_real_images_paths.append(
23
- os.path.join(
24
- dir_path, f
25
- )
26
- )
27
 
28
  test_fake_images_paths = []
29
 
 
 
 
 
 
 
 
 
 
30
  # Calling Main function
31
  if __name__ == '__main__':
32
  # -----------------------------------------------------------------
@@ -125,8 +132,9 @@ if __name__ == '__main__':
125
  f_model_name = config["dataset"]["f_model_name"]
126
 
127
 
128
- # Model
129
- feature_extractor = networks.get_model(model_name=config["dataset"]["model_name"], device="cuda")
 
130
 
131
 
132
  # Classifier
 
15
  import functions.networks as networks
16
  import defaults
17
 
18
+ # Get all images from new_images_to_test folder
19
+ import glob
20
+ test_images_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "new_images_to_test")
21
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
22
  test_real_images_paths = []
23
+ for ext in image_extensions:
24
+ test_real_images_paths.extend([os.path.abspath(p) for p in glob.glob(os.path.join(test_images_dir, ext))])
 
 
 
 
25
 
26
  test_fake_images_paths = []
27
 
28
+ if not test_real_images_paths:
29
+ print(f"Error: No images found in {test_images_dir}/")
30
+ sys.exit(1)
31
+
32
+ print(f"Found {len(test_real_images_paths)} image(s) to test:")
33
+ for img in test_real_images_paths:
34
+ print(f" - {os.path.basename(img)}")
35
+ print()
36
+
37
  # Calling Main function
38
  if __name__ == '__main__':
39
  # -----------------------------------------------------------------
 
132
  f_model_name = config["dataset"]["f_model_name"]
133
 
134
 
135
+ # Model - use CPU for Mac (MPS not fully supported by all models)
136
+ device = "cpu" # Change to "cuda" if you have NVIDIA GPU
137
+ feature_extractor = networks.get_model(model_name=config["dataset"]["model_name"], device=device)
138
 
139
 
140
  # Classifier