Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma

## Kernel backend

* `attention_backend`: The backend for attention computation and KV cache management.
* `attention_backend`: The backend for attention computation and KV cache management, and can be one of `fa3`, `flashinfer`, `triton` or `torch_native`. When deploying deepseek models, this argument is for specifying the MLA backend it uses.
* `sampling_backend`: The backend for sampling.

## Constrained Decoding
Expand Down Expand Up @@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden.
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. **This argument will be deprecated soon! Please use `--attention_backend flashinfer` instead for switching on flashfiner mla!**
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when flashinfer is used as mla backend turned on.
4 changes: 2 additions & 2 deletions docs/references/deepseek.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be

- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.

- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument.

- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.

Expand Down Expand Up @@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec
```
- The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script.
- The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
- Currently when using flashinfer mla wrapper (`--enable-flashinfer-mla`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`.
- Currently when using flashinfer mla wrapper (`--attention-backend flashinfer`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. The MTP feature on FlashAttention 3 backend is still under beta.
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def __init__(
self.device = model_runner.device
self.skip_prefill = skip_prefill

global_config.enable_flashinfer_mla = True

# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
Expand Down
3 changes: 1 addition & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
Expand Down Expand Up @@ -1435,7 +1434,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:

# Create seq_lens_cpu when needed
if (
global_server_args_dict["enable_flashinfer_mla"]
global_server_args_dict["attention_backend"] == "flashinfer_mla"
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):
Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def __init__(
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
Expand Down Expand Up @@ -223,10 +222,14 @@ def model_specific_adjustment(self):
):
# TODO: add MLA optimization on CPU
if server_args.device != "cpu":
if server_args.enable_flashinfer_mla:
if (
server_args.attention_backend == "flashinfer"
or server_args.enable_flashinfer_mla
):
logger.info(
"MLA optimization is turned on. Use flashinfer mla backend."
"MLA optimization is turned on. Use flashinfer backend."
)
# Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend
server_args.attention_backend = "flashinfer_mla"
elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.")
Expand Down
3 changes: 1 addition & 2 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,15 +684,14 @@ def __init__(
self.w_vc = None
self.w_scale = None

self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.attention_backend = global_server_args_dict["attention_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"

def no_absorb(self, forward_batch: ForwardBatch) -> bool:
if self.enable_flashinfer_mla:
if self.attention_backend == "flashinfer_mla":
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
not self.flashinfer_mla_disable_ragged
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class ServerArgs:
tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
enable_flashinfer_mla: bool = False
enable_flashinfer_mla: bool = False # TODO: remove this argument
enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
Expand Down Expand Up @@ -836,7 +836,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--enable-flashinfer-mla",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently flashinfer backend is used by default. Flashinfer mla will be enabled when --attention-backend isn't passed.

action="store_true",
help="Enable FlashInfer MLA optimization",
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
)
parser.add_argument(
"--enable-flashmla",
Expand Down
10 changes: 6 additions & 4 deletions test/srt/test_mla_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def setUpClass(cls):
"--enable-torch-compile",
"--cuda-graph-max-bs",
"2",
"--enable-flashinfer-mla",
"--attention-backend",
"flashinfer",
]
)
cls.process = popen_launch_server(
Expand Down Expand Up @@ -69,8 +70,8 @@ def setUpClass(cls):
"--disable-cuda-graph",
"--cuda-graph-max-bs",
"4",
"--enable-flashinfer-mla",
"--flashinfer-mla-disable-ragged",
"--attention-backend",
"flashinfer",
]
)
cls.process = popen_launch_server(
Expand Down Expand Up @@ -125,7 +126,8 @@ def setUpClass(cls):
"1",
"--speculative-num-draft-tokens",
"4",
"--enable-flashinfer-mla",
"--attention-backend",
"flashinfer",
]
)
cls.process = popen_launch_server(
Expand Down
Loading