| |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from torchvision import transforms |
| from pathlib import Path |
| import argparse |
| import json |
| from typing import Dict, Tuple |
|
|
|
|
| |
| class LightweightCompressionNet(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv_blocks = nn.Sequential( |
| nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), |
| nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), |
| nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), |
| nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), |
| nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
| nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
| nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), |
| nn.AdaptiveAvgPool2d(1) |
| ) |
| self.head = nn.Sequential( |
| nn.Linear(256, 32), nn.GELU(), |
| nn.Linear(32, 4), nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| features = self.conv_blocks(x) |
| features = features.view(features.size(0), -1) |
| return self.head(features) |
|
|
|
|
| |
| class CompressionArtifactPredictor: |
| def __init__(self, model_path: str, device: str = "cuda"): |
| self.device = torch.device(device if torch.cuda.is_available() else "cpu") |
| self.model = LightweightCompressionNet().to(self.device) |
| self.model.eval() |
|
|
| |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
| |
| self.preprocess = transforms.Compose([ |
| transforms.ToTensor(), |
| ]) |
|
|
| self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl'] |
| self.quality_ranges = { |
| 'jpeg': (0, 100), |
| 'webp': (0, 100), |
| 'avif': (0, 100), |
| 'jxl': (0, 100) |
| } |
|
|
| def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]: |
| """ |
| Predict compression quality/artifact levels for all formats. |
| |
| Args: |
| image: PIL Image in RGB mode |
| |
| Returns: |
| Dictionary with predictions for each format |
| """ |
| |
| img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
|
|
| |
| with torch.no_grad(): |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy() |
|
|
| |
| results = {} |
| for i, fmt in enumerate(self.compression_formats): |
| normalized_score = float(predictions[i]) |
| actual_quality = self._denormalize_quality(normalized_score, fmt) |
|
|
| results[fmt] = { |
| 'normalized_score': normalized_score, |
| 'predicted_quality': actual_quality, |
| 'artifact_level': 1.0 - normalized_score |
| } |
|
|
| return results |
|
|
| def _denormalize_quality(self, normalized: float, fmt: str) -> float: |
| """Convert normalized prediction back to original quality range""" |
| min_q, max_q = self.quality_ranges[fmt] |
| return normalized * (max_q - min_q) + min_q |
|
|
| def predict_format(self, image: Image.Image, format_name: str) -> float: |
| """Predict quality for a specific format only""" |
| if format_name not in self.compression_formats: |
| raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}") |
|
|
| results = self.predict(image) |
| return results[format_name]['predicted_quality'] |
|
|
|
|
| |
| def main(): |
|
|
| |
| predictor = CompressionArtifactPredictor("checkpoints/model.pt") |
|
|
| |
| image_path = Path("/path/to/image") |
| if not image_path.exists(): |
| raise FileNotFoundError(f"Image not found: {image_path}") |
|
|
| image = Image.open(image_path).convert('RGB') |
| print(f"\n๐ Analyzing image: {image_path}") |
| print(f"๐ Image size: {image.size[0]}x{image.size[1]}\n") |
|
|
| |
|
|
| results = predictor.predict(image) |
|
|
| print("=" * 50) |
| print("๐ COMPRESSION ARTIFACT ANALYSIS") |
| print("=" * 50) |
|
|
| for fmt, data in results.items(): |
| print(f"\n{fmt.upper():>4}:") |
| print(f" Predicted Quality: {data['predicted_quality']:>6.1f} / {predictor.quality_ranges[fmt][1]}") |
| print(f" Normalized Score: {data['normalized_score']:>6.3f}") |
| print(f" Artifact Level: {data['artifact_level']:>6.3f} (0.0=clean, 1.0=heavily compressed)") |
|
|
| |
| avg_artifact_level = sum(r['artifact_level'] for r in results.values()) / len(results) |
| print(f"\n{'=' * 50}") |
| print(f"Overall artifact level: {avg_artifact_level:.3f}") |
| if avg_artifact_level < 0.2: |
| print("โ
Image appears to have minimal compression artifacts") |
| elif avg_artifact_level < 0.5: |
| print("โ ๏ธ Image shows moderate compression artifacts") |
| else: |
| print("โ Image exhibits heavy compression artifacts") |
|
|
|
|
| if __name__ == "__main__": |
| main() |