-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Fix swa eagle verify accuracy for Triton backend #9279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @ispobock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses an issue in the Triton backend's handling of sliding window attention (SWA) by introducing a mechanism to correctly apply custom masks. The core change involves adding an offset to accurately read custom masks within the sliding window, which is essential for proper attention calculation. This fix significantly improves model accuracy, as evidenced by an increase in the MMLU score from 0.777 to 0.816, ensuring more reliable and performant SWA operations.
Highlights
- Sliding Window Attention (SWA) Offset Handling: Introduced
window_kv_offsets
to correctly calculate the start position for custom masks within sliding window layers, ensuring proper attention mask application in the Triton backend. - Triton Backend Integration: Modified
TritonAttnBackend
to manage and pass the newwindow_kv_offsets
through various forward and CUDA graph initialization paths, integrating it into the attention computation flow. - Attention Kernel Logic Update: Updated the
extend_attention.py
Triton kernel to utilize thewindow_kv_offset
when loading custom masks, which is crucial for accurate SWA behavior. - Decode Attention Batch Calculation: Corrected the batch size calculation in
decode_attention.py
kernels by usingq.shape[0]
instead ofkv_indptr.shape[0] - 1
, ensuring correct batch dimension for attention operations. - Performance/Accuracy Improvement: The changes lead to an improvement in model accuracy, as demonstrated by an increase in the MMLU score from 0.777 to 0.816.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a bug in the Triton backend for sliding window attention when used with EAGLE speculative decoding. The core of the fix is to correctly calculate and propagate the key-value cache offset (window_kv_offsets
) for sliding windows. This offset is then used in the Triton attention kernel to accurately index the custom attention mask, which is crucial for the correctness of speculative decoding verification.
The changes are well-structured:
- In
python/sglang/srt/layers/attention/triton_backend.py
,window_kv_offsets
is added to theForwardMetadata
and plumbed through various initialization functions, including those for CUDA graphs. The offset is correctly computed byupdate_sliding_window_buffer
. - In
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
, the Triton kernel_fwd_kernel
is updated to accept and use this offset to correctly calculate the indices for the custom mask. - A minor but good refactoring in
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
makes the batch size calculation more direct and robust.
The implementation appears correct and addresses the issue described in the pull request, as supported by the significant accuracy improvement shown in the MMLU benchmark. The code is clean and the changes are consistent. I have no further recommendations.
Motivation
For sliding window layers, we should add offset to read custom mask in the sliding window.
Accuracy Test
main banch:
this PR: