@@ -86,6 +86,7 @@ def __init__(
86
86
super (BaseRLModel , self ).__init__ ()
87
87
self .infer_to_train_mapping = {}
88
88
self .fd_config = None
89
+ self ._mappings_built = False
89
90
90
91
@classmethod
91
92
def name (cls ) -> str :
@@ -142,6 +143,12 @@ def name(self) -> str:
142
143
143
144
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
144
145
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
146
+ if self ._mappings_built :
147
+ return self .infer_to_train_mapping
148
+
149
+ self .infer_to_train_mapping = {}
150
+ self ._mappings_built = True
151
+
145
152
# Prepare placeholders
146
153
place_holders = ["weight" ]
147
154
@@ -215,6 +222,11 @@ def name(self) -> str:
215
222
216
223
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
217
224
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
225
+ if self ._mappings_built :
226
+ return self .infer_to_train_mapping
227
+
228
+ self .infer_to_train_mapping = {}
229
+ self ._mappings_built = True
218
230
# Prepare placeholders
219
231
place_holders = ["weight" ]
220
232
@@ -316,6 +328,11 @@ def name(self) -> str:
316
328
317
329
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
318
330
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
331
+ if self ._mappings_built :
332
+ return self .infer_to_train_mapping
333
+
334
+ self .infer_to_train_mapping = {}
335
+ self ._mappings_built = True
319
336
# Prepare placeholders
320
337
place_holders = ["weight" ]
321
338
@@ -360,6 +377,11 @@ def name(self) -> str:
360
377
361
378
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
362
379
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
380
+ if self ._mappings_built :
381
+ return self .infer_to_train_mapping
382
+
383
+ self .infer_to_train_mapping = {}
384
+ self ._mappings_built = True
363
385
# Prepare placeholders
364
386
place_holders = ["weight" ]
365
387
@@ -429,6 +451,11 @@ def name(self) -> str:
429
451
return "Qwen3ForCausalLMRL"
430
452
431
453
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
454
+ if self ._mappings_built :
455
+ return self .infer_to_train_mapping
456
+
457
+ self .infer_to_train_mapping = {}
458
+ self ._mappings_built = True
432
459
# Prepare placeholders
433
460
place_holders = ["weight" ]
434
461
0 commit comments