File size: 3,971 Bytes
81d425d
 
059ebcb
81d425d
 
 
 
 
 
059ebcb
 
 
81d425d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
059ebcb
 
 
81d425d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
059ebcb
 
 
 
81d425d
059ebcb
 
81d425d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61411c
 
 
059ebcb
 
 
 
 
 
 
 
81d425d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple

import torch

# ZeroGPU on HF Spaces: CUDA must not be initialized in the main process
_ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))


@dataclass(frozen=True)
class LoadedModel:
    model: torch.nn.Module
    device: torch.device
    stage_i: int
    embed_dim: int
    T_w: object
    T_f: object
    T_e: object


def _pick_device(device: str) -> torch.device:
    if device.strip().lower() == "cpu":
        return torch.device("cpu")
    if _ON_SPACES:
        # ZeroGPU: can't init CUDA in main process
        return torch.device("cpu")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def load_style_model(
    ckpt_path: str | Path,
    *,
    device: str = "auto",
) -> LoadedModel:
    """
    Loads `train_style_ddp.TriViewStyleNet` from a checkpoint saved by `train_style_ddp.py`.
    Returns the model and deterministic val transforms based on the checkpoint stage.
    """
    import train_style_ddp as ts

    ckpt_path = Path(ckpt_path)
    if not ckpt_path.exists():
        raise FileNotFoundError(str(ckpt_path))

    # On Spaces, always use CPU (ZeroGPU forbids CUDA in main process)
    if _ON_SPACES:
        dev = torch.device("cpu")
    elif device == "auto":
        dev = _pick_device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        dev = _pick_device(device)

    ck = torch.load(str(ckpt_path), map_location="cpu")
    meta = ck.get("meta", {}) if isinstance(ck, dict) else {}
    stage_i = int(meta.get("stage", 1))
    stage_i = max(1, min(stage_i, len(ts.cfg.stages)))
    stage = ts.cfg.stages[stage_i - 1]

    T_w, T_f, T_e = ts.make_val_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])

    model = ts.TriViewStyleNet(out_dim=ts.cfg.embed_dim, mix_p=ts.cfg.mixstyle_p, share_backbone=True)
    state = ck["model"] if isinstance(ck, dict) and "model" in ck else ck
    model.load_state_dict(state, strict=False)
    model.eval()
    model = model.to(dev)
    try:
        model = model.to(memory_format=torch.channels_last)
    except Exception:
        pass

    return LoadedModel(
        model=model,
        device=dev,
        stage_i=stage_i,
        embed_dim=int(ts.cfg.embed_dim),
        T_w=T_w,
        T_f=T_f,
        T_e=T_e,
    )


def embed_triview(
    lm: LoadedModel,
    *,
    whole: Optional[torch.Tensor],
    face: Optional[torch.Tensor],
    eyes: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Computes a single fused embedding for a triview sample.
    Each view tensor must be CHW (already normalized) and will be batched.
    Missing views can be None.
    """
    if whole is None and face is None and eyes is None:
        raise ValueError("At least one of whole/face/eyes must be provided.")

    views = {}
    masks = {}
    for k, v in (("whole", whole), ("face", face), ("eyes", eyes)):
        if v is None:
            views[k] = None
            masks[k] = torch.zeros(1, dtype=torch.bool, device=lm.device)
        else:
            vb = v.unsqueeze(0).to(lm.device)
            views[k] = vb
            masks[k] = torch.ones(1, dtype=torch.bool, device=lm.device)

    # Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
    import train_style_ddp as _ts
    _dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
    # On CPU or Spaces, skip autocast entirely to avoid touching CUDA
    use_amp = (lm.device.type == "cuda") and not _ON_SPACES
    if use_amp:
        with torch.no_grad(), torch.amp.autocast("cuda", dtype=_dtype, enabled=True):
            z, _, _ = lm.model(views, masks)
    else:
        with torch.no_grad():
            z, _, _ = lm.model(views, masks)
    z = torch.nn.functional.normalize(z.float(), dim=1)
    return z.squeeze(0).detach().cpu()