-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[new feat] ascend backend support fia fusion kernel #8328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @ZhengdQin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly expands the SGLang framework's hardware compatibility by introducing comprehensive support for DeepSeek-v3 models on Ascend NPU devices. It achieves this through the integration of a new NPU-optimized attention backend, custom fusion kernels for key computational and communication patterns, and specialized memory management and quantization techniques tailored for NPU architecture. The changes aim to unlock high-performance inference for DeepSeek-v3 on NPU, particularly with dynamic w8a8 precision.
Highlights
- NPU DeepSeek-v3 Support: Implemented a new
npumla
attention backend specifically for DeepSeek-v3 models on Ascend NPU devices, enabling dynamic w8a8 precision and leveraging 16 NPUs. - Custom NPU Fusion Kernels: Introduced a suite of NPU-optimized fusion kernels for attention, quantization, and Mixture-of-Experts (MoE) operations, including
_npu_fused_infer_attention_score
,_npu_dequant_swiglu_quant
,npu_grouped_matmul
,npu_moe_distribute_combine
,npu_moe_distribute_dispatch
,npu_moe_re_routing
,npu_moe_init_routing_v2
,npu_moe_finalize_routing
,npu_add_rms_norm
,npu_dynamic_quant
,npu_quant_matmul
, andnpu_moe_gating_top_k
. - Enhanced MoE Dispatching: Developed
NpuDeepEPMoE
andNpuDeepEPDispatcher
to optimize MoE operations on NPU, utilizing specialized NPU kernels for routing, dispatch, and combine phases, including support for dynamic quantization scales. - Paged Attention and Memory Management: Integrated paged attention support for NPU with a fixed page size of 128, and introduced
MLATokenToKVPool
for efficient KV cache buffer management on NPU, including NPU-specific allocation kernels. - W8A8 Int8 Quantization: Extended W8A8 Int8 quantization support to NPU, including dynamic quantization and NPU-specific handling for RMSNorm and quantized matrix multiplications.
- Comprehensive Testing: Added new unit tests (
test_npu_mla_backend.py
) to validate the functionality and performance of the NPU MLA backend, specifically for DeepSeek-v3 with w8a8 quantization on NPU.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The code changes introduce the ability to use Ascend NPU for the deepseek-v3 model. I've identified a few issues, including one critical bug that would prevent the code from running, some high-severity issues related to code correctness, and several medium-severity issues concerning code clarity and maintainability. Addressing these will improve the quality and robustness of the new NPU support.
self.attn_backends.append( | ||
NpuMLABackend( | ||
model_runner, | ||
skip_prefill=True, | ||
kv_indptr_buf=self.kv_indptr[i], | ||
kv_last_page_len_buf=None, | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NpuMLABackend
constructor is being called with kv_indptr_buf
and kv_last_page_len_buf
arguments, but its __init__
method does not accept them. This will raise a TypeError
at runtime.
The __init__
signature for NpuMLABackend
is __init__(self, model_runner: ModelRunner, skip_prefill: bool = False)
.
Please remove the extra arguments from the call to NpuMLABackend
.
self.attn_backends.append(
NpuMLABackend(
model_runner,
skip_prefill=True,
)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
python/sglang/srt/utils.py
Outdated
def support_triton(attn_backend: str) -> bool: | ||
return attn_backend not in ["torch_native", "intel_amx", "npumla"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function support_triton
is redefined here. It was previously defined at line 199. This can lead to unexpected behavior depending on import order. The old implementation should be removed, and this new one should be located where the old one was to maintain code organization.
The old implementation is:
return backend not in ["torch_native", "intel_amx", "ascend"]
The new implementation correctly includes "npumla"
. Please consolidate the definitions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
if False: | ||
block_kv_indices = self.indices_updater_decode.update( | ||
forward_batch.req_pool_indices, | ||
forward_batch.seq_lens, | ||
forward_batch.seq_lens_sum, | ||
init_metadata_replay=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if False:
block contains dead code. It seems to be a placeholder or a disabled feature. To improve code clarity and maintainability, please remove this unreachable block.
else:
max_seqlen_pad = (
forward_batch.seq_lens.max().item() + PAGE_SIZE - 1
) // PAGE_SIZE
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
create_flashmla_kv_indices(
bs,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
if False: | ||
attn_weights = ( | ||
torch.matmul(q, k.transpose(1, 2)) * layer.scaling | ||
) # (bs, n, n) | ||
# assert attention_mask is not None | ||
# if attention_mask is not None: | ||
# attn_weights += attention_mask | ||
|
||
attn_weights = torch.nn.functional.softmax( | ||
attn_weights, dim=-1, dtype=torch.float32 | ||
).to(q.dtype) | ||
|
||
# v = v[..., :self.kv_lora_rank] | ||
attn_ouput = torch.matmul(attn_weights, v) # (bs, n, v_dim) | ||
# attn_ouput = attn_ouput.transpose(1,2).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if False:
block appears to contain a reference implementation and is currently dead code. Please remove it to keep the codebase clean and reduce confusion.
else:
bs = forward_batch.batch_size
if use_gqa:
attn_ouput = torch.empty(
bs_qlen, q_heads, v_dim, device=q.device, dtype=q.dtype
)
q_len_offset = 0
for q_len in forward_batch.seq_len:
attn_ouput[q_len_offset : q_len_offset + q_len] = (
torch.ops.npu.npu_fused_infer_attention_score(
q[None, q_len_offset : q_len_offset + q_len],
k[None, q_len_offset : q_len_offset + q_len],
v[None, q_len_offset : q_len_offset + q_len],
num_heads=q_heads,
num_key_value_heads=k_heads,
input_layout="BSND", # todo, TND not supports q_heads!=k_heads
atten_mask=self.attn_mask.unsqueeze(0),
sparse_mode=3,
scale=layer.scaling,
next_tokens=0,
)[0]
)
q_len_offset += q_len
else: # MHA
if q_dim != v_dim:
q_nope, q_rope = q.split(
[self.v_head_dim, self.qk_rope_head_dim], dim=-1
)
k_nope, k_rope = k.split(
[self.v_head_dim, self.qk_rope_head_dim], dim=-1
)
attn_ouput, _ = torch.ops.npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
v,
query_rope=q_rope,
key_rope=k_rope,
num_heads=q_heads,
input_layout="TND",
atten_mask=self.attn_mask,
sparse_mode=3,
actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
scale=layer.scaling,
next_tokens=0,
)
else:
attn_ouput, _ = torch.ops.npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=q_heads,
input_layout="TND",
atten_mask=self.attn_mask,
sparse_mode=3,
actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
scale=layer.scaling,
next_tokens=0,
)
attn_ouput = attn_ouput[..., : layer.v_head_dim]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
# k_nope, k_rope = k_cache.split( | ||
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 | ||
# ) # todo, there is a bug |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
if False: | ||
req_to_token = forward_batch.req_to_token_pool.req_to_token | ||
attn_output = q_nope.new_zeros((b, s, n, k_dim)) | ||
for seq_idx in range(b): | ||
seq_len_kv = forward_batch.seq_len[seq_idx] | ||
req_pool_idx = forward_batch.req_pool_indices[seq_idx] | ||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] | ||
k = k_cache[per_req_tokens] | ||
v = v_cache[per_req_tokens] | ||
attn_ouput_idx, _ = torch.ops.npu.npu_fused_infer_attention_score( | ||
q_nope[seq_idx].unsqueeze(0), | ||
k.unsqueeze(0), | ||
v.unsqueeze(0), | ||
num_heads=n, | ||
num_key_value_heads=k_heads, | ||
input_layout="BSND", | ||
atten_mask=None, | ||
sparse_mode=0, | ||
scale=layer.scaling, | ||
) | ||
attn_ouput[seq_idx] = attn_ouput_idx.squeeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if False:
block appears to contain a reference implementation and is currently dead code. Please remove it to improve readability and maintainability.
else:
seq_len_kv = forward_batch.seq_lens
attn_ouput, _ = torch.ops.npu.npu_fused_infer_attention_score(
q_nope,
k_cache.view(-1, PAGE_SIZE, k_heads * k_dim),
v_cache.view(-1, PAGE_SIZE, k_heads * k_dim),
num_heads=n,
num_key_value_heads=k_heads,
input_layout="BSND",
atten_mask=None,
block_size=PAGE_SIZE,
block_table=self.forward_metadata.block_kv_indices,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_list,
scale=layer.scaling,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
@@ -53,6 +56,7 @@ | |||
_is_npu = is_npu() | |||
_is_fp8_fnuz = is_fp8_fnuz() | |||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip | |||
_is_npu = is_npu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
@@ -43,6 +43,7 @@ | |||
_is_cuda = is_cuda() | |||
_is_cpu_amx_available = cpu_has_amx_support() | |||
_is_cpu = is_cpu() | |||
_is_npu = is_npu() | |||
if _is_cuda: | |||
from sgl_kernel import int8_scaled_mm | |||
_is_npu = is_npu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
if ( | ||
1 | ||
): # get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"): TODO: Not support fusion ops | ||
return self.forward_native(positions, query, key, offsets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using if (1):
to temporarily disable a code path is unconventional and can be confusing. For better readability and to adhere to common Python style, please use if True:
instead. If this is a temporary change for debugging, it should be reverted to the original logic or removed before merging.
if ( | |
1 | |
): # get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"): TODO: Not support fusion ops | |
return self.forward_native(positions, query, key, offsets) | |
if ( | |
True | |
): # get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"): TODO: Not support fusion ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
_is_fp8_fnuz = is_fp8_fnuz() | ||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip | ||
_is_cpu_amx_available = cpu_has_amx_support() | ||
_is_cpu = is_cpu() | ||
_device_sm = get_device_sm() | ||
_is_npu = is_npu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have fixed
Hi mate, thank you for your contribution to help sglang runs well on ascend npu. However we have a better supporting plan and we would like to discuss it together with ya, pls contact zl19940307@163.com if you are interested. |
|
if _is_npu: | ||
device = "npu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this because 644 line do the same thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we have fixed it.
@@ -92,3 +93,63 @@ def create_flashmla_kv_indices_triton( | |||
data // PAGED_SIZE, | |||
mask=mask_out, | |||
) | |||
|
|||
|
|||
def create_flashinfer_kv_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please ut
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only support the PA scenario and no longer need this function, so we have removed it.
] = data | ||
|
||
|
||
def create_flashmla_kv_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add ut
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment, create_npumla_kv_indices is different from create_flashmla_kv_indices_triton. We use torch native ops to finish the function, which can run on CPU and NPU directly, and the PAGE_SIZE is only support 128.
n_routed_experts_per_rank=0, | ||
): | ||
world_size = get_tensor_model_parallel_world_size() | ||
if world_size > 1 and n_routed_experts_per_rank >= 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if world_size = 1 and n_routed_experts_per_rank >= 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we have fixed it.
params_dtype=params_dtype, | ||
weight_loader=self.weight_loader, | ||
) | ||
kwargs = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should modify the keyword arg name of W8A8Int8MoEMethod
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. We will immediately open an issue to resolve this problem, just as we discussed yesterday.
@@ -766,6 +766,16 @@ def set_mla_kv_buffer_triton( | |||
) | |||
|
|||
|
|||
def set_mla_kv_buffer_npu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have AscendTokenToKVPool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment, for the npumla backend, using AscendTokenToKVPool is not very convenient. This is because in MLATokenToKVPool, the allocation and usage of KV caches are independent across layers. If a use a contiguous buffer of KV cache, additional slice operations may be required.
layer = self.layers[i] | ||
hidden_states, residual = layer( | ||
positions, hidden_states, forward_batch, residual, zero_allocator | ||
) | ||
else: | ||
with get_global_expert_distribution_recorder().with_current_layer(i): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we have reverted this.
@@ -249,6 +249,16 @@ def init_attention_backend(self): | |||
self.topk, | |||
self.speculative_num_steps, | |||
) | |||
elif self.server_args.attention_backend == "npumla": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split into another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we have reverted the files refers to MTP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert the changes of this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we have reverted this file.
python/sglang/srt/utils.py
Outdated
) | ||
except: | ||
is_intel_amx_backend_available = False | ||
return backend not in ["torch_native", "intel_amx", "ascend", "npumla"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we have fixed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we have fixed it.
fe49ec2
to
900e3d3
Compare
f22df63
to
5d96e6a
Compare
What is motivation of these changes? Could you please make performance measurements with FIA and not FIA? I beleave it is slower |
The motivation is:
|
I think there is a small typo here. It should be: export ASCEND_USE_FIA=true (ture) |
Motivation
In this MR, we implemented the NPU fusion kernels npu_fused_infer_attention_score in the Qwen2.5-7b, deepseek-v2-lite and deepseek-v3 models, this fusion kernel is suitable for the graph mode. One needs to export ASCEND_USE_FIA=ture to activate this fusion kernel.
Modifications
Ascend Backend: support npu_fused_infer_attention_score kernel
Add unittest: test_ascend_tp_fia_bf16.py and test_ascend_mla_fia_w8a8int8.py
Memory Management Advancement
We modify the AscendMLAPagedTokenToKVPool class, split the kvbuffer to the k_buffer and v_buffer, in order to remove the split op in MLA attention.
Testing Framework
Unit tests for the Ascend attention backend have been added and can be found at: /test/srt/ascend/test_ascend_tp_fia_bf16.py and /test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
Checklist
Accuracy and performance result
python -m unittest test_npu_mla_backend.TestNpuMlaBackend.test_gsm8k
Pre-commit check