Skip to content
Merged
2 changes: 1 addition & 1 deletion benchmark/lora/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def launch_server(args):
for i in range(NUM_LORAS):
lora_name = f"lora{i}"
cmd += f"{lora_name}={lora_path} "
cmd += f"--disable-radix --disable-cuda-graph "
cmd += f"--disable-radix "
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
cmd += f"--max-running-requests {args.max_running_requests} "
cmd += f"--lora-backend {args.lora_backend} "
Expand Down
6 changes: 3 additions & 3 deletions docs/backend/lora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" --max-loras-per-batch 1 --lora-backend triton \\\n",
" --disable-cuda-graph --disable-radix-cache\n",
" --disable-radix-cache\n",
"\"\"\"\n",
")\n",
"\n",
Expand Down Expand Up @@ -136,7 +136,7 @@
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-cuda-graph --disable-radix-cache\n",
" --disable-radix-cache\n",
"\"\"\"\n",
")\n",
"\n",
Expand Down Expand Up @@ -182,7 +182,7 @@
"source": [
"## Future Works\n",
"\n",
"The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Currently Cuda graph and radix attention are not incompatible with LoRA and must be manually disabled. Other features, including Unified Paging, Cutlass backend, and dynamic loading/unloadingm, are still under development."
"The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Currently radix attention is incompatible with LoRA and must be manually disabled. Other features, including Unified Paging, Cutlass backend, and dynamic loading/unloadingm, are still under development."
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s

| Arguments | Description | Defaults |
|----------|-------------|---------|
| `lora_paths` | List of adapters to apply to your model. Each batch element uses the proper LoRA adapter. `cuda_graph` and `radix_attention` are not supported with this, so they must be disabled manually. See related [issues](https://github.com/sgl-project/sglang/issues/2929). | None |
| `lora_paths` | List of adapters to apply to your model. Each batch element uses the proper LoRA adapter. `radix_attention` is not supported with this, so it must be disabled manually. See related [issues](https://github.com/sgl-project/sglang/issues/2929). | None |
| `max_loras_per_batch` | Maximum number of LoRAs allowed in a running batch, including the base model. | `8` |
| `lora_backend` | Backend used to run GEMM kernels for LoRA modules. Can be `triton` or `flashinfer`. | `triton` |

Expand Down
44 changes: 35 additions & 9 deletions python/sglang/srt/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,19 @@ def set_lora_info(
self.set_lora = True
self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b:
# TODO: avoid using contiguous() in GPU.
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self.B_buffer_gate_up = torch.cat(
(B_buffer[0], B_buffer[1]), dim=-2
).contiguous()
if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
self.B_buffer_gate_up = torch.empty(
(
B_buffer[0].shape[0],
2 * B_buffer[0].shape[1],
B_buffer[0].shape[2],
),
dtype=B_buffer[0].dtype,
device=B_buffer[0].device,
)
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
else:
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])

Expand Down Expand Up @@ -171,7 +179,7 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):


class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__(
def __init__(
self,
base_layer: QKVParallelLinear,
lora_backend: BaseLoRABackend,
Expand All @@ -194,12 +202,30 @@ def set_lora_info(
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]

# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
self.B_buffer_qkv = torch.empty(
(
B_buffer_q[0].shape[0],
output_dim_q + 2 * output_dim_kv,
B_buffer_q[0].shape[2],
),
dtype=B_buffer_q[0].dtype,
device=B_buffer_q[0].device,
)
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
B_buffer_kv[0]
)
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
B_buffer_kv[1]
)

# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
if not hasattr(self, "output_offset") or self.output_offset is None:
self.output_offset = torch.empty(
4, dtype=torch.int32, device=B_buffer_q.device
)
self.output_offset[:4] = torch.tensor(
[
0,
output_dim_q,
Expand Down
113 changes: 82 additions & 31 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def __init__(
self.init_loras()
self.init_lora_memory_pool()

def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=self.max_bs_in_cuda_graph,
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
seg_indptr=torch.zeros(
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
),
max_len=0,
weight_indices=torch.zeros(
self.max_bs_in_cuda_graph, dtype=torch.int32
),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
)

def init_loras(self):
# Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {}
Expand Down Expand Up @@ -140,39 +157,73 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
if cur_uids == set([None]):
return

# set up batch info shared by all lora moruldes
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)

lora_ranks = torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
)
scalings = torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling

batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
)
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
# Do in-place updates when CUDA graph is enabled. Note that
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
# will also use these preallocated buffers, no matter whether
# the batch can use CUDA graph or not.
self.cuda_graph_batch_info.bs = bs
if forward_batch.forward_mode.is_extend():
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
forward_batch.extend_seq_lens
)
else:
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[:bs],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
)
self.cuda_graph_batch_info.max_len = int(
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
)

for i, lora_path in enumerate(forward_batch.lora_paths):
self.cuda_graph_batch_info.weight_indices[i] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need a triton kernel here to set cuda_graph_batch_info, so memcpy synchronization can be avoided.

Copy link
Collaborator

@Fridge003 Fridge003 Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be delayed to next PR

self.memory_pool.get_buffer_id(lora_path)
)
lora = self.loras[lora_path]
self.cuda_graph_batch_info.lora_ranks[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
batch_info = self.cuda_graph_batch_info
else:
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)

lora_ranks = torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
)
scalings = torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
)
self.lora_backend.set_batch_info(batch_info)

# call set_lora_info for each lora modules
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def __init__(self, model_runner: ModelRunner):
if self.enable_torch_compile:
set_torch_compile_config()

if self.model_runner.server_args.lora_paths is not None:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)

# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
Expand Down Expand Up @@ -403,6 +406,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
self.capture_hidden_mode = (
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)
if self.model_runner.server_args.lora_paths is not None:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
# values if lora is enabled.
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
else:
lora_paths = None

forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
Expand All @@ -424,8 +434,12 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths,
)

if lora_paths is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,6 @@ def check_server_args(self):
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
Expand Down
4 changes: 2 additions & 2 deletions test/srt/models/lora/test_lora_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
DEFAULT_PROMPTS,
TORCH_DTYPES,
LoRAModelCase,
run_batch_lora_test,
run_lora_test_one_by_one,
)

from sglang.test.test_utils import CustomTestCase, is_in_ci
Expand All @@ -42,7 +42,7 @@ def _run_backend_on_model_cases(self, model_cases: List[LoRAModelCase]):
)
for torch_dtype in TORCH_DTYPES:
for backend in BACKENDS:
run_batch_lora_test(
run_lora_test_one_by_one(
prompts,
model_case,
torch_dtype,
Expand Down
Loading
Loading