Skip to content

Commit 968ef51

Browse files
Support aiter RMSNorm in AMD (#5510)
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com>
1 parent 1343200 commit 968ef51

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

python/sglang/srt/layers/layernorm.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
import torch.nn as nn
2121

2222
from sglang.srt.custom_op import CustomOp
23-
from sglang.srt.utils import is_cuda
23+
from sglang.srt.utils import is_cuda, is_hip
24+
25+
logger = logging.getLogger(__name__)
2426

2527
_is_cuda = is_cuda()
28+
_is_hip = is_hip()
2629

2730
if _is_cuda:
2831
from sgl_kernel import (
@@ -32,8 +35,20 @@
3235
rmsnorm,
3336
)
3437

38+
if _is_hip:
3539

36-
logger = logging.getLogger(__name__)
40+
from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
41+
42+
rmsnorm = rms_norm
43+
44+
def fused_add_rmsnorm(
45+
x: torch.Tensor,
46+
residual: torch.Tensor,
47+
w: torch.Tensor,
48+
eps: float,
49+
) -> Tuple[torch.Tensor, torch.Tensor]:
50+
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
51+
return x, residual
3752

3853

3954
class RMSNorm(CustomOp):
@@ -139,7 +154,7 @@ def extra_repr(self):
139154
return f"{tuple(self.weight.shape)}, eps={self.eps}"
140155

141156

142-
if not _is_cuda:
157+
if not (_is_cuda or _is_hip):
143158
logger.info(
144159
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
145160
)

0 commit comments

Comments
 (0)