Skip to content

Add custom op declaration for all_reduce #3473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

DrRyanHuang
Copy link
Contributor

@DrRyanHuang DrRyanHuang commented Aug 19, 2025

Deepseek V3 开启SOT动转静推理后,存在这样一个打断:

[Translate InlineFn 3] (line 163) CALL_FUNCTION 5, stack is [BuiltinVariable(all_reduce, object_28183), 
		ConstantVariable(468079520, object_28197, _TP_AR._ptr), 
		TensorVariable([1, 7168], bfloat16, False, var_619, None, object_27861, input_), 
		TensorVariable([1, 7168], bfloat16, False, var_619, None, object_27861, input_), 
		ConstantVariable(44216352768, object_28203, _TP_AR.buffer_ptrs[2]), 
		ConstantVariable(8388608, object_28121, _TP_AR.max_size)
]
[BreakGraph] call function Break graph:
    File "/workspace/FastDeploy/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py", line 207, in custom_all_reduce encountered breakgraph error caused by
    File "/workspace/FastDeploy/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py", line 146, in all_reduce encountered breakgraph error caused by
    Not support builtin function: PyCapsule.all_reduce with args: Args(ConstantVariable, TensorVariable, TensorVariable, ConstantVariable, ConstantVariable)
start subgraph compile and execution.

排查后发现,目前的 all_reduce 只是 cpp_extensions (PyCapsule表明是C++拓展),而不是 Paddle 自定义算子,所以需要给 all_reduce 添加自定义算子声明:

PD_BUILD_STATIC_OP(all_reduce)
    .Inputs({"inp",
             "out"})
    .Outputs({"new_out"})
    .Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
    .SetInplaceMap({{"out", "new_out"}})
    .SetKernelFn(PD_KERNEL(all_reduce));

另外,由于自定义算子要求 Tensor 参数在前,其他 attr 参数在后,所以也调整了 all_reduce 的参数顺序


另外由于SOT的dataclass目前存在这样的问题——直接这样设置会存在属性设置不生效的问题:

        metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
        metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
[Translate InlineFn 3] (line 451) LOAD_FAST    metadata, stack is [PaddleApiVariable(flash_attention_v3_varlen, object_11204, forward_meta.attn_backend.flash_attn_func), TensorVariable([1534, 16, 192], bfloat16, False, var_193, None, object_10794, q), TensorVariable([1534, 16, 192], bfloat16, False, var_194, None, object_10799, k), TensorVariable([1534, 16, 192], bfloat16, False, var_195, None, object_10804, v), TensorVariable([SymbolicInt(9)], int32, True, var_209, None, object_11189, forward_meta.cu_seqlens_q), TensorVariable([SymbolicInt(9)], int32, True, var_212, None, object_11209, forward_meta.cu_seqlens_k)]
[Translate InlineFn 3] (line 451) LOAD_ATTR    max_enc_len_this_time, stack is [PaddleApiVariable(flash_attention_v3_varlen, object_11204, forward_meta.attn_backend.flash_attn_func), TensorVariable([1534, 16, 192], bfloat16, False, var_193, None, object_10794, q), TensorVariable([1534, 16, 192], bfloat16, False, var_194, None, object_10799, k), TensorVariable([1534, 16, 192], bfloat16, False, var_195, None, object_10804, v), TensorVariable([SymbolicInt(9)], int32, True, var_209, None, object_11189, forward_meta.cu_seqlens_q), TensorVariable([SymbolicInt(9)], int32, True, var_212, None, object_11209, forward_meta.cu_seqlens_k), DataClassInstanceVariable(object_11048, forward_meta.attn_backend.attention_metadata)]
Traceback (most recent call last):
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py", line 407, in start_translate
    new_custom_code, guard_fn = simulator.transform(frame)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 2125, in transform
    self.run()
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 628, in run
    is_stop = self.step(cur_instr)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 674, in step
    return getattr(self, opname)(instr)  # run single step.
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 312, in wrapper
    return call_fn(self, instr)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 1436, in CALL_METHOD
    self.stack.push(method(*args))
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py", line 158, in __call__
    return self.call_function(*args, **kwargs)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py", line 334, in call_function
    raise e
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py", line 328, in call_function
    output = inline_executor.inline_call()
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py", line 100, in inline_call
    self.run()
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 628, in run
    is_stop = self.step(cur_instr)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 674, in step
    return getattr(self, opname)(instr)  # run single step.
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 312, in wrapper
    return call_fn(self, instr)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 1436, in CALL_METHOD
    self.stack.push(method(*args))
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py", line 158, in __call__
    return self.call_function(*args, **kwargs)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py", line 334, in call_function
    raise e
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py", line 328, in call_function
    output = inline_executor.inline_call()
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py", line 100, in inline_call
    self.run()
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 628, in run
    is_stop = self.step(cur_instr)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 674, in step
    return getattr(self, opname)(instr)  # run single step.
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 312, in wrapper
    return call_fn(self, instr)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py", line 934, in LOAD_ATTR
    BuiltinVariable(
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py", line 158, in __call__
    return self.call_function(*args, **kwargs)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py", line 968, in call_function
    return handler(*args, **kwargs)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py", line 604, in <lambda>
    lambda var, name, default=None: var.getattr(
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py", line 2462, in getattr
    return super().getattr(name, default)
  File "/workspace/Paddle/build/python/paddle/jit/sot/opcode_translator/executor/variables/base.py", line 545, in getattr
    raise HasNoAttributeError(
paddle.jit.sot.utils.exceptions.HasNoAttributeError: DataClassInstanceVariable DataClassInstanceVariable(object_11048, forward_meta.attn_backend.attention_metadata) has no attribute max_enc_len_this_time

所以提前设置一下,后续SOT会修复这个问题:

    max_enc_len_this_time: Optional[paddle.Tensor] = None
    max_dec_len_this_time: Optional[paddle.Tensor] = None

cc @SigureMo @zyfncg

Copy link

paddle-bot bot commented Aug 19, 2025

Thanks for your contribution!

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

Comment on lines +169 to +172
.Inputs({"inp",
"out"})
.Outputs({"new_out"})
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些名字感觉有点奇怪

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实,但他原来就叫这个名字🤔

@Jiang-Jia-Jun Jiang-Jia-Jun requested a review from zhink August 20, 2025 02:57

except:
tensor_model_parallel_all_reduce = None
@paddle.jit.marker.unified
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前的try是为了适配rl,为啥去掉呢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里能解释下为什么 def 一个函数是有可能报错的么?以我浅薄的认知,这里不会有问题,如果有问题麻烦提供下报错呢,按我理解一定会有更合适的处理方式

Copy link
Contributor Author

@DrRyanHuang DrRyanHuang Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是因为旧版本Paddle版本没有 paddle.jit.marker.unified 才加的 try except 语句呀

或者可以这样:

if hasattr(paddle.jit, "marker") and (hasattr, paddle.jit.marker, "unified"):
    mark_as_unified = paddle.jit.marker.unified
else:
    # do-nothing for PaddlePaddle 3.1-
    mark_as_unified = lamdba fn: fn
tensor_model_parallel_all_reduce = mark_as_unified(tensor_model_parallel_all_reduce)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants