Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import sys | |
| import shutil | |
| # ============ Configuration ============ | |
| REPO_URL = "https://github.com/facebookresearch/sam-3d-objects.git" | |
| REPO_DIR = "/home/user/app/sam-3d-objects" | |
| # ============ Install Dependencies & Setup ============ | |
| def patch_pyproject_toml(): | |
| """ | |
| Removes 'bpy==4.3.0' from pyproject.toml to prevent installation failures. | |
| """ | |
| print("Patching pyproject.toml to remove strict bpy dependency...") | |
| pyproject_path = os.path.join(REPO_DIR, "pyproject.toml") | |
| if os.path.exists(pyproject_path): | |
| with open(pyproject_path, "r") as f: | |
| content = f.read() | |
| # Remove dependency entries for bpy | |
| new_content = content.replace('"bpy==4.3.0",', '') | |
| new_content = new_content.replace("'bpy==4.3.0',", '') | |
| with open(pyproject_path, "w") as f: | |
| f.write(new_content) | |
| print("Patch applied successfully.") | |
| else: | |
| print(f"Warning: {pyproject_path} not found. Skipping patch.") | |
| def install_dependencies(): | |
| """ | |
| Installs the specific list of dependencies requested by the user individually. | |
| """ | |
| print("Starting manual installation sequence...") | |
| # 1. Basic PIP Upgrade | |
| env = os.environ.copy() | |
| env["PIP_EXTRA_INDEX_URL"] = "https://pypi.ngc.nvidia.com https://download.pytorch.org/whl/cu121" | |
| env["PIP_FIND_LINKS"] = "https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html" | |
| env["CUDA_HOME"] = "/usr/local/cuda" | |
| env["MAX_JOBS"] = "4" | |
| subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], env=env, check=True) | |
| # 2. User Defined Package List | |
| packages = [ | |
| # MUST INSTALL FIRST: PyTorch (required for building flash_attn, pytorch3d, etc.) | |
| "torch", "torchvision", | |
| # Core ML & Hydra | |
| "hydra-core", "hydra-submitit-launcher", "omegaconf", "numpy", "einops", | |
| "einops-exts", "timm", "diffusers", "transformers", "accelerate", "safetensors", | |
| # Testing & Dev Tools | |
| "pytest", "pipdeptree", "findpydeps", "lovely_tensors", "autoflake", | |
| "black", "flake8", "usort", | |
| # Visualization & UI | |
| "seaborn", "gradio", "tensorboard", "wandb", "polyscope", | |
| # 3D & Graphics | |
| "open3d", "pyrender", "point-cloud-utils", | |
| "pymeshfix", "xatlas", "panda3d-gltf", "fvcore", "roma", "smplx", | |
| "OpenEXR", "imath", | |
| # Video & Audio | |
| "av", "decord", "librosa", | |
| # CUDA & GPU | |
| "cuda-python", "nvidia-cuda-nvcc-cu12", "nvidia-pyindex", "spconv-cu121", | |
| "xformers", "torchaudio", | |
| # ML Optimization | |
| "auto_gptq", "bitsandbytes", "peft", "optimum", "optree", "lightning", | |
| "sentence-transformers", | |
| # Data & Serialization | |
| "h5py", "fastavro", "jsonlines", "jsonpickle", "orjson", "simplejson", | |
| "webdataset", | |
| # Image Processing | |
| "opencv-python", "scikit-image", "pycocotools", "ftfy", | |
| # Web & Networking | |
| "Flask", "Werkzeug", "hdfs", "httplib2", "PySocks", "gdown", | |
| # Utilities | |
| "astor", "async-timeout", "colorama", "deprecation", "easydict", | |
| "exceptiongroup", "fasteners", "loguru", "objsize", "randomname", | |
| "rootutils", "Rtree", "tomli", | |
| # JSON Schema & URI | |
| "fqdn", "isoduration", "jsonpointer", "uri-template", "webcolors", | |
| # Graph & Docs | |
| "igraph", "pydot", "pdoc3", | |
| # Jupyter & Misc | |
| "jupyter", "dataclasses", "crcmod", "conda-pack", "pip-system-certs", | |
| "python-pycg", "pymongo", "sagemaker", "mosaicml-streaming", "bpy", | |
| # Git installs (pytorch3d requires torch, so it's near the end) | |
| "git+https://github.com/nerfstudio-project/gsplat.git", | |
| "git+https://github.com/facebookresearch/pytorch3d.git", | |
| "git+https://github.com/microsoft/MoGe.git", | |
| "utils3d", | |
| ] | |
| # Packages that require special handling - install with --no-build-isolation | |
| # so they can find torch during build | |
| special_packages = [ | |
| "flash_attn", | |
| "kaolin", | |
| ] | |
| for pkg in packages: | |
| print(f"----------------------------------------") | |
| print(f"Installing: {pkg}") | |
| print(f"----------------------------------------") | |
| try: | |
| cmd = [sys.executable, "-m", "pip", "install", pkg] | |
| subprocess.run(cmd, env=env, check=True) | |
| except subprocess.CalledProcessError as e: | |
| print(f"β Failed to install {pkg}. Continuing to next package...") | |
| # Install packages that need --no-build-isolation (require torch at build time) | |
| for pkg in special_packages: | |
| print(f"----------------------------------------") | |
| print(f"Installing (no-build-isolation): {pkg}") | |
| print(f"----------------------------------------") | |
| try: | |
| cmd = [sys.executable, "-m", "pip", "install", "--no-build-isolation", pkg] | |
| subprocess.run(cmd, env=env, check=True) | |
| except subprocess.CalledProcessError as e: | |
| print(f"β Failed to install {pkg}. Continuing to next package...") | |
| # 3. Clone & Install Main Repo (SAM 3D Objects) | |
| if not os.path.exists(REPO_DIR): | |
| print(f"Cloning repository to {REPO_DIR}...") | |
| subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) | |
| os.chdir(REPO_DIR) | |
| patch_pyproject_toml() | |
| print("Installing sam-3d-objects in editable mode...") | |
| subprocess.run([sys.executable, "-m", "pip", "install", "--no-deps", "-e", "."], env=env, check=True) | |
| # 4. Apply Hydra Patch | |
| patch_script = os.path.join(REPO_DIR, "patching", "hydra") | |
| if os.path.exists(patch_script): | |
| print("Applying Hydra patch...") | |
| subprocess.run(["chmod", "+x", patch_script], check=True) | |
| subprocess.run([patch_script], check=True) | |
| # Run installation | |
| install_dependencies() | |
| # Add repo to Python path | |
| if REPO_DIR not in sys.path: | |
| sys.path.insert(0, REPO_DIR) | |
| # Set environment variables required for runtime | |
| os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
| os.environ["LIDRA_SKIP_INIT"] = "true" | |
| os.environ["PYTORCH3D_NO_CUDA_CHECK"] = "1" | |
| # ============ Imports ============ | |
| import spaces | |
| import builtins | |
| from typing import Optional, List, Callable | |
| from copy import deepcopy | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import math | |
| from omegaconf import OmegaConf, DictConfig, ListConfig | |
| from hydra.utils import instantiate, get_method | |
| # Lazy imports placeholder | |
| _sam3d_imported = False | |
| _pipeline = None | |
| # ============ Security / Config Filters ============ | |
| WHITELIST_FILTERS = [ | |
| lambda target: target.split(".", 1)[0] in {"sam3d_objects", "torch", "torchvision", "moge"}, | |
| ] | |
| BLACKLIST_FILTERS = [ | |
| lambda target: get_method(target) | |
| in { | |
| builtins.exec, builtins.eval, builtins.__import__, | |
| os.kill, os.system, os.putenv, os.remove, os.removedirs, | |
| os.rmdir, os.fchdir, os.setuid, os.fork, os.forkpty, | |
| os.killpg, os.rename, os.renames, os.truncate, os.replace, | |
| os.unlink, os.fchmod, os.fchown, os.chmod, os.chown, | |
| os.chroot, os.lchown, os.getcwd, os.chdir, | |
| shutil.rmtree, shutil.move, shutil.chown, | |
| subprocess.Popen, builtins.help, | |
| }, | |
| ] | |
| def check_target(target: str, whitelist_filters: List[Callable], blacklist_filters: List[Callable]): | |
| if any(filt(target) for filt in whitelist_filters): | |
| if not any(filt(target) for filt in blacklist_filters): | |
| return | |
| raise RuntimeError(f"target '{target}' is not allowed") | |
| def check_hydra_safety(config: DictConfig, whitelist_filters: List[Callable], blacklist_filters: List[Callable]): | |
| to_check = [config] | |
| while to_check: | |
| node = to_check.pop() | |
| if isinstance(node, DictConfig): | |
| to_check.extend(list(node.values())) | |
| if "_target_" in node: | |
| check_target(node["_target_"], whitelist_filters, blacklist_filters) | |
| elif isinstance(node, ListConfig): | |
| to_check.extend(list(node)) | |
| # ============ Lazy Loading & Model Logic ============ | |
| def lazy_import_sam3d(): | |
| """Import sam3d modules lazily after GPU is available.""" | |
| global _sam3d_imported | |
| if not _sam3d_imported: | |
| global utils3d, sam3d_objects, InferencePipelinePointMap, render_utils, SceneVisualizer | |
| global quaternion_multiply, quaternion_invert | |
| try: | |
| import utils3d as _utils3d | |
| utils3d = _utils3d | |
| import sam3d_objects as _sam3d_objects | |
| sam3d_objects = _sam3d_objects | |
| from sam3d_objects.pipeline.inference_pipeline_pointmap import InferencePipelinePointMap as _IPP | |
| InferencePipelinePointMap = _IPP | |
| from sam3d_objects.model.backbone.tdfy_dit.utils import render_utils as _ru | |
| render_utils = _ru | |
| from sam3d_objects.utils.visualization import SceneVisualizer as _SV | |
| SceneVisualizer = _SV | |
| from pytorch3d.transforms import quaternion_multiply as _qm, quaternion_invert as _qi | |
| quaternion_multiply, quaternion_invert = _qm, _qi | |
| _sam3d_imported = True | |
| except ImportError as e: | |
| print(f"Failed to import SAM 3D modules: {e}") | |
| print("Ensure the installation step completed successfully.") | |
| subprocess.run([sys.executable, "-m", "pipdeptree"]) | |
| raise | |
| def load_pipeline(config_file: str): | |
| """Load the inference pipeline (call inside GPU context).""" | |
| global _pipeline | |
| if _pipeline is None: | |
| lazy_import_sam3d() | |
| config = OmegaConf.load(config_file) | |
| config.rendering_engine = "pytorch3d" | |
| config.compile_model = False | |
| config.workspace_dir = os.path.dirname(config_file) | |
| check_hydra_safety(config, WHITELIST_FILTERS, BLACKLIST_FILTERS) | |
| _pipeline = instantiate(config) | |
| return _pipeline | |
| def merge_mask_to_rgba(image, mask): | |
| mask = mask.astype(np.uint8) * 255 | |
| mask = mask[..., None] | |
| return np.concatenate([image[..., :3], mask], axis=-1) | |
| def run_inference(image: np.ndarray, mask: np.ndarray, config_file: str, seed: Optional[int] = None, pointmap=None) -> dict: | |
| """GPU-decorated inference function for ZeroGPU.""" | |
| global _pipeline | |
| _pipeline = load_pipeline(config_file) | |
| if hasattr(_pipeline, 'to'): | |
| _pipeline.to('cuda') | |
| rgba_image = merge_mask_to_rgba(image, mask) | |
| return _pipeline.run( | |
| rgba_image, None, seed, | |
| stage1_only=False, | |
| with_mesh_postprocess=False, | |
| with_texture_baking=False, | |
| with_layout_postprocess=True, | |
| use_vertex_color=True, | |
| stage1_inference_steps=None, | |
| pointmap=pointmap, | |
| ) | |
| # ============ Gradio Interface ============ | |
| CONFIG_FILE = os.path.join(REPO_DIR, "configs/inference.yaml") | |
| def process_image(input_image, input_mask, seed): | |
| if input_image is None: | |
| return None, "Please provide an input image" | |
| if input_mask is None: | |
| return None, "Please provide an object mask" | |
| image = np.array(input_image) | |
| mask = np.array(input_mask.convert("L")) > 127 | |
| seed_val = int(seed) if seed else None | |
| try: | |
| result = run_inference(image, mask, CONFIG_FILE, seed_val) | |
| if "gaussian" in result and result["gaussian"]: | |
| ply_path = "/tmp/output.ply" | |
| result["gaussian"][0].save_ply(ply_path) | |
| return ply_path, "β Inference complete!" | |
| return None, "β οΈ No 3D output generated" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"β Error: {str(e)}" | |
| with gr.Blocks(title="SAM 3D Objects", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π¨ SAM 3D Objects - Single Image to 3D | |
| Upload an image and a mask to generate a 3D Gaussian Splat model. | |
| **Note:** First inference may take longer due to model loading. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| input_mask = gr.Image(label="Object Mask (white = object)", type="pil") | |
| seed = gr.Number(label="Seed (optional)", value=42, precision=0) | |
| run_btn = gr.Button("π Generate 3D", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_model = gr.Model3D(label="3D Output", clear_color=[0.1, 0.1, 0.1, 1.0]) | |
| status = gr.Textbox(label="Status", interactive=False) | |
| run_btn.click(fn=process_image, inputs=[input_image, input_mask, seed], outputs=[output_model, status]) | |
| if __name__ == "__main__": | |
| demo.launch() |