diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 0d89ebc8818..c961dd554af 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -49,13 +49,15 @@ elif _is_cpu and _is_cpu_amx_available: pass elif _is_hip: - from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul + from sgl_kernel import gelu_and_mul, silu_and_mul if _use_aiter: try: from aiter import moe_sum except ImportError: raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") + else: + from vllm import _custom_ops as vllm_ops if _is_cuda or _is_hip: @@ -1537,7 +1539,7 @@ def fused_experts_impl( gemm1_alpha, gemm1_limit, ) - elif _is_cuda: + elif _is_cuda or _is_hip: silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.silu_and_mul( @@ -1546,7 +1548,7 @@ def fused_experts_impl( elif activation == "gelu": assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu" - if _is_cuda: + if _is_cuda or _is_hip: gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.gelu_and_mul( @@ -1619,10 +1621,19 @@ def fused_experts_impl( out_hidden_states[begin_chunk_idx:end_chunk_idx], ) else: - vllm_ops.moe_sum( - intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + # According to micro benchmark results, torch.compile can get better performance for small token. + if tokens_in_chunk <= 32: + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + routed_scaling_factor, + ) + else: + moe_sum_reduce_triton( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + routed_scaling_factor, + ) else: vllm_ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape), diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 8f8de70280a..53a80911bc3 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -27,8 +27,6 @@ if _is_cuda: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace -if _use_aiter: - from aiter.rotary_embedding import get_rope as aiter_get_rope if is_npu(): import torch_npu @@ -1649,222 +1647,257 @@ def extra_repr(self) -> str: _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} +if _use_aiter: + # The following changes are for qwen2.py and qwen3_moe.py + import functools + import inspect -def get_rope( - head_size: int, - rotary_dim: int, - max_position: int, - base: int, - is_neox_style: bool = True, - rope_scaling: Optional[Dict[str, Any]] = None, - dtype: Optional[torch.dtype] = None, - partial_rotary_factor: float = 1.0, - dual_chunk_attention_config: Optional[Dict[str, Any]] = None, -) -> RotaryEmbedding: - if dtype is None: - dtype = torch.get_default_dtype() - if rope_scaling is not None: - # Transforms every value that is a list into a tuple for caching calls - rope_scaling_tuple = { - k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() - } - rope_scaling_args = tuple(rope_scaling_tuple.items()) - else: - rope_scaling_args = None + import aiter.rotary_embedding as _aiter_re + from aiter.rotary_embedding import get_rope - if dual_chunk_attention_config is not None: - dual_chunk_attention_tuple = { - k: tuple(v) if isinstance(v, list) else v - for k, v in dual_chunk_attention_config.items() - if k != "sparse_attention_config" - } - dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) - else: - dual_chunk_attention_args = None + _orig_get_rope = _aiter_re.get_rope - if partial_rotary_factor < 1.0: - rotary_dim = int(rotary_dim * partial_rotary_factor) - key = ( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - rope_scaling_args, - dual_chunk_attention_args, - dtype, - ) - if key in _ROPE_DICT: - return _ROPE_DICT[key] + # Detect whether the upstream function already knows that keyword + if ( + "dual_chunk_attention_config" + not in inspect.signature(_orig_get_rope).parameters + ): - if dual_chunk_attention_config is not None: - extra_kwargs = { - k: v - for k, v in dual_chunk_attention_config.items() - if k in ("chunk_size", "local_size") - } - rotary_emb = DualChunkRotaryEmbedding( + def _patched_get_rope(*args, dual_chunk_attention_config=None, **kwargs): + """ + Wrapper around aiter.rotary_embedding.get_rope that ignores + `dual_chunk_attention_config` when the upstream implementation + does not accept it. + """ + # Just throw the value away; you could also log or assert. + return _orig_get_rope(*args, **kwargs) + + # Ensure the wrapper looks like the original + _patched_get_rope = functools.update_wrapper(_patched_get_rope, _orig_get_rope) + + # Monkey-patch + _aiter_re.get_rope = _patched_get_rope + + # Finally, import the (possibly patched) symbol for local use + get_rope = _aiter_re.get_rope +else: + + def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, + ) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( head_size, rotary_dim, max_position, base, is_neox_style, + rope_scaling_args, + dual_chunk_attention_args, dtype, - **extra_kwargs, ) - elif rope_scaling is None: - rotary_emb = RotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, dtype - ) - else: - if "rope_type" in rope_scaling: - scaling_type = rope_scaling["rope_type"] - elif "type" in rope_scaling: - scaling_type = rope_scaling["type"] - else: - raise ValueError("Unknown RoPE scaling type") - - if scaling_type == "llama3": - scaling_factor = rope_scaling["factor"] - low_freq_factor = rope_scaling["low_freq_factor"] - high_freq_factor = rope_scaling["high_freq_factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] - rotary_emb = Llama3RotaryEmbedding( + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype, - scaling_factor, - low_freq_factor, - high_freq_factor, - original_max_position, + **extra_kwargs, ) - elif scaling_type == "default": - if "mrope_section" in rope_scaling: - rotary_emb = MRotaryEmbedding( + elif rope_scaling is None: + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) + else: + if "rope_type" in rope_scaling: + scaling_type = rope_scaling["rope_type"] + elif "type" in rope_scaling: + scaling_type = rope_scaling["type"] + else: + raise ValueError("Unknown RoPE scaling type") + + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype, - mrope_section=rope_scaling["mrope_section"], + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, ) - else: - rotary_emb = RotaryEmbedding( + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) + elif scaling_type == "linear": + scaling_factor = rope_scaling["factor"] + rotary_emb = LinearScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor, dtype, ) - elif scaling_type == "linear": - scaling_factor = rope_scaling["factor"] - rotary_emb = LinearScalingRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - scaling_factor, - dtype, - ) - elif scaling_type == "dynamic": - scaling_factor = rope_scaling["factor"] - if "alpha" in rope_scaling: - rotary_emb = DynamicNTKAlphaRotaryEmbedding( + elif scaling_type == "dynamic": + scaling_factor = rope_scaling["factor"] + if "alpha" in rope_scaling: + rotary_emb = DynamicNTKAlphaRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling["alpha"], + dtype, + ) + else: + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding( head_size, rotary_dim, - max_position, + original_max_position, base, is_neox_style, - rope_scaling["alpha"], + scaling_factor, dtype, + **extra_kwargs, ) - else: - rotary_emb = DynamicNTKScalingRotaryEmbedding( + elif scaling_type == "deepseek_yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + rotary_emb = DeepseekScalingRotaryEmbedding( head_size, rotary_dim, - max_position, + original_max_position, base, is_neox_style, scaling_factor, dtype, + **extra_kwargs, ) - elif scaling_type == "yarn": - scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] - extra_kwargs = { - k: v - for k, v in rope_scaling.items() - if k - in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") - } - rotary_emb = YaRNScalingRotaryEmbedding( - head_size, - rotary_dim, - original_max_position, - base, - is_neox_style, - scaling_factor, - dtype, - **extra_kwargs, - ) - elif scaling_type == "deepseek_yarn": - scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] - # assert max_position == original_max_position * scaling_factor - extra_kwargs = { - k: v - for k, v in rope_scaling.items() - if k - in ( - "extrapolation_factor", - "attn_factor", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", + elif scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, + rotary_dim, + max_position, + original_max_position, + base, + is_neox_style, + dtype, + short_factor, + long_factor, + **extra_kwargs, ) - } - rotary_emb = DeepseekScalingRotaryEmbedding( - head_size, - rotary_dim, - original_max_position, - base, - is_neox_style, - scaling_factor, - dtype, - **extra_kwargs, - ) - elif scaling_type == "longrope": - short_factor = rope_scaling["short_factor"] - long_factor = rope_scaling["long_factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] - extra_kwargs = { - k: v - for k, v in rope_scaling.items() - if k in ("short_mscale", "long_mscale") - } - rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( - head_size, - rotary_dim, - max_position, - original_max_position, - base, - is_neox_style, - dtype, - short_factor, - long_factor, - **extra_kwargs, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - _ROPE_DICT[key] = rotary_emb - return rotary_emb + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb # Copied from transformers @@ -1982,6 +2015,8 @@ def get_rope_wrapper( device: Optional[str] = None, ): if device != "cpu": + if _use_aiter: + from aiter.rotary_embedding import get_rope as aiter_get_rope wrapper = aiter_get_rope if _use_aiter else get_rope return wrapper( head_size,