-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Feat: support cuda graph for LoRA #4115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
c9ff88a
enable cuda graph for lora
Qiaolin-Yu 799a3bb
Merge branch 'main' into lora_cuda_graph
Qiaolin-Yu d1b4578
clean code
Qiaolin-Yu 24d5035
delete comments
Qiaolin-Yu f4ee33f
refine
Qiaolin-Yu a66e40c
fix doc
Qiaolin-Yu 56c7394
refine
Qiaolin-Yu e4e94b0
refine comments
Qiaolin-Yu 156be8e
Merge branch 'main' into lora_cuda_graph
Fridge003 f7df354
add ci
Qiaolin-Yu bc9b84e
fix
Qiaolin-Yu 1f6b592
fix
Qiaolin-Yu 39f4a8d
Merge branch 'main' into lora_cuda_graph
Qiaolin-Yu ed17a57
add comments back
Qiaolin-Yu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,6 +72,23 @@ def __init__( | |
self.init_loras() | ||
self.init_lora_memory_pool() | ||
|
||
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): | ||
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph | ||
with torch.device("cuda"): | ||
self.cuda_graph_batch_info = LoRABatchInfo( | ||
bs=self.max_bs_in_cuda_graph, | ||
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32), | ||
seg_indptr=torch.zeros( | ||
self.max_bs_in_cuda_graph + 1, dtype=torch.int32 | ||
), | ||
max_len=0, | ||
weight_indices=torch.zeros( | ||
self.max_bs_in_cuda_graph, dtype=torch.int32 | ||
), | ||
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), | ||
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), | ||
) | ||
|
||
def init_loras(self): | ||
# Config of each LoRA adapter | ||
self.configs: Dict[str, LoRAConfig] = {} | ||
|
@@ -140,39 +157,73 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): | |
if cur_uids == set([None]): | ||
return | ||
|
||
# set up batch info shared by all lora moruldes | ||
# set up batch info shared by all lora modules | ||
bs = forward_batch.batch_size | ||
seg_lens = ( | ||
forward_batch.extend_seq_lens | ||
if forward_batch.forward_mode.is_extend() | ||
else torch.ones(bs, device=self.device) | ||
) | ||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) | ||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) | ||
max_len = int(torch.max(seg_lens)) | ||
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) | ||
|
||
lora_ranks = torch.empty( | ||
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda" | ||
) | ||
scalings = torch.empty( | ||
(self.max_loras_per_batch,), dtype=torch.float, device="cuda" | ||
) | ||
for i, lora_path in enumerate(forward_batch.lora_paths): | ||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) | ||
lora = self.loras[lora_path] | ||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] | ||
scalings[weight_indices[i]] = lora.scaling | ||
|
||
batch_info = LoRABatchInfo( | ||
bs=bs, | ||
seg_lens=seg_lens, | ||
seg_indptr=seg_indptr, | ||
max_len=max_len, | ||
weight_indices=weight_indices, | ||
lora_ranks=lora_ranks, | ||
scalings=scalings, | ||
) | ||
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph: | ||
# Do in-place updates when CUDA graph is enabled. Note that | ||
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph | ||
# will also use these preallocated buffers, no matter whether | ||
# the batch can use CUDA graph or not. | ||
self.cuda_graph_batch_info.bs = bs | ||
if forward_batch.forward_mode.is_extend(): | ||
self.cuda_graph_batch_info.seg_lens[:bs].copy_( | ||
forward_batch.extend_seq_lens | ||
) | ||
else: | ||
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1) | ||
torch.cumsum( | ||
self.cuda_graph_batch_info.seg_lens[:bs], | ||
dim=0, | ||
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1], | ||
) | ||
self.cuda_graph_batch_info.max_len = int( | ||
torch.max(self.cuda_graph_batch_info.seg_lens[:bs]) | ||
) | ||
|
||
for i, lora_path in enumerate(forward_batch.lora_paths): | ||
self.cuda_graph_batch_info.weight_indices[i] = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we need a triton kernel here to set cuda_graph_batch_info, so memcpy synchronization can be avoided. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can be delayed to next PR |
||
self.memory_pool.get_buffer_id(lora_path) | ||
) | ||
lora = self.loras[lora_path] | ||
self.cuda_graph_batch_info.lora_ranks[ | ||
self.cuda_graph_batch_info.weight_indices[i] | ||
] = lora.config.hf_config["r"] | ||
self.cuda_graph_batch_info.scalings[ | ||
self.cuda_graph_batch_info.weight_indices[i] | ||
] = lora.scaling | ||
batch_info = self.cuda_graph_batch_info | ||
else: | ||
seg_lens = ( | ||
forward_batch.extend_seq_lens | ||
if forward_batch.forward_mode.is_extend() | ||
else torch.ones(bs, device=self.device) | ||
) | ||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) | ||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) | ||
max_len = int(torch.max(seg_lens)) | ||
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) | ||
|
||
lora_ranks = torch.empty( | ||
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda" | ||
) | ||
scalings = torch.empty( | ||
(self.max_loras_per_batch,), dtype=torch.float, device="cuda" | ||
) | ||
for i, lora_path in enumerate(forward_batch.lora_paths): | ||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) | ||
lora = self.loras[lora_path] | ||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] | ||
scalings[weight_indices[i]] = lora.scaling | ||
batch_info = LoRABatchInfo( | ||
bs=bs, | ||
seg_lens=seg_lens, | ||
seg_indptr=seg_indptr, | ||
max_len=max_len, | ||
weight_indices=weight_indices, | ||
lora_ranks=lora_ranks, | ||
scalings=scalings, | ||
) | ||
self.lora_backend.set_batch_info(batch_info) | ||
|
||
# call set_lora_info for each lora modules | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.