Skip to content

Commit a879811

Browse files
zou3519zhyncs
andauthored
Fix torch.compile cacheing (sgl-project#5259)
Co-authored-by: zhyncs <me@zhyncs.com>
1 parent a222945 commit a879811

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

python/sglang/srt/model_executor/model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@
6464
)
6565
from sglang.srt.model_loader.utils import set_default_torch_dtype
6666
from sglang.srt.model_loader.weight_utils import default_weight_loader
67-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
67+
from sglang.srt.patch_torch import (
68+
monkey_patch_torch_compile,
69+
monkey_patch_torch_reductions,
70+
)
6871
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
6972
from sglang.srt.server_args import ServerArgs
7073
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -88,6 +91,8 @@
8891
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
8992
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
9093

94+
monkey_patch_torch_compile()
95+
9196

9297
class ModelRunner:
9398
"""ModelRunner runs the forward passes of the models."""

python/sglang/srt/patch_torch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Callable, Union
1515

1616
import torch
17+
from packaging import version
1718
from torch.multiprocessing import reductions
1819

1920

@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
6970

7071
def _modify_tuple(t, index: int, modifier: Callable):
7172
return *t[:index], modifier(t[index]), *t[index + 1 :]
73+
74+
75+
def monkey_patch_torch_compile():
76+
if version.parse(torch.__version__) < version.parse("2.8.0"):
77+
# These things are cacheable by torch.compile. torch.compile just doesn't know it.
78+
# This was fixed in PyTorch 2.8, but until then, we monkey patch.
79+
import torch._higher_order_ops.auto_functionalize as af
80+
81+
af.auto_functionalized_v2._cacheable = True
82+
af.auto_functionalized._cacheable = True

0 commit comments

Comments
 (0)