Skip to content

Conversation

DarkSharpness
Copy link
Collaborator

Motivation

This PR is the integration of #8884 in Python code.

Modifications

Integrate custom set_kv_buffer_kernel into MHATokenToKVPool.

Accuracy Tests

Benchmarking and Profiling

python3 -m sglang.launch_server --model meta-llama/Llama-3.1-70B-Instruct --load-format dummy --prefill-attention fa3 --decode-attention flashinfer --tp 8 --disable-radix
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 10 --random-input 4096 --random-output 1024 --random-range-ratio 1 --warmup-requests 1

Before:

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     10        
Benchmark duration (s):                  13.57     
Total input tokens:                      40960     
Total generated tokens:                  10240     
Total generated tokens (retokenized):    10752     
Request throughput (req/s):              0.74      
Input token throughput (tok/s):          3018.83   
Output token throughput (tok/s):         754.71    
Total token throughput (tok/s):          3773.54   
Concurrency:                             9.99      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   13550.59  
Median E2E Latency (ms):                 13551.41  
---------------Time to First Token----------------
Mean TTFT (ms):                          1219.83   
Median TTFT (ms):                        1320.46   
P99 TTFT (ms):                           1765.93   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           12.05     
Median ITL (ms):                         11.52     
P95 ITL (ms):                            12.03     
P99 ITL (ms):                            12.24     
Max ITL (ms):                            1479.45   
==================================================

After:

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     10        
Benchmark duration (s):                  13.36     
Total input tokens:                      40960     
Total generated tokens:                  10240     
Total generated tokens (retokenized):    10752     
Request throughput (req/s):              0.75      
Input token throughput (tok/s):          3067.01   
Output token throughput (tok/s):         766.75    
Total token throughput (tok/s):          3833.77   
Concurrency:                             9.99      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   13339.98  
Median E2E Latency (ms):                 13341.00  
---------------Time to First Token----------------
Mean TTFT (ms):                          1200.47   
Median TTFT (ms):                        1302.12   
P99 TTFT (ms):                           1727.99   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           11.87     
Median ITL (ms):                         11.35     
P95 ITL (ms):                            11.91     
P99 ITL (ms):                            12.34     
Max ITL (ms):                            1452.32   
==================================================

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @DarkSharpness, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the performance of the KV cache management within the SGLang runtime by integrating a specialized CUDA kernel. This change streamlines the process of setting key-value buffers, leading to more efficient memory operations and slight improvements in overall inference throughput and latency.

Highlights

  • Custom KV Buffer Kernel Integration: Integrates a custom set_kv_buffer_kernel into MHATokenToKVPool to optimize K/V cache updates, specifically for CUDA devices, fulfilling the integration of #8884.
  • Performance Improvement: Replaces the previous stream-based K/V cache copy mechanism with a dedicated kernel, leading to minor improvements in request throughput (0.74 req/s to 0.75 req/s) and end-to-end latency (13550.59 ms to 13339.98 ms) as shown in benchmarks.
  • Code Simplification: Removes self.device_module and self.alt_stream attributes, simplifying the set_kv_buffer logic by offloading the operation to the new kernel for CUDA devices.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a custom CUDA kernel, set_kv_buffer_kernel, to optimize the process of setting key-value caches in MHATokenToKVPool. This change replaces the previous Python-level stream management logic with a more efficient, dedicated kernel, which, as the benchmarks show, results in a performance improvement. The new implementation is also cleaner and more maintainable. I have one suggestion to further simplify the code by removing a redundant conditional check.

Comment on lines +399 to 409
if _is_cuda:
set_kv_buffer_kernel(
k_cache=self.k_buffer[layer_id - self.start_layer],
v_cache=self.v_buffer[layer_id - self.start_layer],
k=cache_k,
v=cache_v,
loc=loc,
)
else:
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if _is_cuda: check appears to be redundant. The set_kv_buffer_kernel function is designed to handle cases where the CUDA kernel is unavailable by falling back to a pure PyTorch implementation within its try...except block. You can simplify this code by removing the conditional logic and calling set_kv_buffer_kernel directly. This would eliminate code duplication and centralize the fallback logic within the kernel wrapper function, improving maintainability.

        set_kv_buffer_kernel(
            k_cache=self.k_buffer[layer_id - self.start_layer],
            v_cache=self.v_buffer[layer_id - self.start_layer],
            k=cache_k,
            v=cache_v,
            loc=loc,
        )

@@ -33,6 +33,7 @@
import torch
import triton
import triton.language as tl
from sgl_kernel import set_kv_buffer_kernel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we only import sgl_kernel in cuda env?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants