Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion shap_e/diffusion/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@ def __call__(self, x, ts, **kwargs):

def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
UPDATED to include PyTorch MPS compatibility for Apple Silicon.
Extract values from a 1-D numpy array for a batch of indices.

:param arr: the 1-D numpy array.
Expand All @@ -1065,7 +1066,16 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
res = None
is_mps = hasattr(timesteps, "device") and timesteps.device.type == "mps"

if is_mps:
# Create the tensor on the CPU and perform indexing there
cpu_tensor = th.from_numpy(arr.copy().astype(np.float32))[timesteps.cpu()].clone().contiguous().float()
# Then transfer the result to the target device (e.g., MPS)
res = cpu_tensor.to(device=timesteps.device)
else:
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
Expand Down
19 changes: 14 additions & 5 deletions shap_e/examples/encode_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "65b721ee",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -15,16 +16,22 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "3088e7fd",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
"device = torch.device(\n",
" 'mps' if torch.backends.mps.is_available()\n",
" else 'cuda' if torch.cuda.is_available()\n",
" else 'cpu'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48ecf7fd",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -33,7 +40,8 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "cc779df5",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -54,6 +62,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d0a488c0",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -85,7 +94,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.13.2"
}
},
"nbformat": 4,
Expand Down
16 changes: 14 additions & 2 deletions shap_e/examples/sample_image_to_3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
"device = torch.device(\n",
" 'mps' if torch.backends.mps.is_available()\n",
" else 'cuda' if torch.cuda.is_available()\n",
" else 'cpu'\n",
")"
]
},
{
Expand Down Expand Up @@ -83,6 +87,14 @@
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
" display(gif_widget(images))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e8ee2b6-3534-4c89-bd88-eb8f24ca1496",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -101,7 +113,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.13.2"
}
},
"nbformat": 4,
Expand Down
16 changes: 14 additions & 2 deletions shap_e/examples/sample_text_to_3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
"device = torch.device(\n",
" 'mps' if torch.backends.mps.is_available()\n",
" else 'cuda' if torch.cuda.is_available()\n",
" else 'cpu'\n",
")"
]
},
{
Expand Down Expand Up @@ -98,6 +102,14 @@
" with open(f'example_mesh_{i}.obj', 'w') as f:\n",
" t.write_obj(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1752a20a-e83d-4450-88c9-640b7e25bb7c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -116,7 +128,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.13.2"
}
},
"nbformat": 4,
Expand Down