@@ -192,13 +192,17 @@ def init_health_status(self) -> None:
192
192
)
193
193
self .worker_ready_signal .value [self .local_rank % self .max_chips_per_node ] = 1
194
194
195
+ if self .parallel_config .local_data_parallel_id == 0 :
196
+ current_suffix = self .parallel_config .engine_pid
197
+ else :
198
+ current_suffix = self .parallel_config .engine_worker_queue_port
195
199
# init worker_healthy_live_signal
196
200
workers_alive = np .zeros (shape = [min (array_size , self .parallel_config .tensor_parallel_size )], dtype = np .int32 )
197
201
self .worker_healthy_live_signal = IPCSignal (
198
202
name = "worker_healthy_live_signal" ,
199
203
array = workers_alive ,
200
204
dtype = np .int32 ,
201
- suffix = self . parallel_config . engine_worker_queue_port ,
205
+ suffix = current_suffix ,
202
206
create = False ,
203
207
)
204
208
local_rank = self .local_rank % self .parallel_config .tensor_parallel_size
@@ -210,7 +214,7 @@ def init_health_status(self) -> None:
210
214
name = "model_weights_status" ,
211
215
array = workers_model_weights ,
212
216
dtype = np .int32 ,
213
- suffix = self . parallel_config . engine_worker_queue_port ,
217
+ suffix = current_suffix ,
214
218
create = False ,
215
219
)
216
220
@@ -220,7 +224,7 @@ def init_health_status(self) -> None:
220
224
name = "exist_task_signal" ,
221
225
array = workers_exist_task ,
222
226
dtype = np .int32 ,
223
- suffix = self . parallel_config . engine_worker_queue_port ,
227
+ suffix = current_suffix ,
224
228
create = False ,
225
229
)
226
230
@@ -230,7 +234,7 @@ def init_health_status(self) -> None:
230
234
name = "exist_swapped_task_signal" ,
231
235
array = workers_swapped_task ,
232
236
dtype = np .int32 ,
233
- suffix = self . parallel_config . engine_worker_queue_port ,
237
+ suffix = current_suffix ,
234
238
create = False ,
235
239
)
236
240
@@ -240,7 +244,7 @@ def init_health_status(self) -> None:
240
244
name = "exist_prefill_task_signal" ,
241
245
array = exist_prefill_task_signal_data ,
242
246
dtype = np .int32 ,
243
- suffix = self . parallel_config . engine_worker_queue_port ,
247
+ suffix = current_suffix ,
244
248
create = False ,
245
249
)
246
250
@@ -643,12 +647,14 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
643
647
644
648
num_experts_per_rank = num_experts // args .expert_parallel_size
645
649
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
650
+ max_chips_per_node = 16 if current_platform .is_iluvatar () else 8
651
+ parallel_config .local_data_parallel_id = expert_parallel_rank % max_chips_per_node
646
652
647
653
parallel_config .expert_parallel_rank = expert_parallel_rank
648
654
parallel_config .num_experts_per_rank = num_experts_per_rank
649
655
parallel_config .num_experts_start_offset = num_experts_start_offset
650
656
parallel_config .engine_worker_queue_port = parallel_config .engine_worker_queue_port [
651
- parallel_config .expert_parallel_rank
657
+ parallel_config .local_data_parallel_id
652
658
]
653
659
654
660
load_config = LoadConfig (vars (args ))
0 commit comments