Spaces:
Runtime error
Runtime error
moved resizing to 3D model code out of depth generatino to clean architecture
Browse files
app.py
CHANGED
|
@@ -60,7 +60,7 @@ def generate_3d_model(depth, image_path, focallength_px):
|
|
| 60 |
|
| 61 |
Args:
|
| 62 |
depth (np.ndarray): 2D array representing depth in meters.
|
| 63 |
-
image_path (str): Path to the
|
| 64 |
focallength_px (float): Focal length in pixels.
|
| 65 |
|
| 66 |
Returns:
|
|
@@ -68,8 +68,16 @@ def generate_3d_model(depth, image_path, focallength_px):
|
|
| 68 |
"""
|
| 69 |
# Load the RGB image and convert to a NumPy array
|
| 70 |
image = np.array(Image.open(image_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
height, width = depth.shape
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
# Compute camera intrinsic parameters
|
| 74 |
fx = fy = focallength_px # Assuming square pixels and fx = fy
|
| 75 |
cx, cy = width / 2, height / 2 # Principal point at the image center
|
|
@@ -126,17 +134,13 @@ def predict_depth(input_image):
|
|
| 126 |
# Preprocess the image for depth prediction
|
| 127 |
result = depth_pro.load_rgb(temp_file)
|
| 128 |
|
| 129 |
-
# Add error checking for the result tuple
|
| 130 |
if len(result) < 2:
|
| 131 |
raise ValueError(f"Unexpected result from load_rgb: {result}")
|
| 132 |
|
| 133 |
-
image
|
| 134 |
-
f_px = result[-1] # Extract focal length
|
| 135 |
-
|
| 136 |
print(f"Extracted focal length: {f_px}")
|
| 137 |
|
| 138 |
-
image = transform(image)
|
| 139 |
-
image = image.to(device) # Move the image tensor to the selected device
|
| 140 |
|
| 141 |
# Run the depth prediction model
|
| 142 |
prediction = model.infer(image, f_px=f_px)
|
|
@@ -151,33 +155,13 @@ def predict_depth(input_image):
|
|
| 151 |
if depth.ndim != 2:
|
| 152 |
depth = depth.squeeze()
|
| 153 |
|
| 154 |
-
|
| 155 |
-
print(f"Original depth shape: {depth.shape}")
|
| 156 |
-
print(f"Original image shape: {image.shape}")
|
| 157 |
-
|
| 158 |
-
# Resize depth to match image dimensions
|
| 159 |
-
image_height, image_width = image.shape[2], image.shape[3]
|
| 160 |
-
depth = cv2.resize(depth, (image_width, image_height), interpolation=cv2.INTER_LINEAR)
|
| 161 |
-
|
| 162 |
-
print(f"Resized depth shape: {depth.shape}")
|
| 163 |
-
print(f"Final image shape: {image.shape}")
|
| 164 |
-
|
| 165 |
-
# No downsampling
|
| 166 |
-
downscale_factor = 1
|
| 167 |
-
|
| 168 |
-
# Convert image tensor to CPU and NumPy
|
| 169 |
-
image_np = image.cpu().detach().numpy()[0].transpose(1, 2, 0)
|
| 170 |
-
|
| 171 |
-
# No normalization of depth map as it is already in meters
|
| 172 |
-
depth_min = np.min(depth)
|
| 173 |
-
depth_max = np.max(depth)
|
| 174 |
-
depth_normalized = depth # Depth remains in meters
|
| 175 |
|
| 176 |
# Create a color map for visualization using matplotlib
|
| 177 |
plt.figure(figsize=(10, 10))
|
| 178 |
-
plt.imshow(
|
| 179 |
plt.colorbar(label='Depth [m]')
|
| 180 |
-
plt.title(f'Predicted Depth Map - Min: {
|
| 181 |
plt.axis('off') # Hide axis for a cleaner image
|
| 182 |
|
| 183 |
# Save the depth map visualization to a file
|
|
@@ -208,8 +192,9 @@ def get_last_commit_timestamp():
|
|
| 208 |
try:
|
| 209 |
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
|
| 210 |
return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
| 211 |
-
except Exception:
|
| 212 |
-
|
|
|
|
| 213 |
|
| 214 |
# Create the Gradio interface with appropriate input and output components.
|
| 215 |
last_updated = get_last_commit_timestamp()
|
|
|
|
| 60 |
|
| 61 |
Args:
|
| 62 |
depth (np.ndarray): 2D array representing depth in meters.
|
| 63 |
+
image_path (str): Path to the RGB image.
|
| 64 |
focallength_px (float): Focal length in pixels.
|
| 65 |
|
| 66 |
Returns:
|
|
|
|
| 68 |
"""
|
| 69 |
# Load the RGB image and convert to a NumPy array
|
| 70 |
image = np.array(Image.open(image_path))
|
| 71 |
+
|
| 72 |
+
# Resize depth to match image dimensions if necessary
|
| 73 |
+
if depth.shape != image.shape[:2]:
|
| 74 |
+
depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 75 |
+
|
| 76 |
height, width = depth.shape
|
| 77 |
|
| 78 |
+
print(f"3D model generation - Depth shape: {depth.shape}")
|
| 79 |
+
print(f"3D model generation - Image shape: {image.shape}")
|
| 80 |
+
|
| 81 |
# Compute camera intrinsic parameters
|
| 82 |
fx = fy = focallength_px # Assuming square pixels and fx = fy
|
| 83 |
cx, cy = width / 2, height / 2 # Principal point at the image center
|
|
|
|
| 134 |
# Preprocess the image for depth prediction
|
| 135 |
result = depth_pro.load_rgb(temp_file)
|
| 136 |
|
|
|
|
| 137 |
if len(result) < 2:
|
| 138 |
raise ValueError(f"Unexpected result from load_rgb: {result}")
|
| 139 |
|
| 140 |
+
image, _, _, _, f_px = result
|
|
|
|
|
|
|
| 141 |
print(f"Extracted focal length: {f_px}")
|
| 142 |
|
| 143 |
+
image = transform(image).to(device)
|
|
|
|
| 144 |
|
| 145 |
# Run the depth prediction model
|
| 146 |
prediction = model.infer(image, f_px=f_px)
|
|
|
|
| 155 |
if depth.ndim != 2:
|
| 156 |
depth = depth.squeeze()
|
| 157 |
|
| 158 |
+
print(f"Depth map shape: {depth.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# Create a color map for visualization using matplotlib
|
| 161 |
plt.figure(figsize=(10, 10))
|
| 162 |
+
plt.imshow(depth, cmap='gist_rainbow')
|
| 163 |
plt.colorbar(label='Depth [m]')
|
| 164 |
+
plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
|
| 165 |
plt.axis('off') # Hide axis for a cleaner image
|
| 166 |
|
| 167 |
# Save the depth map visualization to a file
|
|
|
|
| 192 |
try:
|
| 193 |
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
|
| 194 |
return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"{str(e)}")
|
| 197 |
+
return str(e)
|
| 198 |
|
| 199 |
# Create the Gradio interface with appropriate input and output components.
|
| 200 |
last_updated = get_last_commit_timestamp()
|