From fb31562361a6de16d0681c38d2f934259b7318ef Mon Sep 17 00:00:00 2001 From: Timothy Feng Date: Tue, 19 Aug 2025 17:15:15 -0400 Subject: [PATCH] remove mrope position sync --- .../srt/model_executor/forward_batch_info.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index bceb0759efa..2f55b1c25b1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -513,24 +513,23 @@ def _compute_mrope_positions( for batch_idx in range(batch_size): mm_input = batch.multimodal_inputs[batch_idx] if self.forward_mode.is_decode(): - mrope_position_deltas = ( - [0] - if mm_input is None - else flatten_nested_list(mm_input.mrope_position_delta.tolist()) - ) - next_input_positions = [] - for mrope_position_delta in mrope_position_deltas: - # batched deltas needs to be processed separately - # Convert list of lists to tensor with shape [3, seq_len] - next_input_positions += [ - MRotaryEmbedding.get_next_input_positions( - mrope_position_delta, - int(self.seq_lens[batch_idx]) - 1, - int(self.seq_lens[batch_idx]), - ) - ] # 3 * N - mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1) + if mm_input is None: + mrope_positions_list[batch_idx] = torch.full( + (3, 1), + self.seq_lens[batch_idx] - 1, + dtype=torch.int64, + device=model_runner.device, + ) + else: + mrope_position_deltas = ( + mm_input.mrope_position_delta + .flatten() + .to(model_runner.device, non_blocking=True) + ) + mrope_positions_list[batch_idx] = ( + mrope_position_deltas + self.seq_lens[batch_idx] - 1 + ).unsqueeze(0).repeat(3, 1) elif self.forward_mode.is_extend(): extend_seq_len, extend_prefix_len = ( batch.extend_seq_lens[batch_idx],