Skip to content

Commit 0d2e1b9

Browse files
committed
Address comments
1 parent 629e091 commit 0d2e1b9

File tree

4 files changed

+12
-9
lines changed

4 files changed

+12
-9
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ def create_event(self, options: Optional[EventOptions] = None) -> Event:
12281228
Newly created event object.
12291229
12301230
"""
1231-
return Event._init(options)
1231+
return Event._init(self._id, options)
12321232

12331233
@precondition(_check_context_initialized)
12341234
def allocate(self, size, stream=None) -> Buffer:

cuda_core/cuda/core/experimental/_event.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import TYPE_CHECKING, Optional
1010

1111
from cuda.core.experimental._context import Context
12-
from cuda.core.experimental._device import Device
1312
from cuda.core.experimental._utils.cuda_utils import (
1413
CUDAError,
1514
check_or_create_options,
@@ -22,6 +21,7 @@
2221

2322
if TYPE_CHECKING:
2423
import cuda.bindings
24+
from cuda.core.experimental._device import Device
2525

2626

2727
@dataclass
@@ -96,7 +96,7 @@ def __new__(self, *args, **kwargs):
9696
__slots__ = ("__weakref__", "_mnff", "_timing_disabled", "_busy_waited")
9797

9898
@classmethod
99-
def _init(cls, options: Optional[EventOptions] = None):
99+
def _init(cls, device_id: int, ctx_handle=None, options: Optional[EventOptions] = None):
100100
self = super().__new__(cls)
101101
self._mnff = Event._MembersNeededForFinalize(self, None)
102102

@@ -113,8 +113,11 @@ def _init(cls, options: Optional[EventOptions] = None):
113113
if options.support_ipc:
114114
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103")
115115
self._mnff.handle = handle_return(driver.cuEventCreate(flags))
116-
self._device_id = int(handle_return(driver.cuCtxGetDevice()))
117-
self._ctx_handle = handle_return(driver.cuStreamGetCtx(self._mnff.handle))
116+
self._device_id = device_id
117+
if ctx_handle is not None:
118+
self._ctx_handle = ctx_handle
119+
else:
120+
self._ctx_handle = handle_return(driver.cuCtxGetCurrent())
118121
return self
119122

120123
def close(self):

cuda_core/cuda/core/experimental/_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def record(self, event: Event = None, options: EventOptions = None) -> Event:
244244
# on the stream. Event flags such as disabling timing, nonblocking,
245245
# and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions.
246246
if event is None:
247-
event = Event._init(options)
247+
event = Event._init(self._device_id, self._ctx_handle, options)
248248
assert_type(event, Event)
249249
handle_return(driver.cuEventRecord(event.handle, self._mnff.handle))
250250
return event

cuda_core/tests/test_event.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,9 @@ def test_error_timing_incomplete():
177177

178178

179179
def test_event_device(init_cuda):
180-
event = Device().create_event(options=EventOptions())
181-
device = event.device
182-
assert isinstance(device, Device)
180+
device = Device()
181+
event = device.create_event(options=EventOptions())
182+
assert event.device is device
183183

184184

185185
def test_event_context(init_cuda):

0 commit comments

Comments
 (0)