@@ -211,6 +211,8 @@ def __init__(
211
211
self .gpu_memory_utilization = gpu_memory_utilization
212
212
self .num_gpu_blocks_override = num_gpu_blocks_override
213
213
self .kv_cache_ratio = kv_cache_ratio
214
+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
215
+ self .kv_cache_ratio = 1.0
214
216
self .enc_dec_block_num = enc_dec_block_num
215
217
self .prealloc_dec_block_slot_num_threshold = prealloc_dec_block_slot_num_threshold
216
218
self .cache_dtype = cache_dtype
@@ -291,7 +293,10 @@ def postprocess(self, num_total_tokens, number_of_tasks):
291
293
self .dec_token_num = self .enc_dec_block_num * self .block_size
292
294
if self .num_gpu_blocks_override is not None :
293
295
self .total_block_num = self .num_gpu_blocks_override
294
- self .prefill_kvcache_block_num = int (self .total_block_num * self .kv_cache_ratio )
296
+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
297
+ self .prefill_kvcache_block_num = self .total_block_num
298
+ else :
299
+ self .prefill_kvcache_block_num = int (self .total_block_num * self .kv_cache_ratio )
295
300
else :
296
301
length = num_total_tokens // number_of_tasks
297
302
block_num = (length + self .block_size - 1 + self .dec_token_num ) // self .block_size
@@ -304,7 +309,10 @@ def reset(self, num_gpu_blocks):
304
309
reset gpu block number
305
310
"""
306
311
self .total_block_num = num_gpu_blocks
307
- self .prefill_kvcache_block_num = int (self .total_block_num * self .kv_cache_ratio )
312
+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
313
+ self .prefill_kvcache_block_num = self .total_block_num
314
+ else :
315
+ self .prefill_kvcache_block_num = int (self .total_block_num * self .kv_cache_ratio )
308
316
llm_logger .info (
309
317
f"Reset block num, the total_block_num:{ self .total_block_num } ,"
310
318
f" prefill_kvcache_block_num:{ self .prefill_kvcache_block_num } "
@@ -796,7 +804,10 @@ def postprocess(self):
796
804
if self .cache_config .enable_chunked_prefill :
797
805
self .max_num_batched_tokens = 2048
798
806
else :
799
- self .max_num_batched_tokens = self .max_model_len
807
+ if not int (os .getenv ('ENABLE_V1_KVCACHE_SCHEDULER' , '0' )):
808
+ self .max_num_batched_tokens = self .max_model_len
809
+ else :
810
+ self .max_num_batched_tokens = 8192
800
811
801
812
if self .long_prefill_token_threshold == 0 :
802
813
self .long_prefill_token_threshold = int (self .max_model_len * 0.04 )
@@ -844,10 +855,11 @@ def check(self):
844
855
)
845
856
846
857
if not self .cache_config .enable_chunked_prefill :
847
- assert self .max_num_batched_tokens >= self .max_model_len , (
848
- f"max_num_batched_tokens: { self .max_num_batched_tokens } "
849
- f"should be larger than or equal to max_model_len: { self .max_model_len } "
850
- )
858
+ if not int (os .getenv ('ENABLE_V1_KVCACHE_SCHEDULER' , '0' )):
859
+ assert self .max_num_batched_tokens >= self .max_model_len , (
860
+ f"max_num_batched_tokens: { self .max_num_batched_tokens } "
861
+ f"should be larger than or equal to max_model_len: { self .max_model_len } "
862
+ )
851
863
else :
852
864
assert self .max_num_batched_tokens >= self .cache_config .block_size , (
853
865
f"max_num_batched_tokens: { self .max_num_batched_tokens } "
0 commit comments