File tree Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -37,6 +37,7 @@ def __init__(
37
37
monkey_patch_torch_reductions ()
38
38
self ._device_mesh_cpu = device_mesh_cpu
39
39
self ._tp_rank = device_mesh_cpu .get_local_rank ()
40
+ self ._rank = device_mesh_cpu .get_rank ()
40
41
self ._tp_size = device_mesh_cpu .size ()
41
42
tp_size_per_node = self ._tp_size // nnodes
42
43
node_rank = self ._tp_rank // tp_size_per_node
@@ -114,7 +115,7 @@ def generate(
114
115
# Most naive implementation, can extract tensor and send via gloo if too slow
115
116
[output ] = broadcast_pyobj (
116
117
data = [output ],
117
- rank = self ._tp_rank ,
118
+ rank = self ._rank ,
118
119
dist_group = self ._device_mesh_cpu .get_group (),
119
120
src = self ._device_mesh_cpu .mesh [0 ].item (),
120
121
force_cpu_device = False ,
@@ -157,7 +158,7 @@ def update_weights_from_tensor(
157
158
)
158
159
159
160
if self ._tp_rank == 0 :
160
- self ._engine .tokenizer_manager . flush_cache ()
161
+ self ._engine .flush_cache ()
161
162
162
163
def release_memory_occupation (self ):
163
164
if self ._tp_rank == 0 :
Original file line number Diff line number Diff line change @@ -897,7 +897,10 @@ def broadcast_pyobj(
897
897
src : int = 0 ,
898
898
force_cpu_device : bool = True ,
899
899
):
900
- """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
900
+ """Broadcast inputs from src rank to all other ranks with torch.dist backend.
901
+ The `rank` here refer to the source rank on global process group (regardless
902
+ of dist_group argument).
903
+ """
901
904
device = torch .device (
902
905
"cuda" if torch .cuda .is_available () and not force_cpu_device else "cpu"
903
906
)
You can’t perform that action at this time.
0 commit comments