Skip to content

Commit cd8e337

Browse files
committed
add custom op declaration
1 parent fef447e commit cd8e337

File tree

5 files changed

+30
-23
lines changed

5 files changed

+30
-23
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
530530
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
531531
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
532532

533-
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
533+
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
534534
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
535535

536536
void dispose(int64_t _fa);

custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
4949
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
5050
* copied into _reg_buffer.
5151
*/
52-
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
52+
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
5353
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
5454
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
5555
auto stream = inp.stream();
@@ -163,3 +163,12 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
163163
void free_shared_buffer(fptr_t buffer) {
164164
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
165165
}
166+
167+
168+
PD_BUILD_STATIC_OP(all_reduce)
169+
.Inputs({"inp",
170+
"out"})
171+
.Outputs({"new_out"})
172+
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
173+
.SetInplaceMap({{"out", "new_out"}})
174+
.SetKernelFn(PD_KERNEL(all_reduce));

fastdeploy/distributed/communication.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,17 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
4242
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
4343

4444

45-
try:
46-
47-
@paddle.jit.marker.unified
48-
def tensor_model_parallel_all_reduce(
49-
input_: paddle.Tensor,
50-
) -> paddle.Tensor:
51-
"""All-reduce the input tensor across model parallel group."""
52-
global _TP_AR
53-
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
54-
_TP_AR.custom_all_reduce(input_)
55-
elif paddle.in_dynamic_mode():
56-
hcg = fleet.get_hybrid_communicate_group()
57-
mp_group = hcg.get_model_parallel_group()
58-
dist.all_reduce(input_, group=mp_group)
59-
else:
60-
dist.all_reduce(input_)
61-
62-
except:
63-
tensor_model_parallel_all_reduce = None
45+
@paddle.jit.marker.unified
46+
def tensor_model_parallel_all_reduce(
47+
input_: paddle.Tensor,
48+
) -> paddle.Tensor:
49+
"""All-reduce the input tensor across model parallel group."""
50+
global _TP_AR
51+
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
52+
_TP_AR.custom_all_reduce(input_)
53+
elif paddle.in_dynamic_mode():
54+
hcg = fleet.get_hybrid_communicate_group()
55+
mp_group = hcg.get_model_parallel_group()
56+
dist.all_reduce(input_, group=mp_group)
57+
else:
58+
dist.all_reduce(input_)

fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ def all_reduce(
158158
if out is None:
159159
out = paddle.empty_like(inp)
160160
if registered:
161-
all_reduce(self._ptr, inp, out, 0, 0)
161+
all_reduce(inp, out, self._ptr, 0, 0)
162162
else:
163-
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
163+
all_reduce(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
164164
return out
165165

166166
def start_capture(self):

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ class MLAAttentionMetadata(AttentionMetadata):
8989
kv_signal_metadata: Optional[paddle.Tensor] = None
9090
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
9191

92+
max_enc_len_this_time: Optional[paddle.Tensor] = None
93+
max_dec_len_this_time: Optional[paddle.Tensor] = None
94+
9295

9396
class MLAAttentionBackend(AttentionBackend):
9497
"""

0 commit comments

Comments
 (0)