Skip to content

Commit d407eb4

Browse files
authored
fix: support zero-sized arrays as input to StridedMemoryView.from_cuda_array_interface (#1397)
* test: add failing test case for zero-sized array input to `StridedMemoryView.from_cuda_array_interface` * fix: only pluck out device id if the data buffer pointer coming from CAI is valid * chore: set zero-sized array device to the current device * docs: add note about changes * docs: clarify
1 parent 8006fb6 commit d407eb4

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,13 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
582582
buf.dl_tensor = NULL
583583
buf.ptr, buf.readonly = cai_data["data"]
584584
buf.is_device_accessible = True
585-
buf.device_id = handle_return(
586-
driver.cuPointerGetAttribute(
587-
driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
588-
buf.ptr))
585+
if buf.ptr != 0:
586+
buf.device_id = handle_return(
587+
driver.cuPointerGetAttribute(
588+
driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
589+
buf.ptr))
590+
else:
591+
buf.device_id = handle_return(driver.cuCtxGetDevice())
589592

590593
cdef intptr_t producer_s, consumer_s
591594
stream_ptr = int(stream_ptr)

cuda_core/docs/source/release/0.5.0-notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ None.
6161
Fixes and enhancements
6262
----------------------
6363

64+
- Zero-size arrays are now supported as inputs when constructing ``StridedMemoryView``.
6465
- Most CUDA resources can be hashed now.
6566
- Python ``bool`` objects are now converted to C++ ``bool`` type when passed as kernel arguments (previously converted to ``int``).
6667
- Restored v0.3.x :class:`MemoryResource` behaviors and missing MR attributes for backward compatibility.

cuda_core/tests/test_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,23 @@ def test_view_sliced_external_negative_offset(stride_order, view_as):
413413
assert sliced_view._layout.itemsize == a_sliced.itemsize == layout.itemsize
414414
assert sliced_view.shape == a_sliced.shape
415415
assert sliced_view._layout.strides_in_bytes == a_sliced.strides
416+
417+
418+
@pytest.mark.parametrize(
419+
"api",
420+
[
421+
StridedMemoryView.from_dlpack,
422+
StridedMemoryView.from_cuda_array_interface,
423+
],
424+
)
425+
@pytest.mark.parametrize("shape", [(0,), (0, 0), (0, 0, 0)])
426+
@pytest.mark.parametrize("dtype", [np.int64, np.uint8, np.float64])
427+
def test_view_zero_size_array(api, shape, dtype):
428+
cp = pytest.importorskip("cupy")
429+
430+
x = cp.empty(shape, dtype=dtype)
431+
smv = api(x, stream_ptr=0)
432+
433+
assert smv.size == 0
434+
assert smv.shape == shape
435+
assert smv.dtype == np.dtype(dtype)

0 commit comments

Comments
 (0)