Skip to content

Commit b02df20

Browse files
authored
Minor add utility to read expert distribution recorder output (sgl-project#7134)
1 parent bd7cfbd commit b02df20

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import reader
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from collections import defaultdict
2+
from pathlib import Path
3+
4+
import torch
5+
from tqdm import tqdm
6+
7+
from sglang.srt.managers.expert_distribution import (
8+
_convert_global_physical_count_to_logical_count,
9+
)
10+
11+
convert_global_physical_count_to_logical_count = (
12+
_convert_global_physical_count_to_logical_count
13+
)
14+
15+
16+
def read_mode_per_pass(dir_data: Path):
17+
"""Read data from ExpertDistributionRecorder when recorded with mode `per_pass`"""
18+
19+
# gpc := global_physical_count
20+
gpc_of_forward_pass_and_rank = defaultdict(lambda: defaultdict())
21+
for path in tqdm(list(dir_data.glob("*.pt"))):
22+
data_pack = torch.load(path, weights_only=True)
23+
last_physical_to_logical_map = data_pack["last_physical_to_logical_map"]
24+
for record in data_pack["records"]:
25+
forward_pass_id = record["forward_pass_id"]
26+
rank = record["rank"]
27+
assert (
28+
gpc_of_forward_pass_and_rank[forward_pass_id].get(rank) is None
29+
), f"Duplicated {forward_pass_id=} {rank=}"
30+
gpc_of_forward_pass_and_rank[forward_pass_id][rank] = record[
31+
"global_physical_count"
32+
]
33+
34+
forward_pass_ids = sorted(gpc_of_forward_pass_and_rank.keys())
35+
print(f"Make {forward_pass_ids=} into array")
36+
37+
items = []
38+
for forward_pass_id, gpc_of_rank in sorted(gpc_of_forward_pass_and_rank.items()):
39+
gpc_of_rank_tensor = torch.stack(
40+
[gpc for rank, gpc in sorted(gpc_of_rank.items())]
41+
).sum(dim=0)
42+
items.append(gpc_of_rank_tensor)
43+
44+
gpc_of_forward_pass = torch.stack(items)
45+
print(f"{gpc_of_forward_pass.shape=}")
46+
47+
return dict(
48+
global_physical_count_of_forward_pass=gpc_of_forward_pass,
49+
last_physical_to_logical_map=last_physical_to_logical_map,
50+
forward_pass_ids=forward_pass_ids,
51+
)

0 commit comments

Comments
 (0)