E-RayZer / app_core /engine.py
qitaoz's picture
Update app_core/engine.py
161652b verified
from __future__ import annotations
import os
import shutil
import time
import trimesh
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import yaml
from easydict import EasyDict as edict
from PIL import Image
from scipy.spatial.transform import Rotation
import erayzer_core # noqa: F401 # ensures vendored modules register themselves
import imageio.v2 as imageio
import numpy as np
@dataclass(frozen=True)
class EngineKey:
config_path: str
ckpt_path: str
device: str
def _ensure_file(path: str, label: str) -> None:
if not os.path.isfile(path):
raise FileNotFoundError(f"Missing {label}: {path}")
def _load_config(path: str) -> edict:
with open(path, "r", encoding="utf-8") as handle:
data = yaml.safe_load(handle)
return edict(data)
def add_scene_cam(scene, c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
OPENGL = np.array([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]
])
if image is not None:
H, W, THREE = image.shape
assert THREE == 3
if image.dtype != np.uint8:
image = np.uint8(255*image)
elif imsize is not None:
W, H = imsize
elif focal is not None:
H = W = focal / 1.1
else:
H = W = 1
if focal is None:
focal = min(H, W) * 1.1 # default value
elif isinstance(focal, np.ndarray):
focal = focal[0]
# create fake camera
height = focal * screen_width / H
width = screen_width * 0.5**0.5
rot45 = np.eye(4)
rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
rot45[2, 3] = -height # set the tip of the cone = optical center
aspect_ratio = np.eye(4)
aspect_ratio[0, 0] = W/H
transform = c2w @ OPENGL @ aspect_ratio @ rot45
cam = trimesh.creation.cone(width, height, sections=4)
# this is the camera mesh
rot2 = np.eye(4)
rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(4)).as_matrix()
vertices = cam.vertices
vertices_offset = 0.9 * cam.vertices
vertices = np.r_[vertices, vertices_offset, geotrf(rot2, cam.vertices)]
vertices = geotrf(transform, vertices)
faces = []
for face in cam.faces:
if 0 in face:
continue
a, b, c = face
a2, b2, c2 = face + len(cam.vertices)
# add 3 pseudo-edges
faces.append((a, b, b2))
faces.append((a, a2, c))
faces.append((c2, b, c))
faces.append((a, b2, a2))
faces.append((a2, c, c2))
faces.append((c2, b2, b))
# no culling
faces += [(c, b, a) for a, b, c in faces]
for i,face in enumerate(cam.faces):
if 0 in face:
continue
if i == 1 or i == 5:
a, b, c = face
faces.append((a, b, c))
vertices[:, [1, 2]] *= -1
cam = trimesh.Trimesh(vertices=vertices, faces=faces)
cam.visual.face_colors[:, :3] = edge_color
scene.add_geometry(cam)
def geotrf(Trf, pts, ncol=None, norm=False):
""" Apply a geometric transformation to a list of 3-D points.
H: 3x3 or 4x4 projection matrix (typically a Homography)
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
ncol: int. number of columns of the result (2 or 3)
norm: float. if != 0, the resut is projected on the z=norm plane.
Returns an array of projected 2d points.
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# optimized code
if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
Trf.ndim == 3 and pts.ndim == 4):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d+1:
pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
else:
raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
else:
if Trf.ndim >= 3:
n = Trf.ndim-2
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1]+1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
class ERayZerEngine:
"""Thin wrapper around the E-RayZer model for single-scene inference."""
def __init__(self, config_path: str, ckpt_path: str, device: str, output_root: str) -> None:
_ensure_file(config_path, "config")
_ensure_file(ckpt_path, "checkpoint")
os.makedirs(output_root, exist_ok=True)
self.output_root = output_root
self.device_name = device or "auto"
self.device = torch.device(self.device_name if self.device_name != "auto" else self._default_device())
self.config = _load_config(config_path)
self.ckpt_path = ckpt_path
self._prepare_config()
self.model = self._load_model()
self.model.eval()
training = self.config.training
tokenizer = self.config.model.image_tokenizer
self.image_size = int(tokenizer.image_size)
self.num_views = int(training.num_views)
self.num_input_views = int(training.num_input_views)
self.num_target_views = int(training.num_target_views)
def _central_crop(img: Image.Image) -> Image.Image:
shorter_side = min(img.size)
return TF.center_crop(img, shorter_side)
self.transform = T.Compose(
[
T.Lambda(_central_crop),
T.Resize((self.image_size, self.image_size), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
T.ToTensor(),
]
)
amp_dtype = str(training.get("amp_dtype", "fp16")).lower()
self.amp_dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16
self.amp_enabled = bool(training.get("use_amp", True)) and self.device.type == "cuda"
def _prepare_config(self) -> None:
cfg = self.config
cfg.inference = True
cfg.evaluation = False
cfg.create_visual = True
training = cfg.training
training.batch_size_per_gpu = 1
training.num_workers = 0
training.prefetch_factor = training.get("prefetch_factor", 2)
training.random_inputs = False
training.random_shuffle = False
training.force_resume_ckpt = True
training.resume_ckpt = self.ckpt_path
training.view_selector = edict(training.get("view_selector", {}))
training.view_selector.type = training.view_selector.get("type", "even_I_B")
training.view_selector.use_curriculum = False
cfg.inference_view_selector_type = cfg.get("inference_view_selector_type", training.view_selector.type)
def _load_model(self) -> torch.nn.Module:
module_name, class_name = self.config.model.class_name.rsplit(".", 1)
ModelClass = __import__(module_name, fromlist=[class_name]).__dict__[class_name]
model = ModelClass(self.config).to(self.device)
checkpoint = torch.load(self.ckpt_path, map_location=self.device)
state_dict = checkpoint.get("model", checkpoint)
incompatible = model.load_state_dict(state_dict, strict=False)
if incompatible.missing_keys:
print(f"[ERayZerEngine] Missing keys: {len(incompatible.missing_keys)}")
if incompatible.unexpected_keys:
print(f"[ERayZerEngine] Unexpected keys: {len(incompatible.unexpected_keys)}")
print("[ERayZerEngine] Model loaded successfully.")
return model
@staticmethod
def _default_device() -> str:
return "cuda:0" if torch.cuda.is_available() else "cpu"
def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
array = tensor.permute(1, 2, 0).cpu().numpy()
array = (array * 255.0).round().astype("uint8")
return Image.fromarray(array)
def _prepare_batch(self, image_files: Sequence[str]) -> Dict[str, torch.Tensor]:
if len(image_files) != self.num_views:
print(f"Warning: expected {self.num_views} views, but got {len(image_files)}; padding inputs to {self.num_views} views.")
tensors: List[torch.Tensor] = []
for path in sorted(image_files, key=os.path.basename):
img = Image.open(path).convert("RGB")
tensors.append(self.transform(img))
images = torch.stack(tensors, dim=0).unsqueeze(0)
intrinsics = torch.tensor(
[[[1.0, 1.0, 0.5, 0.5]] * self.num_views], dtype=torch.float32
)
return {"image": images, "fxfycxcy": intrinsics}
def _move_to_device(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {
key: value.to(self.device, non_blocking=self.device.type == "cuda") if torch.is_tensor(value) else value
for key, value in batch.items()
}
def run(
self, image_files: Sequence[str]
) -> Tuple[List[str], str, str, Optional[str], Optional[str]]:
batch = self._prepare_batch(image_files)
batch_gpu = self._move_to_device(batch)
autocast_ctx = (
torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.amp_enabled)
if self.device.type == "cuda"
else nullcontext()
)
with torch.no_grad():
with autocast_ctx:
result = self.model(batch_gpu)
run_dir, glb_path, video_path = self._export_outputs(result)
gallery_paths = sorted(
[os.path.join(run_dir, name) for name in os.listdir(run_dir) if name.startswith("pred_view_")]
)
archive = shutil.make_archive(run_dir, "zip", run_dir)
log = (
f"Saved {len(gallery_paths)} predicted views and Gaussian assets to {run_dir}.\n"
f"Archive: {archive}"
)
return gallery_paths, archive, log, glb_path, video_path
def _export_outputs(self, result) -> Tuple[str, Optional[str], Optional[str]]:
timestamp = time.strftime("%Y%m%d-%H%M%S")
run_dir = os.path.join(self.output_root, timestamp)
os.makedirs(run_dir, exist_ok=True)
glb_path: Optional[str] = None
video_path: Optional[str] = None
if getattr(result, "render") is not None:
render_tensor = result.render.detach().cpu().clamp(0, 1)
for idx, frame in enumerate(render_tensor[0]):
img = self._tensor_to_pil(frame)
img.save(os.path.join(run_dir, f"pred_view_{idx:02d}.png"))
if hasattr(result, "pixelalign_xyz") is not None:
glb_path = os.path.join(run_dir, "point_cloud.glb")
scene = trimesh.Scene()
xyzs = result.pixelalign_xyz[0].detach().cpu().permute(0, 2, 3, 1).reshape(-1, 3).numpy()
xyzs[:, [1, 2]] *= -1
rgbs = (result.image[0].detach().cpu().permute(0, 2, 3, 1).reshape(-1, 3) * 255.0).round().numpy().astype(np.uint8)
point_cloud = trimesh.points.PointCloud(vertices=xyzs, colors=rgbs)
scene.add_geometry(point_cloud)
c2ws = result.c2w[0].detach().cpu().numpy()
num_images = c2ws.shape[0]
cmap = plt.get_cmap("hsv")
for i, c2w in enumerate(c2ws):
color_rgb = (np.array(cmap(i / num_images))[:3] * 255).astype(int)
add_scene_cam(
scene=scene,
c2w=c2w,
edge_color=color_rgb,
image=None,
focal=None,
imsize=(256, 256),
screen_width=0.1
)
scene.export(glb_path)
if getattr(result, "render_video") is not None:
frames_dir = os.path.join(run_dir, "render_video_frames")
os.makedirs(frames_dir, exist_ok=True)
frames = result.render_video[0].detach().cpu().clamp(0, 1)
for idx, frame in enumerate(frames):
img = self._tensor_to_pil(frame)
img.save(os.path.join(frames_dir, f"frame_{idx:03d}.png"))
frames_np = (frames.permute(0, 2, 3, 1).numpy() * 255.0).round().astype(np.uint8)
video_path = os.path.join(run_dir, "render_video.mp4")
imageio.mimwrite(video_path, frames_np, fps=24)
return run_dir, glb_path, video_path
_ENGINE_CACHE: Dict[EngineKey, ERayZerEngine] = {}
def get_engine(config_path: str, ckpt_path: str, device: str, output_root: str) -> ERayZerEngine:
key = EngineKey(config_path, ckpt_path, device or "auto")
if key not in _ENGINE_CACHE:
_ENGINE_CACHE[key] = ERayZerEngine(config_path, ckpt_path, device, output_root)
return _ENGINE_CACHE[key]