1
+ """
2
+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+
17
+ import unittest
18
+
19
+ import numpy as np
20
+ import paddle
21
+ import paddle .distributed as dist
22
+ from paddle .distributed import fleet
23
+
24
+ from fastdeploy .distributed .custom_all_reduce import CustomAllreduce
25
+
26
+ class Test (unittest .TestCase ):
27
+ def setUp (self ):
28
+ """
29
+ Initialize the test environment,
30
+ including setting random seeds.
31
+ """
32
+ paddle .seed (2025 )
33
+ strategy = fleet .DistributedStrategy ()
34
+ strategy .hybrid_configs = {
35
+ "dp_degree" : 1 ,
36
+ "mp_degree" : 2 ,
37
+ "pp_degree" : 1 ,
38
+ "sharding_degree" : 1 ,
39
+ }
40
+
41
+ fleet .init (is_collective = True , strategy = strategy )
42
+
43
+ def test_case (self ):
44
+ """
45
+ Check if the CustomAllreduce function works properly.
46
+ """
47
+
48
+ mns = [[1 , 2048 ], [2 , 4096 ], [20 , 4096 ], [128 , 4096 ], [256 , 4096 ], [256 , 8192 ]]
49
+
50
+ hcg = fleet .get_hybrid_communicate_group ()
51
+ model_parallel_group = hcg .get_model_parallel_group ()
52
+ fa = CustomAllreduce (model_parallel_group )
53
+
54
+ for (m , n ) in mns :
55
+ data_ar = paddle .rand ([m , n ],dtype = 'bfloat16' )
56
+ data_paddle = data_ar .clone ()
57
+ if fa .should_custom_ar (data_ar ):
58
+ fa .custom_all_reduce (data_ar )
59
+ dist .all_reduce (data_paddle )
60
+ if dist .get_rank () == 0 :
61
+ np .testing .assert_allclose (
62
+ data_ar .numpy (),
63
+ data_paddle .numpy (),
64
+ rtol = 1e-04 ,
65
+ atol = 1e-04 ,
66
+ )
67
+
68
+
69
+ if __name__ == "__main__" :
70
+ unittest .main ()
0 commit comments