Skip to content
This repository was archived by the owner on May 5, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion zoedepth/models/depth_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,37 @@ def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, out
return out_tensor.squeeze().cpu()
else:
raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'")


@torch.no_grad()
def infer_cv(self, cv_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, torch.Tensor]:
"""
Inference interface for the model for a NumPy image loaded via OpenCV (BGR format)
Args:
cv_img (np.ndarray): input image in BGR format as loaded by OpenCV
pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
output_type (str, optional): output type. Supported values are 'numpy' and 'tensor'. Defaults to "numpy".
"""
# Convert BGR to RGB
cv_img_rgb = cv_img[:, :, ::-1]

# Ensure the image is in float format (required for correct scaling)
if cv_img_rgb.dtype != np.float32:
cv_img_rgb = cv_img_rgb.astype(np.float32)
cv_img_rgb /= 255.0

# Convert the image to a PyTorch tensor, add a batch dimension, and transfer to device
x = torch.from_numpy(cv_img_rgb).permute(2, 0, 1).unsqueeze(0).to(self.device)

# Perform inference
out_tensor = self.infer(
x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs
)

# Convert the output tensor to the requested output type
if output_type == "numpy":
return out_tensor.squeeze().cpu().numpy()
elif output_type == "tensor":
return out_tensor.squeeze().cpu()
else:
raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy' and 'tensor'")