From ce45d15b87fc96a374953d2fd37b13b68cfb95f3 Mon Sep 17 00:00:00 2001 From: Sundara Raman Ramachandran Date: Wed, 20 Aug 2025 05:51:26 +0000 Subject: [PATCH] Dynamic Batch Tokenizer --- benchmark/api/bench_common.py | 671 ++++++++++++++++++ benchmark/api/bench_embeddings.py | 122 ++++ benchmark/api/bench_score.py | 157 ++++ benchmark/score/bench_score.py | 603 ---------------- .../managers/async_dynamic_batch_tokenizer.py | 170 +++++ .../sglang/srt/managers/tokenizer_manager.py | 117 ++- python/sglang/srt/server_args.py | 29 + .../srt/test_async_dynamic_batch_tokenizer.py | 295 ++++++++ 8 files changed, 1550 insertions(+), 614 deletions(-) create mode 100644 benchmark/api/bench_common.py create mode 100644 benchmark/api/bench_embeddings.py create mode 100644 benchmark/api/bench_score.py delete mode 100644 benchmark/score/bench_score.py create mode 100644 python/sglang/srt/managers/async_dynamic_batch_tokenizer.py create mode 100644 test/srt/test_async_dynamic_batch_tokenizer.py diff --git a/benchmark/api/bench_common.py b/benchmark/api/bench_common.py new file mode 100644 index 00000000000..3953f0509ff --- /dev/null +++ b/benchmark/api/bench_common.py @@ -0,0 +1,671 @@ +""" +Common utilities for SGLang benchmark scripts. + +This module contains shared code for benchmarking different SGLang APIs +including scoring, embeddings, and other endpoints. +""" + +import asyncio +import concurrent.futures +import json +import os +import random +from statistics import mean +from typing import Any, Callable, Dict, List, Optional, Tuple + +import aiohttp +import numpy as np +from tqdm import tqdm +from transformers import AutoTokenizer + + +class BenchmarkConfig: + """Configuration for benchmark parameters.""" + + def __init__(self): + # Common benchmark settings + self.server_type = "HTTP" + self.rps_values = [70] + self.duration_secs_values = [60] + self.num_unique_requests = 100 + self.distribution = "POISSON" # Options: "CONSTANT", "POISSON" + self.profile = False + + # Special token for text generation + self.special_replicated_token = "<|im_start|>" + + +def generate_text_with_token_count( + model_path: str, num_tokens: int, special_token: str = "<|im_start|>" +) -> str: + """ + Generate text with precise token count using a replicated token. + + Args: + model_path: Path to the model for tokenizer + num_tokens: Target number of tokens + special_token: Token to replicate + + Returns: + Generated text with approximately the target token count + """ + tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Verify token count + special_token_count = len(tokenizer.encode(special_token, add_special_tokens=False)) + + if special_token_count == 1: + # Simple case: token maps to exactly 1 token + return special_token * num_tokens + else: + print(f"Special token '{special_token}' produces {special_token_count} tokens") + # Handle case where special token produces multiple tokens + repetitions = (num_tokens + special_token_count - 1) // special_token_count + text = special_token * repetitions + + # Verify we got the expected token count + actual_tokens = len(tokenizer.encode(text, add_special_tokens=False)) + if actual_tokens < num_tokens: + print(f"Warning: Generated {actual_tokens} tokens, expected {num_tokens}") + + return text + + +def prepare_all_requests_parallel( + num_requests: int, + item_count: int, + build_request_func: Callable[[int, int], Tuple[int, Any]], + config: BenchmarkConfig, + description: str = "requests", +) -> List[Any]: + """ + Generic function to generate unique requests in parallel, then reuse them. + + Args: + num_requests: Total number of requests needed + item_count: Number of items per request (batch size) + build_request_func: Function that takes (index, item_count) and returns (index, request_data) + config: Benchmark configuration + description: Description for progress bars + + Returns: + List of request data objects + """ + + def build_request_wrapper(index): + """Wrapper to call the provided build_request_func.""" + try: + return build_request_func(index, item_count) + except Exception as e: + print(f"Error building request {index}: {e}") + return (index, None) + + # Generate only the unique requests + unique_requests = [None] * config.num_unique_requests + max_workers = min(8, os.cpu_count() or 1) # Limit to 8 threads max + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for i in tqdm( + range(config.num_unique_requests), + desc=f"Submitting {description} generation tasks", + ): + future = executor.submit(build_request_wrapper, i) + futures.append(future) + + # Collect results as they complete + for f in tqdm( + concurrent.futures.as_completed(futures), + desc=f"Building unique {description}", + total=config.num_unique_requests, + ): + try: + index, req_data = f.result() + if req_data is not None: + unique_requests[index] = req_data + else: + print(f"Failed to build request {index}") + except Exception as e: + print(f"Error processing request result: {e}") + + # Check if we have any valid requests + valid_requests = [req for req in unique_requests if req is not None] + if not valid_requests: + raise RuntimeError("Failed to generate any valid requests") + + print( + f"Successfully generated {len(valid_requests)} out of " + f"{config.num_unique_requests} unique {description}" + ) + + # Create the full request list by cycling through unique requests + print( + f"Reusing {len(valid_requests)} unique {description} to create " + f"{num_requests} total requests..." + ) + all_requests = [] + for i in tqdm(range(num_requests), desc=f"Reusing {description}"): + unique_index = i % len(valid_requests) + all_requests.append(valid_requests[unique_index]) + + print(f"All {description} prepared.\n") + return all_requests + + +async def sleep_with_distribution(distribution: str, rps: float) -> None: + """ + Sleep according to the specified distribution pattern. + + Args: + distribution: "CONSTANT" or "POISSON" + rps: Requests per second rate + """ + if distribution == "CONSTANT": + interval = 1 / rps + await asyncio.sleep(interval) + elif distribution == "POISSON": + # For Poisson process, inter-arrival times follow exponential distribution + interval = random.expovariate(rps) + await asyncio.sleep(interval) + else: + raise ValueError( + f"Unknown distribution: {distribution}. Use 'CONSTANT' or 'POISSON'." + ) + + +def build_http_request_json(request_data: Any) -> str: + """ + Generic function to build HTTP request JSON. + + Args: + request_data: The data to serialize to JSON + + Returns: + JSON string representation of the request data + """ + return json.dumps(request_data) + + +async def make_http_call( + session: aiohttp.ClientSession, + request_data: Any, + request_id: int, + results_queue: asyncio.Queue, + http_url: str, + response_validator: Callable[[Dict[str, Any]], bool], + api_name: str = "API", +) -> None: + """ + Generic HTTP call function for API requests. + + Args: + session: aiohttp client session + request_data: Data to send in the request + request_id: Unique identifier for this request + results_queue: Queue to put results + http_url: URL to send the request to + response_validator: Function to validate the response JSON + api_name: Name of the API for error messages + """ + try: + start_time = asyncio.get_event_loop().time() + + request_json = build_http_request_json(request_data) + headers = {"Content-Type": "application/json"} + + async with session.post(http_url, data=request_json, headers=headers) as resp: + resp_text = await resp.text() + + if resp.status != 200: + print( + f"[HTTP] {api_name} Request {request_id} failed with status " + f"{resp.status}: {resp_text}" + ) + completion_time = asyncio.get_event_loop().time() + await results_queue.put((request_id, 0, False, completion_time)) + return + + # Parse and validate response + try: + response_data = json.loads(resp_text) + success = response_validator(response_data) + if not success: + print( + f"[HTTP] {api_name} Request {request_id} failed response validation" + ) + except json.JSONDecodeError: + print( + f"[HTTP] {api_name} Request {request_id} failed to parse JSON response" + ) + success = False + + completion_time = asyncio.get_event_loop().time() + elapsed_time = (completion_time - start_time) * 1000 + await results_queue.put((request_id, elapsed_time, success, completion_time)) + + except Exception as e: + print(f"[HTTP] {api_name} Error for request {request_id}: {e}") + completion_time = asyncio.get_event_loop().time() + await results_queue.put((request_id, 0, False, completion_time)) + + +async def send_profile_request( + profile_text: str, http_url: str, session: Optional[aiohttp.ClientSession] = None +) -> None: + """ + Send a profile request (START_PROFILE or STOP_PROFILE) and wait for completion. + + Args: + profile_text: "START_PROFILE" or "STOP_PROFILE" + http_url: Base HTTP URL (will derive profile endpoints from this) + session: Optional aiohttp session to use + """ + try: + if session: + print(f"Sending {profile_text} request via HTTP...") + + # Determine the correct endpoint + if "/v1/" in http_url: + base_url = http_url.rsplit("/v1/", 1)[0] # Remove /v1/xxx + else: + base_url = http_url.rsplit("/", 1)[0] # Remove last path component + + if profile_text == "START_PROFILE": + endpoint_url = f"{base_url}/start_profile" + elif profile_text == "STOP_PROFILE": + endpoint_url = f"{base_url}/stop_profile" + else: + print(f"Unknown profile request: {profile_text}") + return + + headers = {"Content-Type": "application/json"} + + async with session.post(endpoint_url, headers=headers) as resp: + resp_text = await resp.text() + if resp.status == 200: + print(f"{profile_text} request completed") + else: + print( + f"{profile_text} request failed with status " + f"{resp.status}: {resp_text}" + ) + else: + print(f"Cannot send {profile_text} request - missing session") + + except Exception as e: + print(f"Error sending {profile_text} request: {e}") + + +async def process_results( + results_queue: asyncio.Queue, + num_requests: int, + send_duration: float, + total_duration: float, + rps: int, + duration_secs: int, + item_count: int, + test_start_time: float, + config: BenchmarkConfig, + http_mode: str = "UNKNOWN", +) -> List[Dict[str, Any]]: + """ + Process benchmark results and group them by minute intervals. + + Args: + results_queue: Queue containing result tuples + num_requests: Total number of requests sent + send_duration: Time taken to send all requests + total_duration: Total time for all requests to complete + rps: Target requests per second + duration_secs: Test duration in seconds + item_count: Number of items per request + test_start_time: Start time of the test + config: Benchmark configuration + http_mode: Description of the HTTP mode/API being tested + + Returns: + List of dictionaries containing minute-by-minute results + """ + all_results = [] + + # Collect all results + for _ in range(num_requests): + result = await results_queue.get() + request_id, elapsed_time, success, completion_time = result + all_results.append( + { + "request_id": request_id, + "elapsed_time": elapsed_time, + "success": success, + "completion_time": completion_time, + } + ) + + # Group results by minute intervals + minute_results = [] + num_minutes = int(duration_secs // 60) + (1 if duration_secs % 60 > 0 else 0) + + for minute in range(num_minutes): + minute_start = test_start_time + (minute * 60) + minute_end = test_start_time + ((minute + 1) * 60) + + # Filter results that completed in this minute + minute_data = [ + r for r in all_results if minute_start <= r["completion_time"] < minute_end + ] + + response_times = [r["elapsed_time"] for r in minute_data if r["success"]] + successful_requests = len([r for r in minute_data if r["success"]]) + failed_requests = len([r for r in minute_data if not r["success"]]) + + avg_response_time = mean(response_times) if response_times else 0 + + # Calculate percentiles using numpy + if response_times: + p50 = np.percentile(response_times, 50) + p90 = np.percentile(response_times, 90) + p99 = np.percentile(response_times, 99) + else: + p50 = p90 = p99 = 0 + + minute_result = { + "test_duration_secs": duration_secs, + "minute_interval": minute + 1, + "target_rps": rps, + "item_count": item_count, + "server_type": config.server_type, + "distribution": config.distribution, + "unique_requests": config.num_unique_requests, + "total_requests": len(minute_data), + "successful_requests": successful_requests, + "failed_requests": failed_requests, + "send_duration_secs": send_duration, + "total_duration_secs": total_duration, + "avg_response_time_ms": avg_response_time, + "p50_response_time_ms": p50, + "p90_response_time_ms": p90, + "p99_response_time_ms": p99, + } + + minute_results.append(minute_result) + + print( + f"\nMinute {minute + 1} Summary for RPS {rps}, " + f"Duration {duration_secs}s, Item Count {item_count}:" + ) + print(f" Requests completed in minute: {len(minute_data)}") + print(f" Successful requests: {successful_requests}") + print(f" Failed requests: {failed_requests}") + print(f" Average response time: {avg_response_time:.2f} ms") + print(f" P50 response time: {p50:.2f} ms") + print(f" P90 response time: {p90:.2f} ms") + print(f" P99 response time: {p99:.2f} ms") + + # Print overall summary + all_response_times = [r["elapsed_time"] for r in all_results if r["success"]] + total_successful = len([r for r in all_results if r["success"]]) + total_failed = len([r for r in all_results if not r["success"]]) + + overall_avg = mean(all_response_times) if all_response_times else 0 + if all_response_times: + overall_p50 = np.percentile(all_response_times, 50) + overall_p90 = np.percentile(all_response_times, 90) + overall_p99 = np.percentile(all_response_times, 99) + else: + overall_p50 = overall_p90 = overall_p99 = 0 + + print( + f"\nOverall Summary for RPS {rps}, Duration {duration_secs}s, " + f"Item Count {item_count}:" + ) + print(f" Test duration: {duration_secs} seconds") + print(f" Server type: {config.server_type}") + print(f" HTTP mode: {http_mode}") + print(f" Target RPS: {rps}") + print(f" Item count: {item_count}") + print(f" Distribution: {config.distribution}") + print(f" Unique requests generated: {config.num_unique_requests}") + print(f" Total requests sent: {num_requests}") + print(f" Successful requests: {total_successful}") + print(f" Failed requests: {total_failed}") + print(f" Time to send all requests: {send_duration:.2f} seconds") + print(f" Time for all requests to complete: {total_duration:.2f} seconds") + print(f" Average response time: {overall_avg:.2f} ms") + print(f" P50 response time: {overall_p50:.2f} ms") + print(f" P90 response time: {overall_p90:.2f} ms") + print(f" P99 response time: {overall_p99:.2f} ms\n") + + return minute_results + + +def print_csv_results(all_results: List[Dict[str, Any]]) -> None: + """ + Print benchmark results in CSV format. + + Args: + all_results: List of result dictionaries from process_results + """ + print("\n" + "=" * 80) + print("FINAL CSV RESULTS:") + print("=" * 80) + + # CSV Header + headers = [ + "test_duration_secs", + "minute_interval", + "target_rps", + "item_count", + "server_type", + "distribution", + "unique_requests", + "total_requests", + "successful_requests", + "failed_requests", + "send_duration_secs", + "total_duration_secs", + "avg_response_time_ms", + "p50_response_time_ms", + "p90_response_time_ms", + "p99_response_time_ms", + ] + print(",".join(headers)) + + # CSV Data + for result in all_results: + row = [ + result["test_duration_secs"], + result["minute_interval"], + result["target_rps"], + result["item_count"], + result["server_type"], + result["distribution"], + result["unique_requests"], + result["total_requests"], + result["successful_requests"], + result["failed_requests"], + f"{result['send_duration_secs']:.2f}", + f"{result['total_duration_secs']:.2f}", + f"{result['avg_response_time_ms']:.2f}", + f"{result['p50_response_time_ms']:.2f}", + f"{result['p90_response_time_ms']:.2f}", + f"{result['p99_response_time_ms']:.2f}", + ] + print(",".join(map(str, row))) + + +async def run_benchmark_main( + config: BenchmarkConfig, + run_single_benchmark_func, + benchmark_name: str, + http_url: str, + item_count_values: List[int], + additional_info: Optional[Dict[str, Any]] = None, +) -> None: + """ + Main benchmark orchestration function. + + Args: + config: Benchmark configuration + run_single_benchmark_func: Async function to run a single benchmark + benchmark_name: Name of the benchmark (e.g., "SCORING", "EMBEDDINGS") + http_url: URL of the API endpoint + item_count_values: List of item counts to test + additional_info: Additional information to print in the header + """ + total_combinations = ( + len(config.duration_secs_values) + * len(config.rps_values) + * len(item_count_values) + ) + + print( + f"Running benchmarks for {len(config.duration_secs_values)} duration " + f"values, {len(config.rps_values)} RPS values, and " + f"{len(item_count_values)} item count values = " + f"{total_combinations} total combinations" + ) + print(f"Server Type: {config.server_type}") + print(f"HTTP Mode: {benchmark_name}") + print(f"API URL: {http_url}") + + if additional_info: + for key, value in additional_info.items(): + print(f"{key}: {value}") + + print(f"Items per request (batch size): {item_count_values}") + print(f"Profiling Enabled: {config.profile}") + print(f"Duration values: {config.duration_secs_values}") + print(f"RPS values: {config.rps_values}") + print(f"Item count values: {item_count_values}") + print("=" * 80) + + all_results = [] + + for duration_secs in config.duration_secs_values: + for rps in config.rps_values: + for item_count in item_count_values: + result = await run_single_benchmark_func(rps, duration_secs, item_count) + all_results.extend(result) # Extend with minute results + + print_csv_results(all_results) + + +async def run_generic_benchmark( + rps: int, + duration_secs: int, + item_count: int, + config: BenchmarkConfig, + http_url: str, + build_request_func: Callable[[int, int], Tuple[int, Any]], + response_validator: Callable[[Dict[str, Any]], bool], + api_name: str, + request_description: str = "requests", +) -> List[Dict[str, Any]]: + """ + Generic benchmark runner that can be used for different APIs. + + Args: + rps: Requests per second + duration_secs: Duration of the test in seconds + item_count: Number of items per request (batch size) + config: Benchmark configuration + http_url: URL of the API endpoint + build_request_func: Function to build individual requests + response_validator: Function to validate API responses + api_name: Name of the API for logging + request_description: Description for progress bars + + Returns: + List of dictionaries containing minute-by-minute results + """ + num_requests = int(rps * duration_secs) + print( + f"Starting benchmark with RPS={rps}, Duration={duration_secs}s, " + f"Item Count={item_count}, num_requests={num_requests}" + ) + print(f"Server Type: {config.server_type}") + print(f"HTTP Mode: {api_name}") + print(f"Profiling Enabled: {config.profile}") + + # Build requests in parallel (unmeasured) + all_requests = prepare_all_requests_parallel( + num_requests, item_count, build_request_func, config, request_description + ) + + results_queue = asyncio.Queue() + tasks = [] + + # Track timing for sending requests + send_start_time = asyncio.get_event_loop().time() + + # HTTP implementation + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300) + ) as session: + + # Send START_PROFILE if profiling is enabled + if config.profile: + await send_profile_request("START_PROFILE", http_url, session=session) + + # Add progress bar for sending requests + with tqdm( + total=len(all_requests), + desc=f"Sending HTTP {request_description} at {rps} RPS", + unit="req", + ) as pbar: + for i, request_data in enumerate(all_requests): + request_id = i + 1 + tasks.append( + asyncio.create_task( + make_http_call( + session, + request_data, + request_id, + results_queue, + http_url, + response_validator, + api_name, + ) + ) + ) + + # Update progress bar + pbar.update(1) + + # Throttle based on distribution + if i < len(all_requests) - 1: + await sleep_with_distribution(config.distribution, rps) + + send_end_time = asyncio.get_event_loop().time() + send_duration = send_end_time - send_start_time + + # Wait for all requests to complete with progress tracking + print(f"Waiting for {len(tasks)} HTTP {request_description} to complete...") + with tqdm( + total=len(tasks), desc=f"Completing HTTP {request_description}", unit="req" + ) as completion_pbar: + completed_tasks = [] + for task in asyncio.as_completed(tasks): + await task + completed_tasks.append(task) + completion_pbar.update(1) + + # Send STOP_PROFILE if profiling is enabled + if config.profile: + await send_profile_request("STOP_PROFILE", http_url, session=session) + + completion_end_time = asyncio.get_event_loop().time() + total_duration = completion_end_time - send_start_time + + return await process_results( + results_queue, + num_requests, + send_duration, + total_duration, + rps, + duration_secs, + item_count, + send_start_time, + config, + api_name, + ) diff --git a/benchmark/api/bench_embeddings.py b/benchmark/api/bench_embeddings.py new file mode 100644 index 00000000000..0261cf3c9b7 --- /dev/null +++ b/benchmark/api/bench_embeddings.py @@ -0,0 +1,122 @@ +""" +SGLang Embeddings Benchmark Script + +This script benchmarks SGLang's /v1/embeddings API performance using HTTP requests. + +Features: +- HTTP-only implementation +- Uses /v1/embeddings API endpoint directly +- Configurable RPS, duration, and batch sizes +- Progress tracking and detailed metrics +- Poisson and constant request distributions + +Usage: +- Update configuration variables at the top of the file +- Ensure SGLang server is running on the configured HTTP_URL +- Run: python bench_embeddings.py +""" + +import asyncio +import logging + +from bench_common import ( + BenchmarkConfig, + generate_text_with_token_count, + run_benchmark_main, + run_generic_benchmark, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +############################################################################### +# CONFIG +############################################################################### +# Create benchmark configuration +config = BenchmarkConfig() +config.rps_values = [500] +config.duration_secs_values = [60] +config.num_unique_requests = 100 +config.distribution = "POISSON" +config.profile = False + +# HTTP Configuration +HTTP_URL = "http://localhost:30000/v1/embeddings" + +# Embeddings API Config +EMBEDDINGS_MODEL_PATH = "/shared/public/sharing/suramach/Qwen3-0.6B" +BATCH_SIZE = [1] # Number of items per request (batch size) + +# Configurable input token length +EMBEDDINGS_INPUT_TOKENS = 500 # Default token length + +# Generate input text with the specified token length +EMBEDDINGS_INPUT_TEXT = generate_text_with_token_count( + EMBEDDINGS_MODEL_PATH, EMBEDDINGS_INPUT_TOKENS, config.special_replicated_token +) + + +############################################################################### +# REQUEST GENERATION (in parallel) +############################################################################### +def build_embeddings_request(index: int, item_count: int) -> tuple: + """Build a single embeddings request.""" + try: + # For embeddings, input can be a string or list of strings + if item_count == 1: + input_data = EMBEDDINGS_INPUT_TEXT + else: + input_data = [EMBEDDINGS_INPUT_TEXT for _ in range(item_count)] + req = { + "input": input_data, + "model": EMBEDDINGS_MODEL_PATH, + } + return (index, req) + except Exception as e: + logger.error(f"Error building request {index}: {e}") + return (index, None) + + +def validate_embeddings_response(response_data: dict) -> bool: + """Validate embeddings API response.""" + return "data" in response_data + + +############################################################################### +# MAIN +############################################################################### +async def run_benchmark(rps, duration_secs, item_count): + """Run a single embeddings benchmark with the given RPS value.""" + return await run_generic_benchmark( + rps=rps, + duration_secs=duration_secs, + item_count=item_count, + config=config, + http_url=HTTP_URL, + build_request_func=build_embeddings_request, + response_validator=validate_embeddings_response, + api_name="EMBEDDINGS", + request_description="embeddings requests", + ) + + +async def main(): + additional_info = { + "Input text length": f"{EMBEDDINGS_INPUT_TOKENS} tokens", + "Input text preview": ( + EMBEDDINGS_INPUT_TEXT[:100] + "..." + if len(EMBEDDINGS_INPUT_TEXT) > 100 + else EMBEDDINGS_INPUT_TEXT + ), + } + + await run_benchmark_main( + config, run_benchmark, "EMBEDDINGS", HTTP_URL, BATCH_SIZE, additional_info + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmark/api/bench_score.py b/benchmark/api/bench_score.py new file mode 100644 index 00000000000..126974923f6 --- /dev/null +++ b/benchmark/api/bench_score.py @@ -0,0 +1,157 @@ +""" +SGLang Scoring Benchmark Script + +This script benchmarks SGLang's scoring API performance using HTTP requests. + +Current Features: +- HTTP-only implementation (open source compatible) +- Uses /v1/score API endpoint directly +- Single item scoring with batching support +- Configurable RPS, duration, and batch sizes +- Progress tracking and detailed metrics +- Poisson and constant request distributions + +Usage: +- Update configuration variables at the top of the file +- Ensure SGLang server is running on the configured HTTP_URL +- Run: python bench_score.py +- Each request will contain ITEM_COUNT_VALUES items for batch scoring + +""" + +import asyncio +import os + +from bench_common import ( + BenchmarkConfig, + generate_text_with_token_count, + run_benchmark_main, + run_generic_benchmark, +) +from transformers import AutoTokenizer + +############################################################################### +# CONFIG +############################################################################### +# Create benchmark configuration +config = BenchmarkConfig() +config.rps_values = [70] +config.duration_secs_values = [60] +config.num_unique_requests = 100 +config.distribution = "POISSON" +config.profile = False + +# HTTP Configuration +HTTP_URL = "http://localhost:30000/v1/score" # Use score API directly + +# Score API Config +# ITEM_COUNT_VALUES determines number of items per score request (batch size) +SCORE_QUERY_TOKENS = 120 +SCORE_ITEM_TOKENS = 180 +SCORE_MODEL_PATH = "Qwen/Qwen3-0.6B" +SCORE_LABEL_TOKEN_IDS = [9454, 2753] # Yes/No token IDs +ITEM_COUNT_VALUES = [10] # Number of items per request +# Directory for profiler output +SGLANG_TORCH_PROFILER_DIR = "/shared/user/sglang-oss-trace/remove-decode" +if config.profile: + os.environ["SGLANG_TORCH_PROFILER_DIR"] = SGLANG_TORCH_PROFILER_DIR + +# Special token to replicate for precise token counting +SPECIAL_REPLICATED_TOKEN = "<|im_start|>" + + +############################################################################### +# REQUEST GENERATION (in parallel) +############################################################################### +def create_score_request_builder(): + """Create a score request builder function with shared tokenizer.""" + # Load tokenizer once here to verify special token and get precise counts + print("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH) + + # Verify that our special token produces exactly 1 token + special_token_count = len( + tokenizer.encode(config.special_replicated_token, add_special_tokens=False) + ) + print( + f"Special token '{config.special_replicated_token}' produces " + f"{special_token_count} token(s)" + ) + + def generate_text_with_token_count_local(num_toks): + """Generate text with precise token count using replicated token.""" + return generate_text_with_token_count( + SCORE_MODEL_PATH, num_toks, config.special_replicated_token + ) + + def build_score_request(index: int, item_count: int) -> tuple: + """Build a single score request.""" + try: + # Generate query and items for score API + query = generate_text_with_token_count_local(SCORE_QUERY_TOKENS) + items = [ + generate_text_with_token_count_local(SCORE_ITEM_TOKENS) + for _ in range(item_count) + ] + + # Return as dict for score API format + score_data = { + "query": query, + "items": items, + "label_token_ids": SCORE_LABEL_TOKEN_IDS, + "model": SCORE_MODEL_PATH, + } + return (index, score_data) + + except Exception as e: + print(f"Error building request {index}: {e}") + return (index, None) + + return build_score_request + + +def validate_score_response(response_data: dict) -> bool: + """Validate score API response.""" + return "scores" in response_data or "logprobs" in response_data + + +############################################################################### +# MAIN +############################################################################### +async def run_benchmark(rps, duration_secs, item_count): + """Run a single benchmark with the given RPS value.""" + # Create the request builder function with shared tokenizer + build_request_func = create_score_request_builder() + + return await run_generic_benchmark( + rps=rps, + duration_secs=duration_secs, + item_count=item_count, + config=config, + http_url=HTTP_URL, + build_request_func=build_request_func, + response_validator=validate_score_response, + api_name="SINGLE_ITEM_SCORING", + request_description="score requests", + ) + + +async def main(): + """Main function that runs benchmarks for all RPS values.""" + additional_info = { + "Query tokens per request": SCORE_QUERY_TOKENS, + "Item tokens per item": SCORE_ITEM_TOKENS, + } + + await run_benchmark_main( + config, + run_benchmark, + "SINGLE_ITEM_SCORING", + HTTP_URL, + ITEM_COUNT_VALUES, + additional_info, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmark/score/bench_score.py b/benchmark/score/bench_score.py deleted file mode 100644 index 60bcea24c51..00000000000 --- a/benchmark/score/bench_score.py +++ /dev/null @@ -1,603 +0,0 @@ -""" -SGLang Scoring Benchmark Script - -This script benchmarks SGLang's scoring API performance using HTTP requests. - -Current Features: -- HTTP-only implementation (open source compatible) -- Uses /v1/score API endpoint directly -- Single item scoring with batching support -- Configurable RPS, duration, and batch sizes -- Progress tracking and detailed metrics -- Poisson and constant request distributions - -Usage: -- Update configuration variables at the top of the file -- Ensure SGLang server is running on the configured HTTP_URL -- Run: python bench_score.py -- Each request will contain ITEM_COUNT_VALUES items for batch scoring - -""" - -import asyncio -import concurrent.futures # For parallel prompt generation -import json -import os -import random -from statistics import mean - -import aiohttp -import numpy as np -from tqdm import tqdm -from transformers import AutoTokenizer - -############################################################################### -# CONFIG -############################################################################### -# Server Configuration -SERVER_TYPE = "HTTP" # Fixed to HTTP for open source - -# HTTP Configuration -HTTP_URL = "http://localhost:30000/v1/score" # Use score API directly - -# Score API Config -# ITEM_COUNT_VALUES determines number of items per score request (batch size) -SCORE_QUERY_TOKENS = 120 -SCORE_ITEM_TOKENS = 180 -SCORE_MODEL_PATH = "Qwen/Qwen3-0.6B" -SCORE_LABEL_TOKEN_IDS = [9454, 2753] # Yes/No token IDs - -# Array of RPS values to test -RPS_VALUES = [70] -# Array of duration values to test -DURATION_SECS_VALUES = [60] # Duration values in seconds -# Array of item count values to test -ITEM_COUNT_VALUES = [10] # Number of items per request -# Number of unique requests to generate (will be reused) -NUM_UNIQUE_REQUESTS = 100 -DISTRIBUTION = "POISSON" # Options: "CONSTANT", "POISSON" - -# Profiling Configuration -PROFILE = False # Enable profiling with START_PROFILE/STOP_PROFILE prompts -# Directory for profiler output -SGLANG_TORCH_PROFILER_DIR = "/shared/user/sglang-oss-trace/remove-decode" -if PROFILE: - os.environ["SGLANG_TORCH_PROFILER_DIR"] = SGLANG_TORCH_PROFILER_DIR - -# Special token to replicate for precise token counting -SPECIAL_REPLICATED_TOKEN = "<|im_start|>" - - -############################################################################### -# REQUEST GENERATION (in parallel) -############################################################################### -def prepare_all_requests_parallel(num_requests, item_count): - """ - Generates unique requests in parallel, then reuses them to create the - full request list. Returns a list of str prompts for HTTP. - """ - # Load tokenizer once here to verify special token and get precise counts - print("Loading tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH) - - # Verify that our special token produces exactly 1 token - special_token_count = len( - tokenizer.encode(SPECIAL_REPLICATED_TOKEN, add_special_tokens=False) - ) - print( - f"Special token '{SPECIAL_REPLICATED_TOKEN}' produces " - f"{special_token_count} token(s)" - ) - - def generate_text_with_token_count(num_toks): - """Generate text with precise token count using replicated token.""" - if special_token_count == 1: - # Simple case: token maps to exactly 1 token - return SPECIAL_REPLICATED_TOKEN * num_toks - else: - print( - f"Special token '{SPECIAL_REPLICATED_TOKEN}' produces more than 1 token!!!" - ) - # Handle case where special token produces multiple tokens - # Repeat the token enough times to get at least num_toks tokens - repetitions = (num_toks + special_token_count - 1) // special_token_count - text = SPECIAL_REPLICATED_TOKEN * repetitions - - # Verify we got the expected token count (approximately) - actual_tokens = len(tokenizer.encode(text, add_special_tokens=False)) - if actual_tokens < num_toks: - print( - f"Warning: Generated {actual_tokens} tokens, " - f"expected {num_toks}" - ) - - return text - - def build_request(index): - """Build a single request using the shared tokenizer.""" - try: - # Generate query and items for score API - query = generate_text_with_token_count(SCORE_QUERY_TOKENS) - items = [ - generate_text_with_token_count(SCORE_ITEM_TOKENS) - for _ in range(item_count) - ] - - # Return as dict for score API format - score_data = { - "query": query, - "items": items, - "label_token_ids": SCORE_LABEL_TOKEN_IDS, - "model": SCORE_MODEL_PATH, - } - return (index, score_data) - - except Exception as e: - print(f"Error building request {index}: {e}") - return (index, None) - - # Generate only the unique requests - unique_requests = [None] * NUM_UNIQUE_REQUESTS - - # Use ThreadPoolExecutor instead of ProcessPoolExecutor to avoid - # tokenizer loading issues across processes - max_workers = min(8, os.cpu_count() or 1) # Limit to 8 threads max - - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for i in tqdm( - range(NUM_UNIQUE_REQUESTS), desc="Submitting prompt generation tasks" - ): - future = executor.submit(build_request, i) - futures.append(future) - - # Collect results as they complete - for f in tqdm( - concurrent.futures.as_completed(futures), - desc="Building unique requests", - total=NUM_UNIQUE_REQUESTS, - ): - try: - index, req_data = f.result() - if req_data is not None: - unique_requests[index] = req_data - else: - print(f"Failed to build request {index}") - except Exception as e: - print(f"Error processing request result: {e}") - - # Check if we have any valid requests - valid_requests = [req for req in unique_requests if req is not None] - if not valid_requests: - raise RuntimeError("Failed to generate any valid requests") - - print( - f"Successfully generated {len(valid_requests)} out of " - f"{NUM_UNIQUE_REQUESTS} unique requests" - ) - - # Create the full request list by cycling through unique requests - print( - f"Reusing {len(valid_requests)} unique requests to create " - f"{num_requests} total requests..." - ) - all_requests = [] - for i in tqdm(range(num_requests), desc="Reusing requests"): - unique_index = i % len(valid_requests) - all_requests.append(valid_requests[unique_index]) - - print("All prompts/requests prepared.\n") - return all_requests - - -############################################################################### -# PROFILING HELPERS -############################################################################### -async def send_profile_request(profile_text, item_count, session=None): - """Send a profile request and wait for completion.""" - try: - if session: - print(f"Sending {profile_text} request via HTTP...") - - # Determine the correct endpoint - base_url = HTTP_URL.rsplit("/", 2)[0] # Remove /v1/score - if profile_text == "START_PROFILE": - endpoint_url = f"{base_url}/start_profile" - elif profile_text == "STOP_PROFILE": - endpoint_url = f"{base_url}/stop_profile" - else: - print(f"Unknown profile request: {profile_text}") - return - - headers = {"Content-Type": "application/json"} - - async with session.post(endpoint_url, headers=headers) as resp: - resp_text = await resp.text() - if resp.status == 200: - print(f"{profile_text} request completed") - else: - print( - f"{profile_text} request failed with status " - f"{resp.status}: {resp_text}" - ) - else: - print(f"Cannot send {profile_text} request - missing session") - - except Exception as e: - print(f"Error sending {profile_text} request: {e}") - - -############################################################################### -# HTTP CALLS -############################################################################### -def build_http_request_json(score_data): - """Build HTTP request JSON for /v1/score endpoint. - - Score API format: - { - "query": "Generated query text with SCORE_QUERY_TOKENS tokens", - "items": ["item1", "item2", ...], # Items to score with SCORE_ITEM_TOKENS each - "label_token_ids": [token_id1, token_id2], # Target token IDs - "model": "/path/to/model" - } - - Args: - score_data: A dict containing query, items, label_token_ids, and model - """ - # score_data is already in the correct format from build_request - return json.dumps(score_data) - - -async def make_http_call(session, score_data, request_id, results_queue): - """HTTP call to /v1/score endpoint.""" - try: - start_time = asyncio.get_event_loop().time() - - request_json = build_http_request_json(score_data) - headers = {"Content-Type": "application/json"} - - async with session.post(HTTP_URL, data=request_json, headers=headers) as resp: - resp_text = await resp.text() - - if resp.status != 200: - print( - f"[HTTP] Request {request_id} failed with status " - f"{resp.status}: {resp_text}" - ) - completion_time = asyncio.get_event_loop().time() - await results_queue.put((request_id, 0, False, completion_time)) - return - - # Parse score API response - try: - response_data = json.loads(resp_text) - # Score API returns scores for each item - # For now, just verify we got a valid response - if "scores" in response_data or "logprobs" in response_data: - success = True - else: - print( - f"[HTTP] Request {request_id} missing expected fields in response" - ) - success = False - except json.JSONDecodeError: - print(f"[HTTP] Request {request_id} failed to parse JSON response") - success = False - - completion_time = asyncio.get_event_loop().time() - elapsed_time = (completion_time - start_time) * 1000 - await results_queue.put((request_id, elapsed_time, success, completion_time)) - - except Exception as e: - print(f"[HTTP] Error for request {request_id}: {e}") - completion_time = asyncio.get_event_loop().time() - await results_queue.put((request_id, 0, False, completion_time)) - - -############################################################################### -# RESULTS -############################################################################### -async def process_results( - results_queue, - num_requests, - send_duration, - total_duration, - rps, - duration_secs, - item_count, - test_start_time, -): - """Processes results and groups them by minute intervals. - Returns a list of dictionaries, one for each minute.""" - all_results = [] - - # Collect all results - for _ in range(num_requests): - result = await results_queue.get() - request_id, elapsed_time, success, completion_time = result - all_results.append( - { - "request_id": request_id, - "elapsed_time": elapsed_time, - "success": success, - "completion_time": completion_time, - } - ) - - # Group results by minute intervals - minute_results = [] - num_minutes = int(duration_secs // 60) + (1 if duration_secs % 60 > 0 else 0) - - for minute in range(num_minutes): - minute_start = test_start_time + (minute * 60) - minute_end = test_start_time + ((minute + 1) * 60) - - # Filter results that completed in this minute - minute_data = [ - r for r in all_results if minute_start <= r["completion_time"] < minute_end - ] - - response_times = [r["elapsed_time"] for r in minute_data if r["success"]] - successful_requests = len([r for r in minute_data if r["success"]]) - failed_requests = len([r for r in minute_data if not r["success"]]) - - avg_response_time = mean(response_times) if response_times else 0 - - # Calculate percentiles using numpy - if response_times: - p50 = np.percentile(response_times, 50) - p90 = np.percentile(response_times, 90) - p99 = np.percentile(response_times, 99) - else: - p50 = p90 = p99 = 0 - - minute_result = { - "test_duration_secs": duration_secs, - "minute_interval": minute + 1, - "target_rps": rps, - "item_count": item_count, - "server_type": SERVER_TYPE, - "distribution": DISTRIBUTION, - "unique_requests": NUM_UNIQUE_REQUESTS, - "total_requests": len(minute_data), - "successful_requests": successful_requests, - "failed_requests": failed_requests, - "send_duration_secs": send_duration, - "total_duration_secs": total_duration, - "avg_response_time_ms": avg_response_time, - "p50_response_time_ms": p50, - "p90_response_time_ms": p90, - "p99_response_time_ms": p99, - } - - minute_results.append(minute_result) - - print( - f"\nMinute {minute + 1} Summary for RPS {rps}, " - f"Duration {duration_secs}s, Item Count {item_count}:" - ) - print(f" Requests completed in minute: {len(minute_data)}") - print(f" Successful requests: {successful_requests}") - print(f" Failed requests: {failed_requests}") - print(f" Average response time: {avg_response_time:.2f} ms") - print(f" P50 response time: {p50:.2f} ms") - print(f" P90 response time: {p90:.2f} ms") - print(f" P99 response time: {p99:.2f} ms") - - # Also print overall summary - all_response_times = [r["elapsed_time"] for r in all_results if r["success"]] - total_successful = len([r for r in all_results if r["success"]]) - total_failed = len([r for r in all_results if not r["success"]]) - - overall_avg = mean(all_response_times) if all_response_times else 0 - if all_response_times: - overall_p50 = np.percentile(all_response_times, 50) - overall_p90 = np.percentile(all_response_times, 90) - overall_p99 = np.percentile(all_response_times, 99) - else: - overall_p50 = overall_p90 = overall_p99 = 0 - - print( - f"\nOverall Summary for RPS {rps}, Duration {duration_secs}s, " - f"Item Count {item_count}:" - ) - print(f" Test duration: {duration_secs} seconds") - print(f" Server type: {SERVER_TYPE}") - print(f" HTTP mode: SINGLE_ITEM_SCORING") - print(f" Target RPS: {rps}") - print(f" Item count: {item_count}") - print(f" Distribution: {DISTRIBUTION}") - print(f" Unique requests generated: {NUM_UNIQUE_REQUESTS}") - print(f" Total requests sent: {num_requests}") - print(f" Successful requests: {total_successful}") - print(f" Failed requests: {total_failed}") - print(f" Time to send all requests: {send_duration:.2f} seconds") - print(f" Time for all requests to complete: {total_duration:.2f} seconds") - print(f" Average response time: {overall_avg:.2f} ms") - print(f" P50 response time: {overall_p50:.2f} ms") - print(f" P90 response time: {overall_p90:.2f} ms") - print(f" P99 response time: {overall_p99:.2f} ms\n") - - return minute_results - - -############################################################################### -# MAIN -############################################################################### -async def run_benchmark(rps, duration_secs, item_count): - """Run a single benchmark with the given RPS value.""" - num_requests = int(rps * duration_secs) - print( - f"Starting benchmark with RPS={rps}, Duration={duration_secs}s, " - f"Item Count={item_count}, num_requests={num_requests}" - ) - print(f"Server Type: {SERVER_TYPE}") - print(f"HTTP Mode: SINGLE_ITEM_SCORING") - print(f"Profiling Enabled: {PROFILE}") - - # Build requests in parallel (unmeasured) - all_requests = prepare_all_requests_parallel(num_requests, item_count) - - results_queue = asyncio.Queue() - tasks = [] - - # Track timing for sending requests - send_start_time = asyncio.get_event_loop().time() - - # HTTP implementation (open source only supports HTTP with /v1/score API) - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=300) - ) as session: - - # Send START_PROFILE if profiling is enabled - if PROFILE: - await send_profile_request("START_PROFILE", item_count, session=session) - - # Add progress bar for sending requests - with tqdm( - total=len(all_requests), - desc=f"Sending HTTP score requests at {rps} RPS", - unit="req", - ) as pbar: - for i, score_data in enumerate(all_requests): - request_id = i + 1 - tasks.append( - asyncio.create_task( - make_http_call(session, score_data, request_id, results_queue) - ) - ) - - # Update progress bar - pbar.update(1) - - # Throttle based on distribution - if i < len(all_requests) - 1: - if DISTRIBUTION == "CONSTANT": - interval = 1 / rps - await asyncio.sleep(interval) - elif DISTRIBUTION == "POISSON": - # For Poisson process, inter-arrival times follow - # exponential distribution - interval = random.expovariate(rps) - await asyncio.sleep(interval) - else: - raise ValueError( - f"Unknown distribution: {DISTRIBUTION}. " - f"Use 'CONSTANT' or 'POISSON'." - ) - - send_end_time = asyncio.get_event_loop().time() - send_duration = send_end_time - send_start_time - - # Wait for all requests to complete with progress tracking - print(f"Waiting for {len(tasks)} HTTP score requests to complete...") - with tqdm( - total=len(tasks), desc="Completing HTTP score requests", unit="req" - ) as completion_pbar: - completed_tasks = [] - for task in asyncio.as_completed(tasks): - await task - completed_tasks.append(task) - completion_pbar.update(1) - - # Send STOP_PROFILE if profiling is enabled - if PROFILE: - await send_profile_request("STOP_PROFILE", item_count, session=session) - - completion_end_time = asyncio.get_event_loop().time() - total_duration = completion_end_time - send_start_time - - return await process_results( - results_queue, - num_requests, - send_duration, - total_duration, - rps, - duration_secs, - item_count, - send_start_time, - ) - - -async def main(): - """Main function that runs benchmarks for all RPS values.""" - total_combinations = ( - len(DURATION_SECS_VALUES) * len(RPS_VALUES) * len(ITEM_COUNT_VALUES) - ) - print( - f"Running benchmarks for {len(DURATION_SECS_VALUES)} duration " - f"values, {len(RPS_VALUES)} RPS values, and " - f"{len(ITEM_COUNT_VALUES)} item count values = " - f"{total_combinations} total combinations" - ) - print(f"Server Type: {SERVER_TYPE}") - print(f"HTTP Mode: SINGLE_ITEM_SCORING") - print(f"Score API URL: {HTTP_URL}") - print(f"Query tokens per request: {SCORE_QUERY_TOKENS}") - print(f"Item tokens per item: {SCORE_ITEM_TOKENS}") - print(f"Items per request (batch size): {ITEM_COUNT_VALUES}") - print(f"Profiling Enabled: {PROFILE}") - print(f"Duration values: {DURATION_SECS_VALUES}") - print(f"RPS values: {RPS_VALUES}") - print(f"Item count values: {ITEM_COUNT_VALUES}") - print("=" * 80) - - all_results = [] - - for duration_secs in DURATION_SECS_VALUES: - for rps in RPS_VALUES: - for item_count in ITEM_COUNT_VALUES: - result = await run_benchmark(rps, duration_secs, item_count) - all_results.extend(result) # Extend with minute results - - # Print CSV header and results - print("\n" + "=" * 80) - print("FINAL CSV RESULTS:") - print("=" * 80) - - # CSV Header - headers = [ - "test_duration_secs", - "minute_interval", - "target_rps", - "item_count", - "server_type", - "distribution", - "unique_requests", - "total_requests", - "successful_requests", - "failed_requests", - "send_duration_secs", - "total_duration_secs", - "avg_response_time_ms", - "p50_response_time_ms", - "p90_response_time_ms", - "p99_response_time_ms", - ] - print(",".join(headers)) - - # CSV Data - for result in all_results: - row = [ - result["test_duration_secs"], - result["minute_interval"], - result["target_rps"], - result["item_count"], - result["server_type"], - result["distribution"], - result["unique_requests"], - result["total_requests"], - result["successful_requests"], - result["failed_requests"], - f"{result['send_duration_secs']:.2f}", - f"{result['total_duration_secs']:.2f}", - f"{result['avg_response_time_ms']:.2f}", - f"{result['p50_response_time_ms']:.2f}", - f"{result['p90_response_time_ms']:.2f}", - f"{result['p99_response_time_ms']:.2f}", - ] - print(",".join(map(str, row))) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/sglang/srt/managers/async_dynamic_batch_tokenizer.py b/python/sglang/srt/managers/async_dynamic_batch_tokenizer.py new file mode 100644 index 00000000000..ef1a8307f3c --- /dev/null +++ b/python/sglang/srt/managers/async_dynamic_batch_tokenizer.py @@ -0,0 +1,170 @@ +""" +Asynchronous dynamic batch tokenizer for SGLang. + +This module provides an async tokenizer with dynamic batching capabilities +to reduce tokenization overhead when multiple requests arrive concurrently. +""" + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class AsyncDynamicbatchTokenizer: + """Asynchronous tokenizer with dynamic batching for single string prompts. + + Dynamically batches pending encode requests from a queue to reduce overhead. + Only handles single string prompts - regular batch processing of multiple + strings per request should be handled at a higher level. + A single-thread ThreadPoolExecutor is used so the event loop stays responsive. + + Note: Uses lazy initialization for asyncio components because this class + is instantiated in TokenizerManager.__init__() before the event loop starts. + """ + + def __init__( + self, + tokenizer, + max_batch_size: int = 32, + batch_wait_timeout_s: float = 0.002, + ) -> None: + self.tokenizer = tokenizer + self.max_batch_size = max_batch_size + self.batch_wait_timeout_s = batch_wait_timeout_s + + # Single queue for all encode requests - initialized lazily + self._queue: Optional[asyncio.Queue] = None + self._batcher_task: Optional[asyncio.Task] = None + + # Single-thread executor for blocking tokenizer calls + self._executor = ThreadPoolExecutor(max_workers=1) + self._initialized = False + + def _ensure_initialized(self): + """Lazy initialization of event loop dependent components.""" + if not self._initialized: + self._queue = asyncio.Queue() + self._batcher_task = asyncio.create_task(self._dynamic_batch_loop()) + self._initialized = True + + async def __call__(self, prompt: str, **kwargs) -> Any: + """Encode a single prompt.""" + return await self.encode(prompt, **kwargs) + + async def encode(self, prompt: str, **kwargs) -> Any: + """Encode a single prompt.""" + self._ensure_initialized() + result_future: asyncio.Future = asyncio.get_running_loop().create_future() + await self._queue.put((prompt, kwargs, result_future)) + return await result_future + + async def _dynamic_batch_loop(self): + """Dynamically batch incoming encode requests for efficiency.""" + while True: + try: + # Get the first request + prompt, kwargs, result_future = await self._queue.get() + + # Collect requests into dynamic batch + prompts = [prompt] + kwargs_list = [kwargs] + result_futures = [result_future] + + # Check if there are more items immediately available in the queue + # If queue is empty, process single item immediately without timeout + if self._queue.empty(): + # No other requests waiting, process immediately + pass + else: + # There might be more requests, wait for dynamic batching opportunity + start_time = asyncio.get_running_loop().time() + + # Collect more requests up to max_batch_size or batch_wait_timeout_s + while len(prompts) < self.max_batch_size: + elapsed = asyncio.get_running_loop().time() - start_time + if elapsed >= self.batch_wait_timeout_s: + break + + remaining_time = self.batch_wait_timeout_s - elapsed + try: + prompt, kwargs, result_future = await asyncio.wait_for( + self._queue.get(), remaining_time + ) + prompts.append(prompt) + kwargs_list.append(kwargs) + result_futures.append(result_future) + except asyncio.TimeoutError: + break + + # Log dynamic batch information + logger.debug( + f"AsyncDynamicbatchTokenizer: Processing dynamic batch of size {len(prompts)}" + ) + + # Process the dynamic batch + await self._process_dynamic_batch(prompts, kwargs_list, result_futures) + + except Exception as e: + logger.error(f"Error in dynamic batch loop: {e}") + # Continue the loop to handle other requests + + async def _process_dynamic_batch( + self, + prompts: List[str], + kwargs_list: List[Dict], + result_futures: List[asyncio.Future], + ) -> None: + """Process a dynamic batch of encode requests for single string prompts.""" + # Check if all kwargs are identical for efficient batch processing + can_batch = len(set(str(sorted(kw.items())) for kw in kwargs_list)) == 1 + kwargs = kwargs_list[0] if can_batch else None + + try: + # If every request uses identical kwargs we can run a single + # batch tokenizer call for a big speed-up. + if can_batch and len(prompts) > 1: + encode_fn = partial(self.tokenizer, prompts, **kwargs) + results = await asyncio.get_running_loop().run_in_executor( + self._executor, encode_fn + ) + + for i, fut in enumerate(result_futures): + if not fut.done(): + data = {k: v[i] for k, v in results.items()} + fut.set_result(data) + else: + # Process each request individually due to different kwargs + if len(prompts) > 1 and not can_batch: + logger.warning( + f"AsyncDynamicbatchTokenizer: Dynamic batching disabled for batch of {len(prompts)} " + f"requests due to differing kwargs. This reduces performance benefits. " + f"Consider using consistent tokenization parameters across requests." + ) + + encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ + self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs_list) + ] + results = await asyncio.get_running_loop().run_in_executor( + self._executor, encode_fn + ) + + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + logger.error(f"Error in dynamic batch processing: {e}") + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + def __del__(self): + """Clean up background tasks.""" + if hasattr(self, "_batcher_task") and self._batcher_task: + if not self._batcher_task.done(): + self._batcher_task.cancel() + if hasattr(self, "_executor"): + self._executor.shutdown(wait=False) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3a81a363679..790ca8cbf44 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -65,6 +65,7 @@ get_tokenizer_from_processor, ) from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry +from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -253,6 +254,18 @@ def __init__( trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, ) + # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal) + if ( + server_args.enable_dynamic_batch_tokenizer + and not server_args.skip_tokenizer_init + ): + self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer( + self.tokenizer, + max_batch_size=server_args.dynamic_batch_tokenizer_batch_size, + batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout, + ) + else: + self.async_dynamic_batch_tokenizer = None # Init inter-process communication context = zmq.asyncio.Context(2) @@ -500,6 +513,82 @@ async def generate_request( ): yield response + async def _tokenize_texts( + self, texts: Union[str, List[str]], is_cross_encoder: bool = False + ) -> Union[ + Tuple[List[int], Optional[List[int]]], + Tuple[List[List[int]], Optional[List[List[int]]]], + ]: + """ + Tokenize text(s) using the appropriate tokenizer strategy. + + This method chooses between async dynamic batch tokenizer (for single texts only) + and regular tokenizer (for batch texts or fallback). + + Args: + texts: Single string or list of strings to tokenize. + is_cross_encoder: Whether to return token_type_ids for cross-encoder models. + Used for tasks like sentence similarity where token types matter. + + Returns: + For single text input: + Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids) + For batch text input: + Tuple[List[List[int]], Optional[List[List[int]]]]: (input_ids_batch, token_type_ids_batch) + + token_type_ids is None unless is_cross_encoder=True. + """ + if not texts or self.tokenizer is None: + raise ValueError("texts cannot be empty and tokenizer must be initialized") + + is_single: bool = isinstance(texts, str) + # normalized to list format + text_list: List[str] = [texts] if is_single else texts + kwargs: Dict[str, Any] = ( + {"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {} + ) + + # Use async dynamic batch tokenizer only for single texts + # For multiple texts in a batch, use regular tokenizer which is efficient + if self.async_dynamic_batch_tokenizer is not None and is_single: + logger.debug("Using async dynamic batch tokenizer for single text") + result: Dict[str, Any] = await self.async_dynamic_batch_tokenizer.encode( + text_list[0], **kwargs + ) + + # Extract and wrap in batch format for consistency + input_ids: List[List[int]] = [result["input_ids"]] + token_type_ids: Optional[List[List[int]]] = ( + [result["token_type_ids"]] + if is_cross_encoder and result.get("token_type_ids") + else None + ) + else: + # Use regular tokenizer - much more efficient for batch requests + logger.debug(f"Using regular tokenizer for {len(text_list)} texts") + encoded: Dict[str, Any] = self.tokenizer(text_list, **kwargs) + + # input_ids is nested since we pass List[str] to tokenizer + # Example: [[101, 7592, 102]] or [[101, 7592, 102], [101, 2088, 102]] + input_ids: List[List[int]] = encoded["input_ids"] + + # token_type_ids is nested as well + # Example: [[0, 0, 0]] or [[0, 0, 0], [0, 0, 0]] or None + token_type_ids: Optional[List[List[int]]] = ( + encoded.get("token_type_ids") if is_cross_encoder else None + ) + + # Return in the expected format + if is_single: + # Extract single sequence from batch format + single_input_ids: List[int] = input_ids[0] + single_token_type_ids: Optional[List[int]] = ( + token_type_ids[0] if token_type_ids else None + ) + return single_input_ids, single_token_type_ids + else: + return input_ids, token_type_ids + async def _tokenize_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -530,14 +619,10 @@ async def _tokenize_one_request( "accept text prompts. Please provide input_ids or re-initialize " "the engine with skip_tokenizer_init=False." ) - encoded = self.tokenizer( - input_text, return_token_type_ids=is_cross_encoder_request - ) - input_ids = encoded["input_ids"] - if is_cross_encoder_request: - input_ids = encoded["input_ids"][0] - token_type_ids = encoded.get("token_type_ids", [None])[0] + input_ids, token_type_ids = await self._tokenize_texts( + input_text, is_cross_encoder_request + ) if self.mm_processor and obj.contains_mm_input(): if not isinstance(obj.image_data, list): @@ -692,17 +777,27 @@ async def _batch_tokenize_and_process( requests = [obj[i] for i in range(batch_size)] texts = [req.text for req in requests] - # Batch tokenize all texts - encoded = self.tokenizer(texts) - input_ids_list = encoded["input_ids"] + # Check if any request is a cross-encoder request + is_cross_encoder_request = any( + isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request + for req in requests + ) + + # Batch tokenize all texts using unified method + input_ids_list, token_type_ids_list = await self._tokenize_texts( + texts, is_cross_encoder_request + ) # Process all requests tokenized_objs = [] for i, req in enumerate(requests): self._validate_one_request(obj[i], input_ids_list[i]) + token_type_ids = ( + token_type_ids_list[i] if token_type_ids_list is not None else None + ) tokenized_objs.append( self._create_tokenized_object( - req, req.text, input_ids_list[i], None, None + req, req.text, input_ids_list[i], None, None, token_type_ids ) ) logger.debug(f"Completed batch processing for {batch_size} requests") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 78515e898ee..c852c33c463 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -262,6 +262,11 @@ class ServerArgs: enable_return_hidden_states: bool = False scheduler_recv_interval: int = 1 + # Dynamic batch tokenizer + enable_dynamic_batch_tokenizer: bool = False + dynamic_batch_tokenizer_batch_size: int = 32 + dynamic_batch_tokenizer_batch_timeout: float = 0.002 + # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_input_file: Optional[str] = None @@ -704,6 +709,13 @@ def print_deprecated_warning(message: str): self.disable_cuda_graph = True logger.warning("Cuda graph is disabled for prefill server") + # Validation: prevent both tokenizer batching features from being enabled + if self.enable_tokenizer_batch_encode and self.enable_dynamic_batch_tokenizer: + raise ValueError( + "Cannot enable both --enable-tokenizer-batch-encode and --enable-dynamic-batch-tokenizer. " + "Please choose one tokenizer batching approach." + ) + # Propagate env vars os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( "1" if self.enable_torch_compile else "0" @@ -1886,6 +1898,23 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Only dump the tensors for prefill requests (i.e. batch size > 1).", ) + parser.add_argument( + "--enable-dynamic-batch-tokenizer", + action="store_true", + help="Enable async dynamic batch tokenizer for improved performance when multiple requests arrive concurrently.", + ) + parser.add_argument( + "--dynamic-batch-tokenizer-batch-size", + type=int, + default=ServerArgs.dynamic_batch_tokenizer_batch_size, + help="[Only used if --enable-dynamic-batch-tokenizer is set] Maximum batch size for dynamic batch tokenizer.", + ) + parser.add_argument( + "--dynamic-batch-tokenizer-batch-timeout", + type=float, + default=ServerArgs.dynamic_batch_tokenizer_batch_timeout, + help="[Only used if --enable-dynamic-batch-tokenizer is set] Timeout in seconds for batching tokenization requests.", + ) # PD disaggregation parser.add_argument( diff --git a/test/srt/test_async_dynamic_batch_tokenizer.py b/test/srt/test_async_dynamic_batch_tokenizer.py new file mode 100644 index 00000000000..930e23e549b --- /dev/null +++ b/test/srt/test_async_dynamic_batch_tokenizer.py @@ -0,0 +1,295 @@ +""" +Unit tests for AsyncDynamicbatchTokenizer. + +Tests the async dynamic batching functionality for tokenization, +including batch efficiency, timeout handling, and error cases. +""" + +import asyncio +import logging +import time +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from transformers import AutoTokenizer + +from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer + + +class TestAsyncDynamicbatchTokenizer: + """Test suite for AsyncDynamicbatchTokenizer.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer that behaves like HuggingFace tokenizer.""" + + def mock_encode(texts, **kwargs): + is_single = isinstance(texts, str) + if is_single: + texts = [texts] + + # Simulate tokenization - convert text to mock token ids + input_ids = [] + token_type_ids = [] + + for text in texts: + # Simple mock: text length determines number of tokens + tokens = [i for i in range(len(text.split()))] + input_ids.append(tokens) + + if kwargs.get("return_token_type_ids", False): + token_type_ids.append([0] * len(tokens)) + + result = {"input_ids": input_ids} + if kwargs.get("return_token_type_ids", False): + result["token_type_ids"] = token_type_ids + + # For single inputs, return individual result (not wrapped in a list) + if is_single: + result = {"input_ids": input_ids[0]} + if kwargs.get("return_token_type_ids", False): + result["token_type_ids"] = token_type_ids[0] + + # Create a proper BatchEncoding-like object that supports dict operations + class MockBatchEncoding(dict): + def __init__(self, data): + super().__init__(data) + for key, value in data.items(): + setattr(self, key, value) + + return MockBatchEncoding(result) + + # Return the function directly - the AsyncDynamicbatchTokenizer will call it + return mock_encode + + @pytest.fixture + def async_tokenizer(self, mock_tokenizer): + """Create AsyncDynamicbatchTokenizer instance.""" + return AsyncDynamicbatchTokenizer( + tokenizer=mock_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01 + ) + + @pytest.mark.asyncio + async def test_single_request(self, async_tokenizer): + """Test tokenizing a single request.""" + text = "hello world" + result = await async_tokenizer.encode(text) + + assert "input_ids" in result + assert result["input_ids"] == [0, 1] # 2 words -> 2 tokens + + @pytest.mark.asyncio + async def test_single_request_with_token_type_ids(self, async_tokenizer): + """Test tokenizing with token type IDs.""" + text = "hello world" + result = await async_tokenizer.encode(text, return_token_type_ids=True) + + assert "input_ids" in result + assert "token_type_ids" in result + assert result["input_ids"] == [0, 1] + assert result["token_type_ids"] == [0, 0] + + @pytest.mark.asyncio + async def test_concurrent_requests_same_kwargs(self, async_tokenizer): + """Test that concurrent requests with same kwargs get batched.""" + texts = ["hello world", "how are you", "fine thanks", "good morning"] + + # Start all requests concurrently + tasks = [async_tokenizer.encode(text) for text in texts] + results = await asyncio.gather(*tasks) + + # Verify all results + assert len(results) == 4 + for i, result in enumerate(results): + assert "input_ids" in result + expected_tokens = list(range(len(texts[i].split()))) + assert result["input_ids"] == expected_tokens + + @pytest.mark.asyncio + async def test_concurrent_requests_different_kwargs(self, async_tokenizer): + """Test that requests with different kwargs are processed individually.""" + text1 = "hello world" + text2 = "how are you" + + # One with token_type_ids, one without + task1 = async_tokenizer.encode(text1, return_token_type_ids=True) + task2 = async_tokenizer.encode(text2) + + result1, result2 = await asyncio.gather(task1, task2) + + # First result should have token_type_ids + assert "input_ids" in result1 + assert "token_type_ids" in result1 + assert result1["input_ids"] == [0, 1] + assert result1["token_type_ids"] == [0, 0] + + # Second result should not have token_type_ids + assert "input_ids" in result2 + assert "token_type_ids" not in result2 + assert result2["input_ids"] == [0, 1, 2] + + @pytest.mark.asyncio + async def test_batch_timeout(self, async_tokenizer): + """Test that batching respects timeout.""" + # Send first request + task1 = asyncio.create_task(async_tokenizer.encode("hello world")) + + # Wait longer than batch timeout + await asyncio.sleep(0.02) # Longer than 0.01s timeout + + # Send second request + task2 = asyncio.create_task(async_tokenizer.encode("how are you")) + + results = await asyncio.gather(task1, task2) + + # Both should complete successfully + assert len(results) == 2 + assert results[0]["input_ids"] == [0, 1] + assert results[1]["input_ids"] == [0, 1, 2] + + @pytest.mark.asyncio + async def test_max_batch_size_limit(self, async_tokenizer): + """Test that batching respects max_batch_size.""" + # Send more requests than max_batch_size (4) + texts = [f"text {i}" for i in range(6)] + tasks = [async_tokenizer.encode(text) for text in texts] + + results = await asyncio.gather(*tasks) + + # All should complete successfully + assert len(results) == 6 + for i, result in enumerate(results): + assert "input_ids" in result + assert result["input_ids"] == [0, 1] # "text i" -> 2 tokens + + @pytest.mark.asyncio + async def test_callable_interface(self, async_tokenizer): + """Test that the tokenizer is callable.""" + text = "hello world" + result = await async_tokenizer(text) + + assert "input_ids" in result + assert result["input_ids"] == [0, 1] + + @pytest.mark.asyncio + async def test_lazy_initialization(self, mock_tokenizer): + """Test that initialization happens lazily.""" + tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer) + + # Should not be initialized yet + assert not tokenizer._initialized + + # First encode should initialize + await tokenizer.encode("hello") + + # Should now be initialized + assert tokenizer._initialized + + @pytest.mark.asyncio + async def test_error_handling_in_tokenizer(self, mock_tokenizer): + """Test error handling when tokenizer fails.""" + + # Create a new async tokenizer with a failing tokenizer + def failing_tokenizer(*args, **kwargs): + raise ValueError("Tokenizer error") + + async_tokenizer = AsyncDynamicbatchTokenizer( + tokenizer=failing_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01 + ) + + with pytest.raises(ValueError, match="Tokenizer error"): + await async_tokenizer.encode("hello world") + + @pytest.mark.asyncio + async def test_batch_processing_logs(self, async_tokenizer, caplog): + """Test that batch processing logs are generated.""" + caplog.set_level(logging.DEBUG) + + # Send multiple requests to trigger batching + tasks = [ + async_tokenizer.encode("hello world"), + async_tokenizer.encode("how are you"), + ] + + await asyncio.gather(*tasks) + + # Should have batch processing log + assert any( + "Processing dynamic batch of size" in record.message + for record in caplog.records + ) + + @pytest.mark.asyncio + async def test_empty_queue_immediate_processing(self, async_tokenizer): + """Test that single requests are processed immediately when queue is empty.""" + start_time = time.time() + result = await async_tokenizer.encode("hello world") + end_time = time.time() + + # Should complete quickly (much less than batch timeout) + assert end_time - start_time < 0.005 # 5ms should be plenty + assert result["input_ids"] == [0, 1] + + @pytest.mark.asyncio + async def test_real_tokenizer_integration(self): + """Test with a real HuggingFace tokenizer.""" + try: + # Use a small, fast tokenizer for testing + real_tokenizer = AutoTokenizer.from_pretrained("gpt2") + async_tokenizer = AsyncDynamicbatchTokenizer( + tokenizer=real_tokenizer, max_batch_size=2, batch_wait_timeout_s=0.01 + ) + + text = "Hello, world!" + result = await async_tokenizer.encode(text) + + # Should get actual token IDs + assert "input_ids" in result + assert isinstance(result["input_ids"], list) + assert len(result["input_ids"]) > 0 + assert all(isinstance(token_id, int) for token_id in result["input_ids"]) + + except Exception as e: + pytest.skip(f"Real tokenizer test skipped: {e}") + + @pytest.mark.asyncio + async def test_concurrent_mixed_requests(self, async_tokenizer): + """Test mixing single and batched requests.""" + # Start some requests + task1 = asyncio.create_task(async_tokenizer.encode("hello")) + task2 = asyncio.create_task(async_tokenizer.encode("world")) + + # Wait a bit + await asyncio.sleep(0.005) + + # Start more requests + task3 = asyncio.create_task(async_tokenizer.encode("how are")) + task4 = asyncio.create_task(async_tokenizer.encode("you doing")) + + results = await asyncio.gather(task1, task2, task3, task4) + + # All should complete successfully + assert len(results) == 4 + for result in results: + assert "input_ids" in result + assert isinstance(result["input_ids"], list) + + def test_cleanup_on_destruction(self, mock_tokenizer): + """Test that resources are cleaned up properly.""" + tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer) + + # Mock the executor and task + tokenizer._executor = Mock() + tokenizer._batcher_task = Mock() + tokenizer._batcher_task.done.return_value = False + + # Call destructor + tokenizer.__del__() + + # Should cancel task and shutdown executor + tokenizer._batcher_task.cancel.assert_called_once() + tokenizer._executor.shutdown.assert_called_once_with(wait=False) + + +if __name__ == "__main__": + pytest.main([__file__])