diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index 8c662b5ccb5..9ffd587936a 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -1,12 +1,13 @@ from torch import nn -from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu, is_xpu _is_cuda = is_cuda() _is_hip = is_hip() _is_cpu = is_cpu() _is_cpu_amx_available = cpu_has_amx_support() _is_npu = is_npu() +_is_xpu = is_xpu() class CustomOp(nn.Module): @@ -88,5 +89,7 @@ def dispatch_forward(self): return self.forward_cpu elif _is_npu: return self.forward_npu + elif _is_xpu: + return self.forward_xpu else: return self.forward_native diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 15c2ba07727..f4dc48a51dd 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -33,6 +33,7 @@ cpu_has_amx_support, is_cpu, is_cuda, + is_xpu, is_hip, is_npu, set_weight_attrs, @@ -44,8 +45,9 @@ _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() _is_hip = is_hip() +_is_xpu = is_xpu() -if _is_cuda: +if (_is_cuda or _is_xpu): from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul elif _is_hip: from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul @@ -70,8 +72,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_cpu(self, x: torch.Tensor) -> torch.Tensor: if _is_cpu_amx_available: - d = x.shape[-1] // 2 - output_shape = x.shape[:-1] + (d,) out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) return out else: @@ -81,17 +81,20 @@ def forward_npu(self, x: torch.Tensor) -> torch.Tensor: out = torch_npu.npu_swiglu(x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + silu_and_mul(x, out) + return out + class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): super().__init__() self.approximate = approximate - def forward_native(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] - - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) @@ -103,6 +106,16 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: raise RuntimeError("GeluAndMul only support tanh or none") return out + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self._forward_impl(x) + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self._forward_impl(x) + class NewGELU(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -230,7 +243,7 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return nn.Identity() -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu): logger.info( "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index a77747351b8..36490d9bae9 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -28,6 +28,7 @@ is_hip, is_npu, supports_custom_op, + is_xpu, ) _is_cuda = is_cuda() @@ -36,6 +37,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_xpu = is_xpu() if _is_cuda: from sgl_kernel import ( @@ -287,7 +289,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu): logger.info( "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index cfc7f36c52a..6b4d1d6cf2b 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -8,10 +8,11 @@ import torch from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool -from sglang.srt.utils import is_npu +from sglang.srt.utils import is_npu,is_xpu _is_npu = is_npu() -if not _is_npu: +_is_xpu = is_xpu() +if not (_is_npu or _is_xpu): from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, transfer_kv_all_layer_lf_pf,