Skip to content

Commit 55349e3

Browse files
support mooncake store dp attention (sgl-project#9684)
1 parent e1f7cf5 commit 55349e3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/sglang/srt/managers/cache_controller.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ def _mooncake_page_get(self, operation, hash_values, host_indices):
636636
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
637637
hash_values,
638638
host_indices,
639+
self.storage_config.tp_rank,
639640
)
640641
get_result = self.storage_backend.batch_get(
641642
key_strs,
@@ -838,6 +839,7 @@ def _mooncake_page_set(self, hash_values, host_indices) -> bool:
838839
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
839840
hash_values,
840841
host_indices,
842+
self.storage_config.tp_rank,
841843
)
842844
success = self.storage_backend.batch_set(
843845
key_strs,

python/sglang/srt/mem_cache/memory_pool_host.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import psutil
88
import torch
99

10-
from sglang.srt.distributed import get_tensor_model_parallel_rank
1110
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
1211
from sglang.srt.utils import is_npu
1312

@@ -464,8 +463,7 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
464463
else:
465464
raise ValueError(f"Unsupported layout: {self.layout}")
466465

467-
def get_buffer_meta(self, keys, indices):
468-
local_rank = get_tensor_model_parallel_rank()
466+
def get_buffer_meta(self, keys, indices, local_rank):
469467
ptr_list = []
470468
key_list = []
471469
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
@@ -704,7 +702,7 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
704702
else:
705703
raise ValueError(f"Unsupported layout: {self.layout}")
706704

707-
def get_buffer_meta(self, keys, indices):
705+
def get_buffer_meta(self, keys, indices, local_rank):
708706
ptr_list = []
709707
key_list = []
710708
kv_buffer_data_ptr = self.kv_buffer.data_ptr()

0 commit comments

Comments
 (0)