iljung1106 commited on
Commit
b04d768
·
1 Parent(s): 64e3b6d

fixed using gpu on main problem

Browse files
Files changed (2) hide show
  1. model_io.py +128 -0
  2. webui_gradio.py +1 -1
model_io.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+
10
+ # ZeroGPU on HF Spaces: CUDA must not be initialized in the main process
11
+ _ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class LoadedModel:
16
+ model: torch.nn.Module
17
+ device: torch.device
18
+ stage_i: int
19
+ embed_dim: int
20
+ T_w: object
21
+ T_f: object
22
+ T_e: object
23
+
24
+
25
+ def _pick_device(device: str) -> torch.device:
26
+ if device.strip().lower() == "cpu":
27
+ return torch.device("cpu")
28
+ if _ON_SPACES:
29
+ # ZeroGPU: can't init CUDA in main process
30
+ return torch.device("cpu")
31
+ if torch.cuda.is_available():
32
+ return torch.device("cuda")
33
+ return torch.device("cpu")
34
+
35
+
36
+ def load_style_model(
37
+ ckpt_path: str | Path,
38
+ *,
39
+ device: str = "auto",
40
+ ) -> LoadedModel:
41
+ """
42
+ Loads `train_style_ddp.TriViewStyleNet` from a checkpoint saved by `train_style_ddp.py`.
43
+ Returns the model and deterministic val transforms based on the checkpoint stage.
44
+ """
45
+ import train_style_ddp as ts
46
+
47
+ ckpt_path = Path(ckpt_path)
48
+ if not ckpt_path.exists():
49
+ raise FileNotFoundError(str(ckpt_path))
50
+
51
+ # On Spaces, always use CPU (ZeroGPU forbids CUDA in main process)
52
+ if _ON_SPACES:
53
+ dev = torch.device("cpu")
54
+ elif device == "auto":
55
+ dev = _pick_device("cuda" if torch.cuda.is_available() else "cpu")
56
+ else:
57
+ dev = _pick_device(device)
58
+
59
+ ck = torch.load(str(ckpt_path), map_location="cpu")
60
+ meta = ck.get("meta", {}) if isinstance(ck, dict) else {}
61
+ stage_i = int(meta.get("stage", 1))
62
+ stage_i = max(1, min(stage_i, len(ts.cfg.stages)))
63
+ stage = ts.cfg.stages[stage_i - 1]
64
+
65
+ T_w, T_f, T_e = ts.make_val_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])
66
+
67
+ model = ts.TriViewStyleNet(out_dim=ts.cfg.embed_dim, mix_p=ts.cfg.mixstyle_p, share_backbone=True)
68
+ state = ck["model"] if isinstance(ck, dict) and "model" in ck else ck
69
+ model.load_state_dict(state, strict=False)
70
+ model.eval()
71
+ model = model.to(dev)
72
+ try:
73
+ model = model.to(memory_format=torch.channels_last)
74
+ except Exception:
75
+ pass
76
+
77
+ return LoadedModel(
78
+ model=model,
79
+ device=dev,
80
+ stage_i=stage_i,
81
+ embed_dim=int(ts.cfg.embed_dim),
82
+ T_w=T_w,
83
+ T_f=T_f,
84
+ T_e=T_e,
85
+ )
86
+
87
+
88
+ def embed_triview(
89
+ lm: LoadedModel,
90
+ *,
91
+ whole: Optional[torch.Tensor],
92
+ face: Optional[torch.Tensor],
93
+ eyes: Optional[torch.Tensor],
94
+ ) -> torch.Tensor:
95
+ """
96
+ Computes a single fused embedding for a triview sample.
97
+ Each view tensor must be CHW (already normalized) and will be batched.
98
+ Missing views can be None.
99
+ """
100
+ if whole is None and face is None and eyes is None:
101
+ raise ValueError("At least one of whole/face/eyes must be provided.")
102
+
103
+ views = {}
104
+ masks = {}
105
+ for k, v in (("whole", whole), ("face", face), ("eyes", eyes)):
106
+ if v is None:
107
+ views[k] = None
108
+ masks[k] = torch.zeros(1, dtype=torch.bool, device=lm.device)
109
+ else:
110
+ vb = v.unsqueeze(0).to(lm.device)
111
+ views[k] = vb
112
+ masks[k] = torch.ones(1, dtype=torch.bool, device=lm.device)
113
+
114
+ # Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
115
+ import train_style_ddp as _ts
116
+ _dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
117
+ # On CPU or Spaces, skip autocast entirely to avoid touching CUDA
118
+ use_amp = (lm.device.type == "cuda") and not _ON_SPACES
119
+ if use_amp:
120
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=_dtype, enabled=True):
121
+ z, _, _ = lm.model(views, masks)
122
+ else:
123
+ with torch.no_grad():
124
+ z, _, _ = lm.model(views, masks)
125
+ z = torch.nn.functional.normalize(z.float(), dim=1)
126
+ return z.squeeze(0).detach().cpu()
127
+
128
+
webui_gradio.py CHANGED
@@ -258,7 +258,7 @@ def load_all(ckpt_path: str, proto_path: str, device: str) -> str:
258
  yolo_dir=ROOT / "yolov5_anime",
259
  weights=ROOT / "yolov5x_anime.pt",
260
  cascade=ROOT / "anime-eyes-cascade.xml",
261
- yolo_device=("0" if torch.cuda.is_available() else "cpu"),
262
  )
263
  APP_STATE.extractor = AnimeFaceEyeExtractor(cfg)
264
  except Exception:
 
258
  yolo_dir=ROOT / "yolov5_anime",
259
  weights=ROOT / "yolov5x_anime.pt",
260
  cascade=ROOT / "anime-eyes-cascade.xml",
261
+ yolo_device="cpu" if _ON_SPACES else ("0" if torch.cuda.is_available() else "cpu"),
262
  )
263
  APP_STATE.extractor = AnimeFaceEyeExtractor(cfg)
264
  except Exception: