Skip to content
4 changes: 3 additions & 1 deletion python/sglang/bench_offline_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def throughput_test_once(
measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"]
) / latency
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
"last_gen_throughput"
]

return measurement_results

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def extend(reqs, model_runner):
_maybe_prepare_dp_attn_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits, batch

Expand All @@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
_maybe_prepare_dp_attn_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits

Expand Down
158 changes: 143 additions & 15 deletions python/sglang/bench_one_batch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary


@dataclasses.dataclass
Expand All @@ -33,9 +34,13 @@ class BenchArgs:
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
temperature: float = 0.0
return_logprob: bool = False
input_len_step_percentage: float = 0.0
result_filename: str = "result.jsonl"
base_url: str = ""
skip_warmup: bool = False
show_report: bool = False

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand All @@ -49,11 +54,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--input-len-step-percentage",
type=float,
default=BenchArgs.input_len_step_percentage,
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true")
parser.add_argument("--show-report", action="store_true")

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down Expand Up @@ -99,36 +112,89 @@ def run_one_case(
batch_size: int,
input_len: int,
output_len: int,
temperature: float,
return_logprob: bool,
input_len_step_percentage: float,
run_name: str,
result_filename: str,
):
requests.post(url + "/flush_cache")
input_lens = [
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
for i in range(batch_size)
]
input_ids = [
[int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
for _ in range(batch_size)
[int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
for i in range(batch_size)
]

use_structured_outputs = False
if use_structured_outputs:
texts = []
for _ in range(batch_size):
texts.append(
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
* 50
+ "Assistant:"
)
json_schema = "$$ANY$$"
else:
json_schema = None

tic = time.time()
response = requests.post(
url + "/generate",
json={
# "text": texts,
"input_ids": input_ids,
"sampling_params": {
"temperature": 0,
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
"json_schema": json_schema,
},
"return_logprob": return_logprob,
"stream": True,
},
stream=True,
)
latency = time.time() - tic

_ = response.json()
output_throughput = batch_size * output_len / latency
# The TTFT of the last request in the batch
ttft = 0.0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
if "error" in data:
raise RuntimeError(f"Request has failed. {data}.")

assert (
data["meta_info"]["finish_reason"] is None
or data["meta_info"]["finish_reason"]["type"] == "length"
)
if data["meta_info"]["completion_tokens"] == 1:
ttft = time.time() - tic

latency = time.time() - tic
input_throughput = batch_size * input_len / ttft
output_throughput = batch_size * output_len / (latency - ttft)
overall_throughput = batch_size * (input_len + output_len) / latency

server_info = requests.get(url + "/get_server_info").json()
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]

print(f"batch size: {batch_size}")
print(f"input_len: {input_len}")
print(f"output_len: {output_len}")
print(f"latency: {latency:.2f} s")
print(f"output throughput: {output_throughput:.2f} token/s")
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
print(f"ttft: {ttft:.2f} s")
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"Input throughput: {input_throughput:.2f} tok/s")
if output_len != 1:
print(f"output throughput: {output_throughput:.2f} tok/s")

if result_filename:
with open(result_filename, "a") as fout:
Expand All @@ -140,9 +206,21 @@ def run_one_case(
"latency": round(latency, 4),
"output_throughput": round(output_throughput, 2),
"overall_throughput": round(overall_throughput, 2),
"last_gen_throughput": round(last_gen_throughput, 2),
}
fout.write(json.dumps(res) + "\n")

return (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
)


def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if bench_args.base_url:
Expand All @@ -152,34 +230,84 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):

# warmup
if not bench_args.skip_warmup:
print("=" * 8 + " Warmup Begin " + "=" * 8)
run_one_case(
base_url,
batch_size=16,
input_len=1024,
output_len=16,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name="",
result_filename="",
)
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")

# benchmark
result = []
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
run_one_case(
base_url,
bs,
il,
ol,
bench_args.run_name,
bench_args.result_filename,
result.append(
run_one_case(
base_url,
bs,
il,
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
)
)
finally:
if proc:
kill_process_tree(proc.pid)

print(f"\nResults are saved to {bench_args.result_filename}")

if not bench_args.show_report:
return

summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"

for (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
) in result:
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
input_util = 0.7
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
line = (
f"| {batch_size} | "
f"{latency:.2f} | "
f"{input_throughput:.2f} | "
f"{output_throughput:.2f} | "
f"{accept_length} | "
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
)
summary += line

# print metrics table
print(summary)

if is_in_ci():
write_github_step_summary(
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
12 changes: 7 additions & 5 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ async def benchmark(
lora_names: List[str],
extra_request_body: Dict[str, Any],
profile: bool,
pd_seperated: bool = False,
pd_separated: bool = False,
flush_cache: bool = False,
warmup_requests: int = 1,
):
Expand Down Expand Up @@ -1239,12 +1239,14 @@ async def limited_request_func(request_func_input, pbar):

if "sglang" in backend:
server_info = requests.get(base_url + "/get_server_info")
if pd_seperated:
accept_length = server_info.json()["decode"][0].get(
if pd_separated:
accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
"avg_spec_accept_length", None
)
else:
accept_length = server_info.json().get("avg_spec_accept_length", None)
accept_length = server_info.json()["internal_states"][0].get(
"avg_spec_accept_length", None
)
else:
accept_length = None

Expand Down Expand Up @@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
lora_names=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
pd_seperated=args.pd_seperated,
pd_separated=args.pd_separated,
flush_cache=args.flush_cache,
)
)
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/constrained/base_grammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def accept_token(self, token: int) -> None:
"""
raise NotImplementedError()

def rollback(self, k: int):
raise NotImplementedError()

def is_terminated(self):
raise NotImplementedError()

def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,17 @@ def process_batch_result_disagg_prefill(
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
bid,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.bid,
)

# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap:
# wait
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
else:
next_token_ids = result.next_token_ids.tolist()

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def get_server_info(self):
return {
**dataclasses.asdict(self.tokenizer_manager.server_args),
**self.scheduler_info,
**internal_states,
"internal_states": internal_states,
"version": __version__,
}

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def get_server_info():
return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info,
**internal_states,
"internal_states": internal_states,
"version": __version__,
}

Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/layers/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(

num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
# index into req_to_token_ptr needs to be int64
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
Expand Down Expand Up @@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)

for i in range(num_pages_loop):
# index into req_to_token_ptr needs to be int64
paged_offset = (
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
) * PAGED_SIZE
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK

Expand Down
Loading
Loading