Skip to content

Add custom op declaration for all_reduce #3473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);

void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);

void dispose(int64_t _fa);
Expand Down
11 changes: 10 additions & 1 deletion custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto stream = inp.stream();
Expand Down Expand Up @@ -163,3 +163,12 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
void free_shared_buffer(fptr_t buffer) {
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}


PD_BUILD_STATIC_OP(all_reduce)
.Inputs({"inp",
"out"})
.Outputs({"new_out"})
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
Comment on lines +169 to +172
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些名字感觉有点奇怪

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实,但他原来就叫这个名字🤔

.SetInplaceMap({{"out", "new_out"}})
.SetKernelFn(PD_KERNEL(all_reduce));
33 changes: 14 additions & 19 deletions fastdeploy/distributed/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,17 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)


try:

@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(
input_: paddle.Tensor,
) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
_TP_AR.custom_all_reduce(input_)
elif paddle.in_dynamic_mode():
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)

except:
tensor_model_parallel_all_reduce = None
@paddle.jit.marker.unified
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前的try是为了适配rl,为啥去掉呢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里能解释下为什么 def 一个函数是有可能报错的么?以我浅薄的认知,这里不会有问题,如果有问题麻烦提供下报错呢,按我理解一定会有更合适的处理方式

Copy link
Contributor Author

@DrRyanHuang DrRyanHuang Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是因为旧版本Paddle版本没有 paddle.jit.marker.unified 才加的 try except 语句呀

或者可以这样:

if hasattr(paddle.jit, "marker") and (hasattr, paddle.jit.marker, "unified"):
    mark_as_unified = paddle.jit.marker.unified
else:
    # do-nothing for PaddlePaddle 3.1-
    mark_as_unified = lamdba fn: fn
tensor_model_parallel_all_reduce = mark_as_unified(tensor_model_parallel_all_reduce)

目前RL的开发是基于什么paddle版本呀?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RL的训练代码(使用fleet分支paddle,比较老)会import FD组网,这里要恢复为之前的样子

def tensor_model_parallel_all_reduce(
input_: paddle.Tensor,
) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
_TP_AR.custom_all_reduce(input_)
elif paddle.in_dynamic_mode():
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
4 changes: 2 additions & 2 deletions fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def all_reduce(
if out is None:
out = paddle.empty_like(inp)
if registered:
all_reduce(self._ptr, inp, out, 0, 0)
all_reduce(inp, out, self._ptr, 0, 0)
else:
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
all_reduce(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
return out

def start_capture(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class MLAAttentionMetadata(AttentionMetadata):
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)

max_enc_len_this_time: Optional[paddle.Tensor] = None
max_dec_len_this_time: Optional[paddle.Tensor] = None


class MLAAttentionBackend(AttentionBackend):
"""
Expand Down
Loading