iljung1106 commited on
Commit
c61411c
·
1 Parent(s): 07f1b5a

Disabled loading CUDA on main process

Browse files
app/model_io.py CHANGED
@@ -100,7 +100,10 @@ def embed_triview(
100
  views[k] = vb
101
  masks[k] = torch.ones(1, dtype=torch.bool, device=lm.device)
102
 
103
- with torch.no_grad(), torch.amp.autocast("cuda", dtype=getattr(__import__("train_style_ddp"), "amp_dtype", torch.float16), enabled=(lm.device.type == "cuda")):
 
 
 
104
  z, _, _ = lm.model(views, masks)
105
  z = torch.nn.functional.normalize(z.float(), dim=1)
106
  return z.squeeze(0).detach().cpu()
 
100
  views[k] = vb
101
  masks[k] = torch.ones(1, dtype=torch.bool, device=lm.device)
102
 
103
+ # Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
104
+ import train_style_ddp as _ts
105
+ _dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
106
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=_dtype, enabled=(lm.device.type == "cuda")):
107
  z, _, _ = lm.model(views, masks)
108
  z = torch.nn.functional.normalize(z.float(), dim=1)
109
  return z.squeeze(0).detach().cpu()
scripts/train_style_ddp.py CHANGED
@@ -74,10 +74,24 @@ torch.backends.cudnn.benchmark = True
74
  if hasattr(torch, "set_float32_matmul_precision"):
75
  torch.set_float32_matmul_precision("high")
76
 
77
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
78
- amp_dtype = torch.bfloat16
79
- else:
80
- amp_dtype = torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # --- PIL safety/verbosity tweaks ---
83
  ImageFile.LOAD_TRUNCATED_IMAGES = True
@@ -1056,7 +1070,7 @@ def ddp_train_worker(rank: int, world_size: int):
1056
  }
1057
  masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
1058
 
1059
- with torch.amp.autocast('cuda', dtype=amp_dtype):
1060
  z_fused, z_views_dict, W = model(views, masks)
1061
 
1062
  Z_all, Y_all, G_all = [], [], []
 
74
  if hasattr(torch, "set_float32_matmul_precision"):
75
  torch.set_float32_matmul_precision("high")
76
 
77
+ # Lazy amp_dtype detection to avoid CUDA init at import time (required for HF Spaces ZeroGPU)
78
+ _amp_dtype_cache = None
79
+
80
+ def _get_amp_dtype():
81
+ global _amp_dtype_cache
82
+ if _amp_dtype_cache is None:
83
+ try:
84
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
85
+ _amp_dtype_cache = torch.bfloat16
86
+ else:
87
+ _amp_dtype_cache = torch.float16
88
+ except Exception:
89
+ _amp_dtype_cache = torch.float16
90
+ return _amp_dtype_cache
91
+
92
+ # For backwards compatibility, amp_dtype is accessed via property-like usage
93
+ # but we keep a module-level name that can be imported (defaults to float16, updated on first GPU use)
94
+ amp_dtype = torch.float16 # safe default; actual dtype picked at runtime via _get_amp_dtype()
95
 
96
  # --- PIL safety/verbosity tweaks ---
97
  ImageFile.LOAD_TRUNCATED_IMAGES = True
 
1070
  }
1071
  masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
1072
 
1073
+ with torch.amp.autocast('cuda', dtype=_get_amp_dtype()):
1074
  z_fused, z_views_dict, W = model(views, masks)
1075
 
1076
  Z_all, Y_all, G_all = [], [], []
train_style_ddp.py CHANGED
@@ -74,10 +74,24 @@ torch.backends.cudnn.benchmark = True
74
  if hasattr(torch, "set_float32_matmul_precision"):
75
  torch.set_float32_matmul_precision("high")
76
 
77
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
78
- amp_dtype = torch.bfloat16
79
- else:
80
- amp_dtype = torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # --- PIL safety/verbosity tweaks ---
83
  ImageFile.LOAD_TRUNCATED_IMAGES = True
@@ -1056,7 +1070,7 @@ def ddp_train_worker(rank: int, world_size: int):
1056
  }
1057
  masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
1058
 
1059
- with torch.amp.autocast('cuda', dtype=amp_dtype):
1060
  z_fused, z_views_dict, W = model(views, masks)
1061
 
1062
  Z_all, Y_all, G_all = [], [], []
 
74
  if hasattr(torch, "set_float32_matmul_precision"):
75
  torch.set_float32_matmul_precision("high")
76
 
77
+ # Lazy amp_dtype detection to avoid CUDA init at import time (required for HF Spaces ZeroGPU)
78
+ _amp_dtype_cache = None
79
+
80
+ def _get_amp_dtype():
81
+ global _amp_dtype_cache
82
+ if _amp_dtype_cache is None:
83
+ try:
84
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
85
+ _amp_dtype_cache = torch.bfloat16
86
+ else:
87
+ _amp_dtype_cache = torch.float16
88
+ except Exception:
89
+ _amp_dtype_cache = torch.float16
90
+ return _amp_dtype_cache
91
+
92
+ # For backwards compatibility, amp_dtype is accessed via property-like usage
93
+ # but we keep a module-level name that can be imported (defaults to float16, updated on first GPU use)
94
+ amp_dtype = torch.float16 # safe default; actual dtype picked at runtime via _get_amp_dtype()
95
 
96
  # --- PIL safety/verbosity tweaks ---
97
  ImageFile.LOAD_TRUNCATED_IMAGES = True
 
1070
  }
1071
  masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
1072
 
1073
+ with torch.amp.autocast('cuda', dtype=_get_amp_dtype()):
1074
  z_fused, z_views_dict, W = model(views, masks)
1075
 
1076
  Z_all, Y_all, G_all = [], [], []
webui_gradio.py CHANGED
@@ -15,6 +15,9 @@ try:
15
  except Exception: # noqa: BLE001
16
  spaces = None
17
 
 
 
 
18
  def _patch_fastapi_starlette_middleware_unpack() -> None:
19
  """
20
  Work around FastAPI/Starlette version mismatches where Starlette's Middleware
@@ -230,6 +233,11 @@ def load_all(ckpt_path: str, proto_path: str, device: str) -> str:
230
  return "❌ No checkpoint selected."
231
  if not proto_path:
232
  return "❌ No prototype DB selected."
 
 
 
 
 
233
  try:
234
  lm = load_style_model(ckpt_path, device=device)
235
  db = load_prototype_db(proto_path, try_dataset_dir=str(ROOT / "dataset"))
@@ -459,12 +467,12 @@ if __name__ == "__main__":
459
  _patch_fastapi_starlette_middleware_unpack()
460
 
461
  try:
462
- _launch_compat(demo, server_name=args.host, server_port=args.port, show_api=False, share=args.share)
463
  except ValueError as e:
464
  # Some environments block localhost checks; fall back to share link.
465
  msg = str(e)
466
  if "localhost is not accessible" in msg and not args.share:
467
- _launch_compat(demo, server_name=args.host, server_port=args.port, show_api=False, share=True)
468
  else:
469
  raise
470
 
 
15
  except Exception: # noqa: BLE001
16
  spaces = None
17
 
18
+ # Detect if running on HF Spaces (ZeroGPU requires special handling)
19
+ _ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
20
+
21
  def _patch_fastapi_starlette_middleware_unpack() -> None:
22
  """
23
  Work around FastAPI/Starlette version mismatches where Starlette's Middleware
 
233
  return "❌ No checkpoint selected."
234
  if not proto_path:
235
  return "❌ No prototype DB selected."
236
+
237
+ # Force CPU on HF Spaces (ZeroGPU doesn't allow CUDA init in main process)
238
+ if _ON_SPACES:
239
+ device = "cpu"
240
+
241
  try:
242
  lm = load_style_model(ckpt_path, device=device)
243
  db = load_prototype_db(proto_path, try_dataset_dir=str(ROOT / "dataset"))
 
467
  _patch_fastapi_starlette_middleware_unpack()
468
 
469
  try:
470
+ _launch_compat(demo, server_name=args.host, server_port=args.port, show_api=False, share=args.share, ssr_mode=False)
471
  except ValueError as e:
472
  # Some environments block localhost checks; fall back to share link.
473
  msg = str(e)
474
  if "localhost is not accessible" in msg and not args.share:
475
+ _launch_compat(demo, server_name=args.host, server_port=args.port, show_api=False, share=True, ssr_mode=False)
476
  else:
477
  raise
478