Skip to content

enable llama3.1-8B on xpu #9434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/custom_op.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from torch import nn

from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu, is_xpu

_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu()
_is_xpu = is_xpu()


class CustomOp(nn.Module):
Expand Down Expand Up @@ -88,5 +89,7 @@ def dispatch_forward(self):
return self.forward_cpu
elif _is_npu:
return self.forward_npu
elif _is_xpu:
return self.forward_xpu
else:
return self.forward_native
31 changes: 22 additions & 9 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
cpu_has_amx_support,
is_cpu,
is_cuda,
is_xpu,
is_hip,
is_npu,
set_weight_attrs,
Expand All @@ -44,8 +45,9 @@
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_hip = is_hip()
_is_xpu = is_xpu()

if _is_cuda:
if (_is_cuda or _is_xpu):
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
elif _is_hip:
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
Expand All @@ -70,8 +72,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:

def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
if _is_cpu_amx_available:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
return out
else:
Expand All @@ -81,17 +81,20 @@ def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
out = torch_npu.npu_swiglu(x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(x, out)
return out


class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"):
super().__init__()
self.approximate = approximate

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Expand All @@ -103,6 +106,16 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
raise RuntimeError("GeluAndMul only support tanh or none")
return out

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self._forward_impl(x)

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self._forward_impl(x)


class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -230,7 +243,7 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return nn.Identity()


if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu):
logger.info(
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
)
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_hip,
is_npu,
supports_custom_op,
is_xpu,
)

_is_cuda = is_cuda()
Expand All @@ -36,6 +37,7 @@
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_xpu = is_xpu()

if _is_cuda:
from sgl_kernel import (
Expand Down Expand Up @@ -287,7 +289,7 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"


if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu):
logger.info(
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
)
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/mem_cache/memory_pool_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import torch

from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import is_npu
from sglang.srt.utils import is_npu,is_xpu

_is_npu = is_npu()
if not _is_npu:
_is_xpu = is_xpu()
if not (_is_npu or _is_xpu):
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_lf_pf,
Expand Down