File size: 2,167 Bytes
4f81869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Upgrade an existing prototype DB .pt file to include artist names (label_names).

This is useful for older prototype files that only store:
  - centers: [N, D]
  - labels:  [N]

We infer label_names from `dataset/` folder (sorted artist directories), matching
`train_style_ddp.TriViewDataset` label assignment.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import torch


def infer_label_names(dataset_dir: Path) -> list[str]:
    if not dataset_dir.exists():
        raise FileNotFoundError(f"dataset dir not found: {dataset_dir}")
    names = sorted([p.name for p in dataset_dir.iterdir() if p.is_dir()])
    if not names:
        raise RuntimeError(f"No artist folders found under: {dataset_dir}")
    return names


def main() -> None:
    p = argparse.ArgumentParser(description="Add label_names to an existing prototype DB .pt")
    p.add_argument("--in", dest="in_path", required=True, help="Input .pt prototype file")
    p.add_argument("--out", dest="out_path", default=None, help="Output .pt (default: overwrite input)")
    p.add_argument("--dataset-dir", type=str, default="dataset", help="Dataset root to infer artist names from")
    args = p.parse_args()

    in_path = Path(args.in_path)
    out_path = Path(args.out_path) if args.out_path else in_path
    dataset_dir = Path(args.dataset_dir)

    obj = torch.load(str(in_path), map_location="cpu")
    if not isinstance(obj, dict) or "centers" not in obj or "labels" not in obj:
        raise ValueError("Unsupported prototype file format (expected dict with centers+labels).")

    if "label_names" in obj and isinstance(obj["label_names"], list) and obj["label_names"]:
        print("label_names already present; nothing to do.")
        if out_path != in_path:
            torch.save(obj, str(out_path))
            print("saved:", out_path)
        return

    label_names = infer_label_names(dataset_dir)
    obj["label_names"] = label_names

    out_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(obj, str(out_path))
    print("saved:", out_path)


if __name__ == "__main__":
    main()