Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,197 Bytes
81d425d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
@dataclass
class PrototypeDB:
centers: torch.Tensor # [N,D] float32
labels: torch.Tensor # [N] int64
label_names: List[str] # id -> name
source_path: Optional[Path] = None
@property
def dim(self) -> int:
return int(self.centers.shape[1])
def id_to_name(self, idx: int) -> str:
if 0 <= idx < len(self.label_names):
return self.label_names[idx]
return str(idx)
def ensure_label_id(self, name: str) -> int:
name = str(name).strip()
if not name:
raise ValueError("Empty label name.")
try:
i = self.label_names.index(name)
return int(i)
except ValueError:
self.label_names.append(name)
return len(self.label_names) - 1
def add_center(self, label_name: str, center: torch.Tensor) -> int:
if center.ndim != 1:
raise ValueError("center must be 1D embedding vector.")
center = torch.nn.functional.normalize(center.float(), dim=0).view(1, -1)
lid = self.ensure_label_id(label_name)
self.centers = torch.cat([self.centers, center], dim=0)
self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0)
return lid
def save(self, path: Optional[str | Path] = None) -> Path:
out = Path(path) if path is not None else self.source_path
if out is None:
raise ValueError("No output path specified for saving prototype DB.")
out.parent.mkdir(parents=True, exist_ok=True)
torch.save(
dict(
centers=self.centers.detach().cpu(),
labels=self.labels.detach().cpu(),
label_names=list(self.label_names),
),
str(out),
)
self.source_path = out
return out
def _infer_label_names_from_dataset(dataset_root: Path) -> Optional[List[str]]:
# `train_style_ddp.TriViewDataset` assigns IDs based on sorted directory names under dataset/<artist>.
if not dataset_root.exists():
return None
artists = sorted([p.name for p in dataset_root.iterdir() if p.is_dir()])
return artists if artists else None
def load_prototype_db(path: str | Path, *, try_dataset_dir: str | Path = "dataset") -> PrototypeDB:
p = Path(path)
if not p.exists():
raise FileNotFoundError(str(p))
obj = torch.load(str(p), map_location="cpu")
if not isinstance(obj, dict) or "centers" not in obj or "labels" not in obj:
raise ValueError(f"Unsupported prototype file format: {p}")
centers = obj["centers"].float()
labels = obj["labels"].long()
label_names = obj.get("label_names")
if not isinstance(label_names, list) or not all(isinstance(x, str) for x in label_names):
inferred = _infer_label_names_from_dataset(Path(try_dataset_dir))
if inferred is None:
max_id = int(labels.max().item()) if labels.numel() else -1
label_names = [str(i) for i in range(max_id + 1)]
else:
label_names = inferred
return PrototypeDB(centers=centers, labels=labels, label_names=label_names, source_path=p)
def topk_predictions(
db: PrototypeDB,
z: torch.Tensor,
*,
topk: int = 5,
) -> List[Tuple[str, float]]:
"""
Returns [(label_name, score)] sorted by score desc (cosine similarity).
`z` is 1D embedding (D).
"""
if z.ndim != 1:
raise ValueError("z must be 1D.")
Z = torch.nn.functional.normalize(z.float(), dim=0).view(1, -1)
C = torch.nn.functional.normalize(db.centers.float(), dim=1)
sim = (Z @ C.t()).squeeze(0) # [N]
k = int(max(1, min(topk, sim.numel()))) if sim.numel() else 0
if k == 0:
return []
vals, idxs = torch.topk(sim, k=k)
out: List[Tuple[str, float]] = []
for v, i in zip(vals.tolist(), idxs.tolist()):
lid = int(db.labels[i].item())
out.append((db.id_to_name(lid), float(v)))
return out
def topk_predictions_unique_labels(
db: PrototypeDB,
z: torch.Tensor,
*,
topk: int = 5,
) -> List[Tuple[str, float]]:
"""
Like topk_predictions(), but dedupes by label:
if a label has multiple prototypes, only the highest score is kept.
"""
if z.ndim != 1:
raise ValueError("z must be 1D.")
Z = torch.nn.functional.normalize(z.float(), dim=0).view(1, -1)
C = torch.nn.functional.normalize(db.centers.float(), dim=1)
sim = (Z @ C.t()).squeeze(0) # [N]
if sim.numel() == 0:
return []
best_by_label: dict[int, float] = {}
# iterate all prototypes once; keep max per label id
for i in range(sim.numel()):
lid = int(db.labels[i].item())
s = float(sim[i].item())
prev = best_by_label.get(lid)
if prev is None or s > prev:
best_by_label[lid] = s
items = sorted(best_by_label.items(), key=lambda kv: kv[1], reverse=True)
items = items[: max(1, int(topk))]
return [(db.id_to_name(lid), float(score)) for (lid, score) in items]
|