Skip to content

Commit 1fe197b

Browse files
committed
fix mapping
1 parent a84a98b commit 1fe197b

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

fastdeploy/rl/rollout_model.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
super(BaseRLModel, self).__init__()
8787
self.infer_to_train_mapping = {}
8888
self.fd_config = None
89+
self._mappings_built = False
8990

9091
@classmethod
9192
def name(cls) -> str:
@@ -142,6 +143,12 @@ def name(self) -> str:
142143

143144
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
144145
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
146+
if self._mappings_built:
147+
return self.infer_to_train_mapping
148+
149+
self.infer_to_train_mapping = {}
150+
self._mappings_built = True
151+
145152
# Prepare placeholders
146153
place_holders = ["weight"]
147154

@@ -215,6 +222,11 @@ def name(self) -> str:
215222

216223
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
217224
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
225+
if self._mappings_built:
226+
return self.infer_to_train_mapping
227+
228+
self.infer_to_train_mapping = {}
229+
self._mappings_built = True
218230
# Prepare placeholders
219231
place_holders = ["weight"]
220232

@@ -316,6 +328,11 @@ def name(self) -> str:
316328

317329
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
318330
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
331+
if self._mappings_built:
332+
return self.infer_to_train_mapping
333+
334+
self.infer_to_train_mapping = {}
335+
self._mappings_built = True
319336
# Prepare placeholders
320337
place_holders = ["weight"]
321338

@@ -360,6 +377,11 @@ def name(self) -> str:
360377

361378
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
362379
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
380+
if self._mappings_built:
381+
return self.infer_to_train_mapping
382+
383+
self.infer_to_train_mapping = {}
384+
self._mappings_built = True
363385
# Prepare placeholders
364386
place_holders = ["weight"]
365387

@@ -429,6 +451,11 @@ def name(self) -> str:
429451
return "Qwen3ForCausalLMRL"
430452

431453
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
454+
if self._mappings_built:
455+
return self.infer_to_train_mapping
456+
457+
self.infer_to_train_mapping = {}
458+
self._mappings_built = True
432459
# Prepare placeholders
433460
place_holders = ["weight"]
434461

0 commit comments

Comments
 (0)