Skip to content

Commit a04be2d

Browse files
committed
is_tensor_stream_capturing instead cudaStreamIsCapturing
1 parent fef447e commit a04be2d

File tree

4 files changed

+36
-5
lines changed

4 files changed

+36
-5
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
552552

553553
void free_shared_buffer(int64_t buffer);
554554

555+
bool is_tensor_stream_capturing(paddle::Tensor& input, int64_t _fa);
556+
555557
// speculative decoding Kernel
556558
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
557559
const paddle::Tensor& input_ids,
@@ -1103,6 +1105,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
11031105

11041106
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
11051107

1108+
m.def("is_tensor_stream_capturing", &is_tensor_stream_capturing, "get tensor stream is in capturing");
1109+
11061110
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
11071111

11081112
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");

custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,8 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
163163
void free_shared_buffer(fptr_t buffer) {
164164
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
165165
}
166+
167+
bool is_tensor_stream_capturing(paddle::Tensor& input, fptr_t _fa) {
168+
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
169+
return fa->is_tensor_stream_capturing(input);
170+
}

custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,21 @@ class CustomAllreduce {
441441
graph_unreg_buffers_.clear();
442442
}
443443

444+
/**
445+
* Paddle GPU Tensor.stream() is cudaStreamCaptureStatusActive.
446+
*/
447+
bool is_tensor_stream_capturing(paddle::Tensor& input)
448+
{
449+
auto stream = input.stream();
450+
cudaStreamCaptureStatus status;
451+
CUDACHECK(cudaStreamIsCapturing(stream, &status));
452+
if (status == cudaStreamCaptureStatusActive) {
453+
return true;
454+
} else {
455+
return false;
456+
}
457+
}
458+
444459
/**
445460
* Performs allreduce, assuming input has already been registered.
446461
*

fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
meta_size,
3232
register_buffer,
3333
register_graph_buffers,
34+
is_tensor_stream_capturing,
3435
)
3536

3637
try:
@@ -163,6 +164,15 @@ def all_reduce(
163164
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
164165
return out
165166

167+
def iscapturing(
168+
self,
169+
input: paddle.Tensor,
170+
):
171+
"""
172+
get tensor stream is in capturing
173+
"""
174+
return is_tensor_stream_capturing(self._ptr, input)
175+
166176
def start_capture(self):
167177
"""
168178
set CUDA graph flag: True.
@@ -207,11 +217,8 @@ def register_graph_buffers(self):
207217
def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
208218
"""The main allreduce API that provides support for cuda graph."""
209219
if self.capturing:
210-
lib = cuda_wrapper.CudaRTLibrary()
211-
stream = paddle.device.current_stream()
212-
stream_capturing = lib.cudaStreamIsCapturing(stream)
213-
if stream_capturing.value == 1:
214-
# 1 is cudaStreamCaptureStatusActive: The stream is capturing.
220+
if self.iscapturing(input):
221+
# The input stream is capturing.
215222
return self.all_reduce(input, input, registered=True)
216223
else:
217224
# If warm up, mimic the allocation pattern since custom

0 commit comments

Comments
 (0)