Skip to content

Commit 90faf90

Browse files
[verl] Modify the update_weights func to align with verl's resharding (sgl-project#5345)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
1 parent 177320a commit 90faf90

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

python/sglang/srt/entrypoints/verl_engine.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# limitations under the License.
1313
# ==============================================================================
1414
import os
15-
from typing import Dict, List, Literal, Optional, Tuple, Union
15+
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.distributed as dist
@@ -124,7 +124,7 @@ def generate(
124124

125125
def update_weights_from_tensor(
126126
self,
127-
named_tensors: List[Tuple[str, torch.Tensor]],
127+
named_tensors: Iterable[Tuple[str, torch.Tensor]],
128128
load_format: Optional[str] = None,
129129
):
130130
# Most naive implementation, can optimize a lot if it is bottleneck
@@ -153,9 +153,12 @@ def update_weights_from_tensor(
153153
)
154154
],
155155
load_format=load_format,
156-
flush_cache=tensor_index == len(named_tensors) - 1,
156+
flush_cache=False,
157157
)
158158

159+
if self._tp_rank == 0:
160+
self._engine.tokenizer_manager.flush_cache()
161+
159162
def release_memory_occupation(self):
160163
if self._tp_rank == 0:
161164
self._engine.release_memory_occupation()

0 commit comments

Comments
 (0)