Skip to content

Commit

Permalink
Allow obtaining cuda stream handle from PyTorch stream when launching…
Browse files Browse the repository at this point in the history
… kernel (#297)

Use `cuda_stream` attribute of a torch stream if the stream is not an
instance of the cupy stream.
  • Loading branch information
aashaka authored May 4, 2024
1 parent 6226556 commit 0650371
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/mscclpp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def launch_kernel(
],
dtype=np.uint64,
)
cuda_stream = stream.ptr if stream else 0
cuda_stream = 0
if stream:
cuda_stream = stream.ptr if isinstance(stream, cp.cuda.Stream) else stream.cuda_stream
cp.cuda.driver.launchKernel(
self._kernel, nblocks, 1, 1, nthreads, 1, 1, shared, cuda_stream, 0, config.ctypes.data
)
Expand Down

0 comments on commit 0650371

Please sign in to comment.