diff --git a/README.md b/README.md index 446dae32b..5fa1cb173 100644 --- a/README.md +++ b/README.md @@ -42,15 +42,36 @@ torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True) # Trigge ```python import torch +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + repo = "isl-org/ZoeDepth" # Zoe_N -model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=True) +model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=False) +pretrained_dict = torch.hub.load_state_dict_from_url('https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt', map_location=DEVICE) +model_zoe_n.load_state_dict(pretrained_dict['model'], strict=False) +for b in model_zoe_n.core.core.pretrained.model.blocks: + b.drop_path = torch.nn.Identity() + +zoe = model_zoe_n.to(DEVICE) # Zoe_K -model_zoe_k = torch.hub.load(repo, "ZoeD_K", pretrained=True) +model_zoe_k = torch.hub.load(repo, "ZoeD_K", pretrained=False) +pretrained_dict = torch.hub.load_state_dict_from_url('https://github.com/isl-org/ZoeDepth/releases/#download/v1.0/ZoeD_M12_K.pt', map_location=DEVICE) +model_zoe_k.load_state_dict(pretrained_dict['model'], strict=False) +for b in model_zoe_k.core.core.pretrained.model.blocks: + b.drop_path = torch.nn.Identity() + +zoe = model_zoe_k.to(DEVICE) # Zoe_NK -model_zoe_nk = torch.hub.load(repo, "ZoeD_NK", pretrained=True) +model_zoe_nk = torch.hub.load(repo, "ZoeD_NK", pretrained=False) +pretrained_dict = torch.hub.load_state_dict_from_url('https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt', map_location=DEVICE) +model_zoe_nk.load_state_dict(pretrained_dict['model'], strict=False) +for b in model_zoe_nk.core.core.pretrained.model.blocks: + b.drop_path = torch.nn.Identity() + +zoe = model_zoe_nk.to(DEVICE) + ``` ### Using local copy Clone this repo: @@ -87,8 +108,8 @@ model_zoe_nk = build_model(conf) ### Using ZoeD models to predict depth ```python ##### sample prediction -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -zoe = model_zoe_n.to(DEVICE) + + # Local file diff --git a/ui/app.py b/ui/app.py index e32721854..45951b72d 100644 --- a/ui/app.py +++ b/ui/app.py @@ -43,7 +43,18 @@ """ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' -model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to(DEVICE).eval() + +# https://github.com/isl-org/ZoeDepth/issues/82#issuecomment-1779799540 +repo = "isl-org/ZoeDepth" + +model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=False) +pretrained_dict = torch.hub.load_state_dict_from_url('https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt', map_location=DEVICE) +model_zoe_n.load_state_dict(pretrained_dict['model'], strict=False) +for b in model_zoe_n.core.core.pretrained.model.blocks: + b.drop_path = torch.nn.Identity() + +model = model_zoe_n +zoe = model_zoe_n.to(DEVICE) title = "# ZoeDepth" description = """Official demo for **ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth**. @@ -63,4 +74,4 @@ create_pano_to_3d_demo(model) if __name__ == '__main__': - demo.queue().launch() \ No newline at end of file + demo.queue().launch() diff --git a/ui/gradio_depth_pred.py b/ui/gradio_depth_pred.py index fb875451e..2b2226f49 100644 --- a/ui/gradio_depth_pred.py +++ b/ui/gradio_depth_pred.py @@ -34,7 +34,7 @@ def predict_depth(model, image): def create_demo(model): gr.Markdown("### Depth Prediction demo") with gr.Row(): - input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input').style(height="auto") + input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input') depth_image = gr.Image(label="Depth Map", elem_id='img-display-output') raw_file = gr.File(label="16-bit raw depth, multiplier:256") submit = gr.Button("Submit") @@ -49,4 +49,4 @@ def on_submit(image): submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file]) # examples = gr.Examples(examples=["examples/person_1.jpeg", "examples/person_2.jpeg", "examples/person-leaves.png", "examples/living-room.jpeg"], - # inputs=[input_image]) \ No newline at end of file + # inputs=[input_image]) diff --git a/zoedepth/models/base_models/midas.py b/zoedepth/models/base_models/midas.py index e26f85895..36f6c3794 100644 --- a/zoedepth/models/base_models/midas.py +++ b/zoedepth/models/base_models/midas.py @@ -170,7 +170,7 @@ def get_size(self, width, height): def __call__(self, x): width, height = self.get_size(*x.shape[-2:][::-1]) - return nn.functional.interpolate(x, (height, width), mode='bilinear', align_corners=True) + return nn.functional.interpolate(x, (int(height), int(width)), mode='bilinear', align_corners=True) class PrepForMidas(object): def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True):