Skip to content
27 changes: 27 additions & 0 deletions docs/backend/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Please refer to our dedicated guide on [constrained decoding](./structured_outpu
| ignore_eos | `bool = False` | Don't stop generation when EOS token is sampled. |
| skip_special_tokens | `bool = True` | Remove special tokens during decoding. |
| custom_params | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below. |
| thinking_budget | `Optional[int] = None` | The maximum number of reasoning tokens that can be generated for a request. |

## Examples

Expand Down Expand Up @@ -296,3 +297,29 @@ response = requests.post(
)
print(response.json())
```

### Thinking Budget

Launch a server with `--reasoning-parser`.

```bash
python3 -m sglang.launch_server --model Qwen/Qwen3-8B --reasoning-parser qwen3
```

Send a request:

```python
import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "9.11 and 9.8, which is greater?",
"sampling_params": {
"temperature": 0.3,
"max_new_tokens": 256,
"thinking_budget": 20,
},
},
)
print(response.json())
```
6 changes: 5 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,9 @@ def sample(
[self.sample(values, forward_batch) for values in logits_output],
axis=-1,
)

sampling_info = forward_batch.sampling_info
if sampling_info.thinking_budgets is not None:
sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
self._preprocess_logits(logits_output, forward_batch.sampling_info)

# Sample the next tokens
Expand All @@ -1151,6 +1153,8 @@ def sample(
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
)
if sampling_info.thinking_budgets is not None:
sampling_info.update_thinking_budgets(next_token_ids)
return next_token_ids

@property
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def v1_generate_request(
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"thinking_budget": request.thinking_budget,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
Expand Down Expand Up @@ -1101,6 +1102,7 @@ def v1_chat_generate_request(
"temperature": request.temperature,
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
"min_new_tokens": request.min_tokens,
"thinking_budget": request.thinking_budget,
"stop": stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class CompletionRequest(BaseModel):
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
thinking_budget: Optional[int] = None
json_schema: Optional[str] = None
regex: Optional[str] = None
ebnf: Optional[str] = None
Expand Down Expand Up @@ -350,6 +351,13 @@ class ChatCompletionRequest(BaseModel):
description="The maximum number of completion tokens for a chat completion request, "
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
)
thinking_budget: Optional[int] = Field(
default=None,
description="The maximum number of reasoning tokens that can be generated for a request. "
"This setting of does not affect the thinking process of models. "
"If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
"the reasoning content will be truncated and the final response content will be generated immediately.",
)
n: int = 1
presence_penalty: float = 0.0
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def detect_and_parse(self, text: str) -> StreamingParseResult:
One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately.
"""
text = text.replace(self.think_start_token, "").strip()
text = text.replace(self.think_start_token, "")
if self.think_end_token not in text:
# Assume reasoning was truncated before `</think>` token
return StreamingParseResult(reasoning_text=text)
Expand Down Expand Up @@ -73,7 +73,7 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
normal_text = current_text[end_idx + len(self.think_end_token) :]

return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
normal_text=normal_text, reasoning_text=reasoning_text
)

# Continue with reasoning content
Expand Down
56 changes: 54 additions & 2 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ class SamplingBatchInfo:
# Whether any request needs min_p sampling
need_min_p_sampling: bool

# Use thinking_budget to truncate thinking
num_thinking_tokens: Optional[torch.Tensor] = None
think_end_ids: Optional[torch.Tensor] = None
thinking_budgets: Optional[torch.Tensor] = None

# Masking tensors for grammar-guided structured outputs
vocab_size: int
vocab_size: int = 0
grammars: Optional[List] = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
Expand Down Expand Up @@ -76,7 +81,22 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)

if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
think_end_ids = torch.tensor(
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
dtype=torch.int64,
).to(device, non_blocking=True)
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
device, non_blocking=True
)
thinking_budgets = torch.tensor(
[r.sampling_params.thinking_budget or -1 for r in reqs],
dtype=torch.int64,
).to(device, non_blocking=True)
else:
think_end_ids = None
num_thinking_tokens = None
thinking_budgets = None
# Check if any request has custom logit processor
has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
Expand Down Expand Up @@ -132,6 +152,9 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
think_end_ids=think_end_ids,
num_thinking_tokens=num_thinking_tokens,
thinking_budgets=thinking_budgets,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
Expand All @@ -146,6 +169,35 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
def __len__(self):
return len(self.temperatures)

def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
has_budget = self.thinking_budgets > 0
if not has_budget.any():
return
torch.where(
has_budget,
self.num_thinking_tokens + 1,
self.num_thinking_tokens,
out=self.num_thinking_tokens,
)
should_stop = has_budget & (
self.num_thinking_tokens - 1 > self.thinking_budgets
)
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
if len(batch_indices) > 0:
end_token_indices = self.think_end_ids[batch_indices]
next_token_logits[batch_indices, end_token_indices] = 0.0

def update_thinking_budgets(self, next_token_ids: torch.Tensor):
if not torch.any(self.thinking_budgets > 0):
return
torch.where(
next_token_ids == self.think_end_ids,
torch.tensor(-1, device=self.thinking_budgets.device),
self.thinking_budgets,
out=self.thinking_budgets,
)

def update_regex_vocab_mask(self):
if not self.grammars:
self.vocab_mask = None
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SamplingParams:
def __init__(
self,
max_new_tokens: int = 128,
thinking_budget: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.thinking_budget = thinking_budget
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class TestFile:
TestFile("test_radix_attention.py", 167),
TestFile("test_reasoning_content.py", 89),
TestFile("test_enable_thinking.py", 70),
TestFile("test_thinking_budget.py", 60),
TestFile("test_regex_constrained.py", 64),
TestFile("test_release_memory_occupation.py", 44),
TestFile("test_request_length_validation.py", 31),
Expand Down
95 changes: 95 additions & 0 deletions test/srt/test_thinking_budget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Usage:
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_20
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_200
"""

import unittest

import requests
from transformers import AutoTokenizer

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)


class TestThinkingBudget(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-8B"
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--reasoning-parser",
"qwen3",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_chat_completion_with_thinking_budget_20(self):
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
"thinking_budget": 20,
},
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
tokens = self.tokenizer.encode(reasoning_content)
self.assertEqual(
len(tokens),
20,
f"Reasoning content length: {len(tokens)} not equal to 20, tokens: {tokens}, reasoning_content: {reasoning_content}",
)

def test_chat_completion_with_thinking_budget_200(self):
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
"thinking_budget": 200,
},
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
tokens = self.tokenizer.encode(reasoning_content)
self.assertEqual(
len(tokens),
200,
f"Reasoning content length {len(tokens)} not equal to 200, tokens: {tokens}, reasoning_content: {reasoning_content}",
)


if __name__ == "__main__":
unittest.main()
Loading