-
Notifications
You must be signed in to change notification settings - Fork 596
Add custom op declaration for all_reduce
#3473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,22 +42,17 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024): | |
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes) | ||
|
||
|
||
try: | ||
|
||
@paddle.jit.marker.unified | ||
def tensor_model_parallel_all_reduce( | ||
input_: paddle.Tensor, | ||
) -> paddle.Tensor: | ||
"""All-reduce the input tensor across model parallel group.""" | ||
global _TP_AR | ||
if _TP_AR is not None and _TP_AR.should_custom_ar(input_): | ||
_TP_AR.custom_all_reduce(input_) | ||
elif paddle.in_dynamic_mode(): | ||
hcg = fleet.get_hybrid_communicate_group() | ||
mp_group = hcg.get_model_parallel_group() | ||
dist.all_reduce(input_, group=mp_group) | ||
else: | ||
dist.all_reduce(input_) | ||
|
||
except: | ||
tensor_model_parallel_all_reduce = None | ||
@paddle.jit.marker.unified | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 之前的try是为了适配rl,为啥去掉呢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里能解释下为什么 def 一个函数是有可能报错的么?以我浅薄的认知,这里不会有问题,如果有问题麻烦提供下报错呢,按我理解一定会有更合适的处理方式 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是不是因为旧版本Paddle版本没有 或者可以这样: if hasattr(paddle.jit, "marker") and (hasattr, paddle.jit.marker, "unified"):
mark_as_unified = paddle.jit.marker.unified
else:
# do-nothing for PaddlePaddle 3.1-
mark_as_unified = lamdba fn: fn
tensor_model_parallel_all_reduce = mark_as_unified(tensor_model_parallel_all_reduce) 目前RL的开发是基于什么paddle版本呀? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RL的训练代码(使用fleet分支paddle,比较老)会import FD组网,这里要恢复为之前的样子 |
||
def tensor_model_parallel_all_reduce( | ||
input_: paddle.Tensor, | ||
) -> paddle.Tensor: | ||
"""All-reduce the input tensor across model parallel group.""" | ||
global _TP_AR | ||
if _TP_AR is not None and _TP_AR.should_custom_ar(input_): | ||
_TP_AR.custom_all_reduce(input_) | ||
elif paddle.in_dynamic_mode(): | ||
hcg = fleet.get_hybrid_communicate_group() | ||
mp_group = hcg.get_model_parallel_group() | ||
dist.all_reduce(input_, group=mp_group) | ||
else: | ||
dist.all_reduce(input_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些名字感觉有点奇怪
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确实,但他原来就叫这个名字🤔