Spaces:
Running
on
Zero
Running
on
Zero
iljung1106
commited on
Commit
·
c61411c
1
Parent(s):
07f1b5a
Disabled loading CUDA on main process
Browse files- app/model_io.py +4 -1
- scripts/train_style_ddp.py +19 -5
- train_style_ddp.py +19 -5
- webui_gradio.py +10 -2
app/model_io.py
CHANGED
|
@@ -100,7 +100,10 @@ def embed_triview(
|
|
| 100 |
views[k] = vb
|
| 101 |
masks[k] = torch.ones(1, dtype=torch.bool, device=lm.device)
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
z, _, _ = lm.model(views, masks)
|
| 105 |
z = torch.nn.functional.normalize(z.float(), dim=1)
|
| 106 |
return z.squeeze(0).detach().cpu()
|
|
|
|
| 100 |
views[k] = vb
|
| 101 |
masks[k] = torch.ones(1, dtype=torch.bool, device=lm.device)
|
| 102 |
|
| 103 |
+
# Use lazy dtype detection to avoid CUDA init at import time (ZeroGPU compatibility)
|
| 104 |
+
import train_style_ddp as _ts
|
| 105 |
+
_dtype = _ts._get_amp_dtype() if hasattr(_ts, "_get_amp_dtype") else torch.float16
|
| 106 |
+
with torch.no_grad(), torch.amp.autocast("cuda", dtype=_dtype, enabled=(lm.device.type == "cuda")):
|
| 107 |
z, _, _ = lm.model(views, masks)
|
| 108 |
z = torch.nn.functional.normalize(z.float(), dim=1)
|
| 109 |
return z.squeeze(0).detach().cpu()
|
scripts/train_style_ddp.py
CHANGED
|
@@ -74,10 +74,24 @@ torch.backends.cudnn.benchmark = True
|
|
| 74 |
if hasattr(torch, "set_float32_matmul_precision"):
|
| 75 |
torch.set_float32_matmul_precision("high")
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# --- PIL safety/verbosity tweaks ---
|
| 83 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
@@ -1056,7 +1070,7 @@ def ddp_train_worker(rank: int, world_size: int):
|
|
| 1056 |
}
|
| 1057 |
masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
|
| 1058 |
|
| 1059 |
-
with torch.amp.autocast('cuda', dtype=
|
| 1060 |
z_fused, z_views_dict, W = model(views, masks)
|
| 1061 |
|
| 1062 |
Z_all, Y_all, G_all = [], [], []
|
|
|
|
| 74 |
if hasattr(torch, "set_float32_matmul_precision"):
|
| 75 |
torch.set_float32_matmul_precision("high")
|
| 76 |
|
| 77 |
+
# Lazy amp_dtype detection to avoid CUDA init at import time (required for HF Spaces ZeroGPU)
|
| 78 |
+
_amp_dtype_cache = None
|
| 79 |
+
|
| 80 |
+
def _get_amp_dtype():
|
| 81 |
+
global _amp_dtype_cache
|
| 82 |
+
if _amp_dtype_cache is None:
|
| 83 |
+
try:
|
| 84 |
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
| 85 |
+
_amp_dtype_cache = torch.bfloat16
|
| 86 |
+
else:
|
| 87 |
+
_amp_dtype_cache = torch.float16
|
| 88 |
+
except Exception:
|
| 89 |
+
_amp_dtype_cache = torch.float16
|
| 90 |
+
return _amp_dtype_cache
|
| 91 |
+
|
| 92 |
+
# For backwards compatibility, amp_dtype is accessed via property-like usage
|
| 93 |
+
# but we keep a module-level name that can be imported (defaults to float16, updated on first GPU use)
|
| 94 |
+
amp_dtype = torch.float16 # safe default; actual dtype picked at runtime via _get_amp_dtype()
|
| 95 |
|
| 96 |
# --- PIL safety/verbosity tweaks ---
|
| 97 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
|
|
| 1070 |
}
|
| 1071 |
masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
|
| 1072 |
|
| 1073 |
+
with torch.amp.autocast('cuda', dtype=_get_amp_dtype()):
|
| 1074 |
z_fused, z_views_dict, W = model(views, masks)
|
| 1075 |
|
| 1076 |
Z_all, Y_all, G_all = [], [], []
|
train_style_ddp.py
CHANGED
|
@@ -74,10 +74,24 @@ torch.backends.cudnn.benchmark = True
|
|
| 74 |
if hasattr(torch, "set_float32_matmul_precision"):
|
| 75 |
torch.set_float32_matmul_precision("high")
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# --- PIL safety/verbosity tweaks ---
|
| 83 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
@@ -1056,7 +1070,7 @@ def ddp_train_worker(rank: int, world_size: int):
|
|
| 1056 |
}
|
| 1057 |
masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
|
| 1058 |
|
| 1059 |
-
with torch.amp.autocast('cuda', dtype=
|
| 1060 |
z_fused, z_views_dict, W = model(views, masks)
|
| 1061 |
|
| 1062 |
Z_all, Y_all, G_all = [], [], []
|
|
|
|
| 74 |
if hasattr(torch, "set_float32_matmul_precision"):
|
| 75 |
torch.set_float32_matmul_precision("high")
|
| 76 |
|
| 77 |
+
# Lazy amp_dtype detection to avoid CUDA init at import time (required for HF Spaces ZeroGPU)
|
| 78 |
+
_amp_dtype_cache = None
|
| 79 |
+
|
| 80 |
+
def _get_amp_dtype():
|
| 81 |
+
global _amp_dtype_cache
|
| 82 |
+
if _amp_dtype_cache is None:
|
| 83 |
+
try:
|
| 84 |
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
| 85 |
+
_amp_dtype_cache = torch.bfloat16
|
| 86 |
+
else:
|
| 87 |
+
_amp_dtype_cache = torch.float16
|
| 88 |
+
except Exception:
|
| 89 |
+
_amp_dtype_cache = torch.float16
|
| 90 |
+
return _amp_dtype_cache
|
| 91 |
+
|
| 92 |
+
# For backwards compatibility, amp_dtype is accessed via property-like usage
|
| 93 |
+
# but we keep a module-level name that can be imported (defaults to float16, updated on first GPU use)
|
| 94 |
+
amp_dtype = torch.float16 # safe default; actual dtype picked at runtime via _get_amp_dtype()
|
| 95 |
|
| 96 |
# --- PIL safety/verbosity tweaks ---
|
| 97 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
|
|
| 1070 |
}
|
| 1071 |
masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
|
| 1072 |
|
| 1073 |
+
with torch.amp.autocast('cuda', dtype=_get_amp_dtype()):
|
| 1074 |
z_fused, z_views_dict, W = model(views, masks)
|
| 1075 |
|
| 1076 |
Z_all, Y_all, G_all = [], [], []
|
webui_gradio.py
CHANGED
|
@@ -15,6 +15,9 @@ try:
|
|
| 15 |
except Exception: # noqa: BLE001
|
| 16 |
spaces = None
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
def _patch_fastapi_starlette_middleware_unpack() -> None:
|
| 19 |
"""
|
| 20 |
Work around FastAPI/Starlette version mismatches where Starlette's Middleware
|
|
@@ -230,6 +233,11 @@ def load_all(ckpt_path: str, proto_path: str, device: str) -> str:
|
|
| 230 |
return "❌ No checkpoint selected."
|
| 231 |
if not proto_path:
|
| 232 |
return "❌ No prototype DB selected."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
try:
|
| 234 |
lm = load_style_model(ckpt_path, device=device)
|
| 235 |
db = load_prototype_db(proto_path, try_dataset_dir=str(ROOT / "dataset"))
|
|
@@ -459,12 +467,12 @@ if __name__ == "__main__":
|
|
| 459 |
_patch_fastapi_starlette_middleware_unpack()
|
| 460 |
|
| 461 |
try:
|
| 462 |
-
_launch_compat(demo, server_name=args.host, server_port=args.port, show_api=False, share=args.share)
|
| 463 |
except ValueError as e:
|
| 464 |
# Some environments block localhost checks; fall back to share link.
|
| 465 |
msg = str(e)
|
| 466 |
if "localhost is not accessible" in msg and not args.share:
|
| 467 |
-
_launch_compat(demo, server_name=args.host, server_port=args.port, show_api=False, share=True)
|
| 468 |
else:
|
| 469 |
raise
|
| 470 |
|
|
|
|
| 15 |
except Exception: # noqa: BLE001
|
| 16 |
spaces = None
|
| 17 |
|
| 18 |
+
# Detect if running on HF Spaces (ZeroGPU requires special handling)
|
| 19 |
+
_ON_SPACES = bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
|
| 20 |
+
|
| 21 |
def _patch_fastapi_starlette_middleware_unpack() -> None:
|
| 22 |
"""
|
| 23 |
Work around FastAPI/Starlette version mismatches where Starlette's Middleware
|
|
|
|
| 233 |
return "❌ No checkpoint selected."
|
| 234 |
if not proto_path:
|
| 235 |
return "❌ No prototype DB selected."
|
| 236 |
+
|
| 237 |
+
# Force CPU on HF Spaces (ZeroGPU doesn't allow CUDA init in main process)
|
| 238 |
+
if _ON_SPACES:
|
| 239 |
+
device = "cpu"
|
| 240 |
+
|
| 241 |
try:
|
| 242 |
lm = load_style_model(ckpt_path, device=device)
|
| 243 |
db = load_prototype_db(proto_path, try_dataset_dir=str(ROOT / "dataset"))
|
|
|
|
| 467 |
_patch_fastapi_starlette_middleware_unpack()
|
| 468 |
|
| 469 |
try:
|
| 470 |
+
_launch_compat(demo, server_name=args.host, server_port=args.port, show_api=False, share=args.share, ssr_mode=False)
|
| 471 |
except ValueError as e:
|
| 472 |
# Some environments block localhost checks; fall back to share link.
|
| 473 |
msg = str(e)
|
| 474 |
if "localhost is not accessible" in msg and not args.share:
|
| 475 |
+
_launch_compat(demo, server_name=args.host, server_port=args.port, show_api=False, share=True, ssr_mode=False)
|
| 476 |
else:
|
| 477 |
raise
|
| 478 |
|