diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index d49392f4c3d..6c2740d7c57 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -12,7 +12,7 @@ # limitations under the License. # ============================================================================== import os -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union import torch import torch.distributed as dist @@ -124,7 +124,7 @@ def generate( def update_weights_from_tensor( self, - named_tensors: List[Tuple[str, torch.Tensor]], + named_tensors: Iterable[Tuple[str, torch.Tensor]], load_format: Optional[str] = None, ): # Most naive implementation, can optimize a lot if it is bottleneck @@ -153,9 +153,12 @@ def update_weights_from_tensor( ) ], load_format=load_format, - flush_cache=tensor_index == len(named_tensors) - 1, + flush_cache=False, ) + if self._tp_rank == 0: + self._engine.tokenizer_manager.flush_cache() + def release_memory_occupation(self): if self._tp_rank == 0: self._engine.release_memory_occupation()