Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,167 Bytes
4f81869 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Upgrade an existing prototype DB .pt file to include artist names (label_names).
This is useful for older prototype files that only store:
- centers: [N, D]
- labels: [N]
We infer label_names from `dataset/` folder (sorted artist directories), matching
`train_style_ddp.TriViewDataset` label assignment.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
def infer_label_names(dataset_dir: Path) -> list[str]:
if not dataset_dir.exists():
raise FileNotFoundError(f"dataset dir not found: {dataset_dir}")
names = sorted([p.name for p in dataset_dir.iterdir() if p.is_dir()])
if not names:
raise RuntimeError(f"No artist folders found under: {dataset_dir}")
return names
def main() -> None:
p = argparse.ArgumentParser(description="Add label_names to an existing prototype DB .pt")
p.add_argument("--in", dest="in_path", required=True, help="Input .pt prototype file")
p.add_argument("--out", dest="out_path", default=None, help="Output .pt (default: overwrite input)")
p.add_argument("--dataset-dir", type=str, default="dataset", help="Dataset root to infer artist names from")
args = p.parse_args()
in_path = Path(args.in_path)
out_path = Path(args.out_path) if args.out_path else in_path
dataset_dir = Path(args.dataset_dir)
obj = torch.load(str(in_path), map_location="cpu")
if not isinstance(obj, dict) or "centers" not in obj or "labels" not in obj:
raise ValueError("Unsupported prototype file format (expected dict with centers+labels).")
if "label_names" in obj and isinstance(obj["label_names"], list) and obj["label_names"]:
print("label_names already present; nothing to do.")
if out_path != in_path:
torch.save(obj, str(out_path))
print("saved:", out_path)
return
label_names = infer_label_names(dataset_dir)
obj["label_names"] = label_names
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(obj, str(out_path))
print("saved:", out_path)
if __name__ == "__main__":
main()
|