Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,971 Bytes
81d425d 059ebcb 81d425d 059ebcb 81d425d 059ebcb 81d425d 059ebcb 81d425d 059ebcb 81d425d c61411c 059ebcb 81d425d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import torch
# ZeroGPU on HF Spaces: CUDA must not be initialized in the main process
_ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
@dataclass(frozen=True)
class LoadedModel:
model: torch.nn.Module
device: torch.device
stage_i: int
embed_dim: int
T_w: object
T_f: object
T_e: object
def _pick_device(device: str) -> torch.device:
if device.strip().lower() == "cpu":
return torch.device("cpu")
if _ON_SPACES:
# ZeroGPU: can't init CUDA in main process
return torch.device("cpu")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def load_style_model(
ckpt_path: str | Path,
*,
device: str = "auto",
) -> LoadedModel:
"""
Loads `train_style_ddp.TriViewStyleNet` from a checkpoint saved by `train_style_ddp.py`.
Returns the model and deterministic val transforms based on the checkpoint stage.
"""
import train_style_ddp as ts
ckpt_path = Path(ckpt_path)
if not ckpt_path.exists():
raise FileNotFoundError(str(ckpt_path))
# On Spaces, always use CPU (ZeroGPU forbids CUDA in main process)
if _ON_SPACES:
dev = torch.device("cpu")
elif device == "auto":
dev = _pick_device("cuda" if torch.cuda.is_available() else "cpu")
else:
dev = _pick_device(device)
ck = torch.load(str(ckpt_path), map_location="cpu")
meta = ck.get("meta", {}) if isinstance(ck, dict) else {}
stage_i = int(meta.get("stage", 1))
stage_i = max(1, min(stage_i, len(ts.cfg.stages)))
stage = ts.cfg.stages[stage_i - 1]
T_w, T_f, T_e = ts.make_val_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])
model = ts.TriViewStyleNet(out_dim=ts.cfg.embed_dim, mix_p=ts.cfg.mixstyle_p, share_backbone=True)
state = ck["model"] if isinstance(ck, dict) and "model" in ck else ck
model.load_state_dict(state, strict=False)
model.eval()
model = model.to(dev)
try:
model = model.to(memory_format=torch.channels_last)
except Exception:
pass
return LoadedModel(
model=model,
device=dev,
stage_i=stage_i,
embed_dim=int(ts.cfg.embed_dim),
T_w=T_w,
T_f=T_f,
T_e=T_e,
)
def embed_triview(
lm: LoadedModel,
*,
whole: Optional[torch.Tensor],
face: Optional[torch.Tensor],
eyes: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Computes a single fused embedding for a triview sample.
Each view tensor must be CHW (already normalized) and will be batched.
Missing views can be None.
"""
if whole is None and face is None and eyes is None:
raise ValueError("At least one of whole/face/eyes must be provided.")
views = {}
masks = {}
for k, v in (("whole", whole), ("face", face), ("eyes", eyes)):
if v is None:
views[k] = None
masks[k] = torch.zeros(1, dtype=torch.bool, device=lm.device)
else:
vb = v.unsqueeze(0).to(lm.device)
views[k] = vb
masks[k] = torch.ones(1, dtype=torch.bool, device=lm.device)
# Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
import train_style_ddp as _ts
_dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
# On CPU or Spaces, skip autocast entirely to avoid touching CUDA
use_amp = (lm.device.type == "cuda") and not _ON_SPACES
if use_amp:
with torch.no_grad(), torch.amp.autocast("cuda", dtype=_dtype, enabled=True):
z, _, _ = lm.model(views, masks)
else:
with torch.no_grad():
z, _, _ = lm.model(views, masks)
z = torch.nn.functional.normalize(z.float(), dim=1)
return z.squeeze(0).detach().cpu()
|