diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py
index 97e9814bfba..9e325478ca5 100644
--- a/python/sglang/srt/function_call/function_call_parser.py
+++ b/python/sglang/srt/function_call/function_call_parser.py
@@ -19,6 +19,7 @@
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.function_call.step3_detector import Step3Detector
+from sglang.srt.function_call.seed_oss_detector import SeedOssDetector
logger = logging.getLogger(__name__)
@@ -43,6 +44,7 @@ class FunctionCallParser:
"glm45": Glm4MoeDetector,
"step3": Step3Detector,
"gpt-oss": GptOssDetector,
+ "seed_oss": SeedOssDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
diff --git a/python/sglang/srt/function_call/seed_oss_detector.py b/python/sglang/srt/function_call/seed_oss_detector.py
new file mode 100644
index 00000000000..23dc0891645
--- /dev/null
+++ b/python/sglang/srt/function_call/seed_oss_detector.py
@@ -0,0 +1,471 @@
+import json
+import logging
+import re
+from typing import List
+
+from sglang.srt.entrypoints.openai.protocol import Tool
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ StructureInfo,
+ ToolCallItem,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.ebnf_composer import EBNFComposer
+from sglang.srt.function_call.utils import _is_complete_json
+
+logger = logging.getLogger(__name__)
+
+
+class SeedOssDetector(BaseFormatDetector):
+ """Detector for the Seed Open Source format using XML-like tags."""
+
+ def __init__(self):
+ super().__init__()
+ self._buffer = ""
+ self.current_tool_name_sent: bool = False
+ self.prev_tool_call_arr: list[dict] = []
+ self.current_tool_id: int = -1
+ self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
+
+ # Format-specific tokens
+ self.tool_call_start_token: str = ""
+ self.tool_call_end_token: str = ""
+ self.bot_token: str = self.tool_call_start_token # For base class compatibility
+ self.eot_token: str = self.tool_call_end_token # For base class compatibility
+
+ # Sentinel tokens for streaming mode
+ self.function_prefix: str = "|| bool:
+ """Check if the text contains a Seed OSS format tool call."""
+ return self.tool_call_start_token in text
+
+ def _parse_xml_function_call(self, function_call_str: str, tools: List[Tool]) -> ToolCallItem:
+ """
+ Parse a function call from the XML format.
+
+ Args:
+ function_call_str: The function call string in XML format
+ tools: List of available tools
+
+ Returns:
+ A ToolCallItem representing the parsed function call
+ """
+ def get_arguments_config(func_name: str) -> dict:
+ if not tools:
+ return {}
+ for i, tool in enumerate(tools):
+ if tool.function and tool.function.name == func_name:
+ if not tool.function.parameters:
+ return {}
+ params = tool.function.parameters
+ if isinstance(params, dict) and "properties" in params:
+ return params["properties"]
+ elif isinstance(params, dict):
+ return params
+ else:
+ return {}
+ logger.warning(f"Tool '{func_name}' is not defined in the tools list.")
+ return {}
+
+ def convert_param_value(param_value: str, param_name: str, param_config: dict, func_name: str) -> any:
+ # Handle null value for any type
+ if param_value.lower() == "null":
+ return None
+
+ if param_name not in param_config:
+ if param_config != {}:
+ logger.warning(
+ f"Parsed parameter '{param_name}' is not defined in the tool "
+ f"parameters for tool '{func_name}', directly returning the string value."
+ )
+ return param_value
+
+ if (isinstance(param_config[param_name], dict) and "type" in param_config[param_name]):
+ param_type = str(param_config[param_name]["type"]).strip().lower()
+ else:
+ param_type = "string"
+
+ if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
+ return param_value
+ elif (param_type.startswith("int") or param_type.startswith("uint") or
+ param_type.startswith("long") or param_type.startswith("short") or
+ param_type.startswith("unsigned")):
+ try:
+ param_value = int(param_value)
+ except:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool "
+ f"'{func_name}', degenerating to string."
+ )
+ return param_value
+ elif param_type.startswith("num") or param_type.startswith("float"):
+ try:
+ float_param_value = float(param_value)
+ param_value = float_param_value if float_param_value - int(float_param_value) != 0 else int(float_param_value)
+ except:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool "
+ f"'{func_name}', degenerating to string."
+ )
+ return param_value
+ elif param_type in ["boolean", "bool", "binary"]:
+ param_value = param_value.lower()
+ if param_value not in ["true", "false"]:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false."
+ )
+ return param_value == "true"
+ else:
+ if param_type == "object" or param_type.startswith("dict"):
+ try:
+ param_value = json.loads(param_value)
+ return param_value
+ except:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a valid JSON object in tool "
+ f"'{func_name}', will try other methods to parse it."
+ )
+ try:
+ param_value = eval(param_value)
+ except:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `eval()` in tool '{func_name}', degenerating to string."
+ )
+ return param_value
+
+ # Extract function name
+ end_index = function_call_str.index(">")
+ function_name = function_call_str[:end_index]
+ param_config = get_arguments_config(function_name)
+ parameters = function_call_str[end_index + 1:]
+ param_dict = {}
+
+ for match in self.tool_call_parameter_regex.findall(parameters):
+ match_text = match[0] if match[0] else match[1]
+ idx = match_text.index(">")
+ param_name = match_text[:idx]
+ param_value = str(match_text[idx + 1:])
+ # Remove prefix and trailing \n
+ if param_value.startswith("\n"):
+ param_value = param_value[1:]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ param_dict[param_name] = convert_param_value(param_value, param_name, param_config, function_name)
+
+ return ToolCallItem(
+ tool_index=-1, # To be updated by the caller based on tools list
+ name=function_name,
+ parameters=json.dumps(param_dict, ensure_ascii=False)
+ )
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ Args:
+ text: The complete text to parse.
+ tools: List of available tools.
+
+ Returns:
+ ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
+ """
+ if not self.has_tool_call(text):
+ return StreamingParseResult(normal_text=text, calls=[])
+
+ try:
+ # Extracting the content before tool call
+ content_index = text.find(self.tool_call_start_token)
+ content = text[:content_index] if content_index >= 0 else ""
+
+ # Find all tool calls
+ matched_ranges = self.tool_call_regex.findall(text)
+ raw_tool_calls = [match[0] if match[0] else match[1] for match in matched_ranges]
+
+ # Back-off strategy if no tool_call tags found
+ if len(raw_tool_calls) == 0:
+ return StreamingParseResult(normal_text=text, calls=[])
+
+ # Extract function calls from tool calls
+ raw_function_calls = []
+ for tool_call in raw_tool_calls:
+ raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call))
+
+ function_calls = [match[0] if match[0] else match[1] for match in raw_function_calls]
+
+ if len(function_calls) == 0:
+ return StreamingParseResult(normal_text=text, calls=[])
+
+ # Parse each function call
+ tool_calls = []
+ for idx, function_call_str in enumerate(function_calls):
+ tool_call = self._parse_xml_function_call(function_call_str, tools)
+ if tool_call:
+ # Update tool index based on position in response
+ tool_call.tool_index = idx
+ tool_calls.append(tool_call)
+
+ # Store in prev_tool_call_arr for later use
+ if idx >= len(self.prev_tool_call_arr):
+ self.prev_tool_call_arr.append({})
+ self.prev_tool_call_arr[idx] = {
+ "name": tool_call.name,
+ "arguments": tool_call.parameters
+ }
+
+ return StreamingParseResult(normal_text=content, calls=tool_calls)
+
+ except Exception as e:
+ logger.error(f"Error in detect_and_parse: {e}")
+ # Return the normal text if parsing fails
+ return StreamingParseResult(normal_text=text)
+
+ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ Streaming incremental parsing for the SeedOss format.
+
+ Args:
+ new_text: The new text increment to parse.
+ tools: List of available tools.
+
+ Returns:
+ StreamingParseResult with parsed calls or normal text.
+ """
+ self._buffer += new_text
+ current_text = self._buffer
+
+ # Check if we have a tool call
+ has_tool_call = (
+ self.tool_call_start_token in current_text or
+ self.function_prefix in current_text
+ )
+
+ if not has_tool_call:
+ self._buffer = ""
+ # Clean up any end tokens in the normal text
+ for e_token in [self.tool_call_end_token, self.function_end_token, self.parameter_end_token]:
+ if e_token in new_text:
+ new_text = new_text.replace(e_token, "")
+ return StreamingParseResult(normal_text=new_text)
+
+ # Initialize tool indices if not already done
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = {
+ tool.function.name: i
+ for i, tool in enumerate(tools)
+ if tool.function and tool.function.name
+ }
+
+ calls = []
+ try:
+ # Check for function start
+ if self.function_prefix in current_text and not self.current_tool_name_sent:
+ # Extract function name
+ func_start = current_text.find(self.function_prefix) + len(self.function_prefix)
+ func_end = current_text.find(">", func_start)
+
+ if func_end != -1:
+ function_name = current_text[func_start:func_end]
+
+ # Initialize state if this is the first tool call
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+
+ # Ensure we have enough entries in our tracking arrays
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=function_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ self.is_function = True
+
+ # Store the tool call info
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": function_name,
+ "arguments": {},
+ }
+
+ # Check for parameter
+ elif self.is_function and self.parameter_prefix in current_text:
+ # Handle parameter
+ param_matches = self.tool_call_parameter_regex.findall(current_text)
+ if param_matches:
+ # Process each parameter match
+ for match in param_matches:
+ match_text = match[0] if match[0] else match[1]
+ if not match_text:
+ continue
+
+ idx = match_text.find(">")
+ if idx == -1:
+ continue
+
+ param_name = match_text[:idx]
+ param_value = str(match_text[idx + 1:])
+
+ # Clean up parameter value
+ if param_value.startswith("\n"):
+ param_value = param_value[1:]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ # Check if parameter is complete
+ is_complete = self.parameter_end_token in current_text
+
+ # Extract the part we haven't sent yet
+ if param_value:
+ arguments_diff = json.dumps({param_name: param_value})
+ if self._last_arguments:
+ # Only send what's new
+ last_args_obj = json.loads(self._last_arguments)
+ if param_name in last_args_obj and last_args_obj[param_name] == param_value:
+ # No change, skip
+ continue
+
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=arguments_diff,
+ )
+ )
+ self._last_arguments = arguments_diff
+
+ # Update tracked parameters
+ if is_complete:
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = json.loads(arguments_diff)
+
+ # Parameter is complete, clean up
+ self._last_arguments = ""
+
+ # Check if function is also complete
+ if self.function_end_token in current_text:
+ self.is_function = False
+ self.current_tool_name_sent = False
+ self.current_tool_id += 1
+
+ # Remove the processed function from buffer
+ func_end_pos = current_text.find(self.function_end_token) + len(self.function_end_token)
+ self._buffer = current_text[func_end_pos:]
+
+ # Check if a function has ended without us catching a specific parameter
+ elif self.is_function and self.function_end_token in current_text:
+ self.is_function = False
+ self.current_tool_name_sent = False
+ self.current_tool_id += 1
+
+ # Remove the processed function from buffer
+ func_end_pos = current_text.find(self.function_end_token) + len(self.function_end_token)
+ self._buffer = current_text[func_end_pos:]
+
+ # Check if the entire tool call section has ended
+ if self.tool_call_end_token in current_text:
+ # Clean up buffer
+ tool_end_pos = current_text.find(self.tool_call_end_token) + len(self.tool_call_end_token)
+ self._buffer = current_text[tool_end_pos:]
+
+ # Reset state for next potential tool call
+ if not self.is_function:
+ self.current_tool_id = -1
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}")
+ return StreamingParseResult(normal_text="")
+
+ def structure_info(self) -> _GetInfoFunc:
+ """Return metadata about the structure of SeedOss tool calls."""
+ def _get_info() -> StructureInfo:
+ return StructureInfo(
+ start_phrase=self.tool_call_start_token,
+ end_phrase=self.tool_call_end_token,
+ likely_content_before_start=True,
+ graceful_recovery_possible=True,
+ can_generate_content_after_end=True,
+ additional_escape_phrases=[
+ self.function_prefix,
+ self.function_end_token,
+ self.parameter_prefix,
+ self.parameter_end_token
+ ]
+ )
+ return _get_info
+
+ def build_ebnf(self, tools: List[Tool]) -> str:
+ """Build an EBNF grammar for the SeedOss format."""
+ composer = EBNFComposer()
+
+ # Define the overall structure
+ composer.add_production("tool_call", f'"{self.tool_call_start_token}" function "{self.tool_call_end_token}"')
+
+ # Define what a function looks like
+ composer.add_production("function", f'"{self.function_prefix}" function_name ">" parameters "{self.function_end_token}"')
+
+ # Function name is any of the available tool names
+ function_names = " | ".join([f'"{tool.function.name}"' for tool in tools if tool.function and tool.function.name])
+ composer.add_production("function_name", function_names if function_names else '"unknown_function"')
+
+ # Parameters are a sequence of parameter definitions
+ composer.add_production("parameters", "parameter*")
+
+ # Each parameter has a name and value
+ composer.add_production("parameter", f'"{self.parameter_prefix}" parameter_name ">" parameter_value "{self.parameter_end_token}"')
+
+ # Parameter name can be any key from the tools
+ param_names = set()
+ for tool in tools:
+ if not (tool.function and tool.function.parameters and isinstance(tool.function.parameters, dict)):
+ continue
+
+ properties = tool.function.parameters.get("properties", {})
+ for param_name in properties.keys():
+ param_names.add(param_name)
+
+ param_name_production = " | ".join([f'"{name}"' for name in param_names]) if param_names else '"param"'
+ composer.add_production("parameter_name", param_name_production)
+
+ # Parameter value can be any string (simplified)
+ composer.add_production("parameter_value", 'string | number | "true" | "false" | "null" | object | array')
+ composer.add_production("string", r'"\"" [^"]* "\""')
+ composer.add_production("number", r"[0-9]+ (\.[0-9]+)?")
+ composer.add_production("object", r'"{" (string ":" value ("," string ":" value)*)? "}"')
+ composer.add_production("array", r'"\[" (value ("," value)*)? "\]"')
+ composer.add_production("value", "string | number | object | array | \"true\" | \"false\" | \"null\"")
+
+ return composer.compose()
\ No newline at end of file
diff --git a/python/sglang/srt/models/seed_oss.py b/python/sglang/srt/models/seed_oss.py
new file mode 100644
index 00000000000..f7e9bcf8f1e
--- /dev/null
+++ b/python/sglang/srt/models/seed_oss.py
@@ -0,0 +1,589 @@
+# Adapted from qwen2.py & modified details to support SeedOss
+"""Inference-only SeedOss model compatible with HuggingFace weights."""
+
+import logging
+from typing import Any, Dict, Iterable, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from transformers import SeedOssConfig
+
+from sglang.srt.distributed import (
+ get_pp_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+)
+from sglang.srt.layers.activation import SiluAndMul
+from sglang.srt.layers.dp_attention import is_dp_attention_enabled
+from sglang.srt.layers.layernorm import RMSNorm
+from sglang.srt.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from sglang.srt.layers.logits_processor import LogitsProcessor
+from sglang.srt.layers.pooler import Pooler, PoolingType
+from sglang.srt.layers.quantization.base_config import QuantizationConfig
+from sglang.srt.layers.radix_attention import RadixAttention
+from sglang.srt.layers.rotary_embedding import get_rope
+from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
+from sglang.srt.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from sglang.srt.managers.schedule_batch import global_server_args_dict
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
+from sglang.srt.model_loader.weight_utils import (
+ default_weight_loader,
+ kv_cache_scales_loader,
+)
+from sglang.srt.utils import add_prefix, make_layers
+
+
+logger = logging.getLogger(__name__)
+
+
+class SeedOssMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=add_prefix("gate_up_proj", prefix),
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=add_prefix("down_proj", prefix),
+ )
+ if hidden_act != "silu":
+ raise ValueError(
+ f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class SeedOssAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: Optional[int] = None,
+ layer_id: int = 0,
+ rope_theta: float = 10000,
+ rope_scaling: Optional[Dict[str, Any]] = None,
+ max_position_embeddings: int = 4096 * 32,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ if head_dim is not None:
+ self.head_dim = head_dim
+ else:
+ self.head_dim = hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim ** -0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=add_prefix("qkv_proj", prefix),
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=add_prefix("o_proj", prefix),
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = RadixAttention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ layer_id=layer_id,
+ quant_config=quant_config,
+ prefix=add_prefix("attn", prefix),
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v, forward_batch)
+ output, _ = self.o_proj(attn_output)
+
+ return output
+
+
+class SeedOssDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: SeedOssConfig,
+ layer_id: int = 0,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ alt_stream: Optional[torch.cuda.Stream] = None,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 1000000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
+ head_dim = getattr(config, "head_dim", None)
+
+ self.self_attn = SeedOssAttention(
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ head_dim=head_dim,
+ layer_id=layer_id,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ max_position_embeddings=max_position_embeddings,
+ quant_config=quant_config,
+ prefix=add_prefix("self_attn", prefix),
+ )
+
+ self.mlp = SeedOssMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=add_prefix("mlp", prefix),
+ )
+
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ forward_batch: ForwardBatch,
+ residual: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ forward_batch=forward_batch,
+ )
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+
+ return hidden_states, residual
+
+
+class SeedOssModel(nn.Module):
+ def __init__(
+ self,
+ config: SeedOssConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer,
+ alt_stream: Optional[torch.cuda.Stream] = None,
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.pp_group = get_pp_group()
+
+ if self.pp_group.is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ enable_tp=not is_dp_attention_enabled(),
+ prefix=add_prefix("embed_tokens", prefix),
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ # Use the provided decoder layer type or default to SeedOssDecoderLayer
+ decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
+ self.layers, self.start_layer, self.end_layer = make_layers(
+ config.num_hidden_layers,
+ lambda idx, prefix: decoder_layer_type(
+ layer_id=idx,
+ config=config,
+ quant_config=quant_config,
+ prefix=prefix,
+ alt_stream=alt_stream,
+ ),
+ pp_rank=self.pp_group.rank_in_group,
+ pp_size=self.pp_group.world_size,
+ prefix=add_prefix("layers", prefix),
+ )
+
+ if self.pp_group.is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer(return_tuple=True)
+ # TODO: Add support for EAGLE3
+
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.get_input_embeddings()(input_ids)
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.embed_tokens
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ input_embeds: torch.Tensor = None,
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
+ ) -> Union[torch.Tensor, PPProxyTensors]:
+ if self.pp_group.is_first_rank:
+ if input_embeds is None:
+ hidden_states = self.embed_tokens(input_ids)
+ else:
+ hidden_states = input_embeds
+ residual = None
+ else:
+ assert pp_proxy_tensors is not None
+ hidden_states = pp_proxy_tensors["hidden_states"]
+ residual = pp_proxy_tensors["residual"]
+
+ for i in range(self.start_layer, self.end_layer):
+ layer = self.layers[i]
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ forward_batch,
+ residual,
+ )
+
+ if not self.pp_group.is_last_rank:
+ return PPProxyTensors(
+ {
+ "hidden_states": hidden_states,
+ "residual": residual,
+ }
+ )
+ else:
+ if hidden_states.shape[0] != 0:
+ if residual is None:
+ hidden_states = self.norm(hidden_states)
+ else:
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ # If this function is called, it should always initialize KV cache scale
+ # factors (or else raise an exception). Thus, handled exceptions should
+ # make sure to leave KV cache scale factors in a known good (dummy) state
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
+ tp_size = get_tensor_model_parallel_world_size()
+ tp_rank = get_tensor_model_parallel_rank()
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
+ quantization_param_path,
+ tp_rank,
+ tp_size,
+ self.config.num_hidden_layers,
+ self.config.__class__.model_type,
+ ):
+ if not isinstance(self.layers[layer_idx], nn.Identity):
+ layer_self_attn = self.layers[layer_idx].self_attn
+ if hasattr(layer_self_attn.attn, "k_scale"):
+ layer_self_attn.attn.k_scale = scaling_factor
+ layer_self_attn.attn.v_scale = scaling_factor
+ else:
+ raise RuntimeError(
+ "Self attention has no KV cache scaling " "factor attribute!"
+ )
+
+
+class SeedOssForCausalLM(nn.Module):
+ # BitandBytes specific attributes
+ default_bitsandbytes_target_modules = [
+ ".gate_proj.",
+ ".down_proj.",
+ ".up_proj.",
+ ".q_proj.",
+ ".k_proj.",
+ ".v_proj.",
+ ".o_proj.",
+ ]
+ bitsandbytes_stacked_params_mapping = {
+ # shard_name, weight_name, index
+ "q_proj": ("qkv_proj", 0),
+ "k_proj": ("qkv_proj", 1),
+ "v_proj": ("qkv_proj", 2),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(
+ self,
+ config: SeedOssConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.pp_group = get_pp_group()
+ self.config = config
+ self.quant_config = quant_config
+ self.model = SeedOssModel(
+ config,
+ quant_config=quant_config,
+ prefix=add_prefix("model", prefix),
+ )
+
+ # handle the lm head on different pp ranks
+ if self.pp_group.is_last_rank:
+ if self.pp_group.world_size == 1 and config.tie_word_embeddings:
+ self.lm_head = self.model.embed_tokens
+ else:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=add_prefix("lm_head", prefix),
+ )
+ else:
+ # ranks other than the last rank will have a placeholder layer
+ self.lm_head = PPMissingLayer()
+
+ # perform weight tying for PP
+ if self.pp_group.world_size > 1 and config.tie_word_embeddings:
+ if self.pp_group.is_first_rank:
+ self.pp_group.send(
+ self.model.embed_tokens.weight, dst=self.pp_group.last_rank
+ )
+ else:
+ emb_token_weight = self.pp_group.recv(
+ size=(config.vocab_size, config.hidden_size),
+ dtype=next(self.model.parameters()).dtype,
+ src=self.pp_group.first_rank,
+ )
+ self.lm_head.weight.copy_(emb_token_weight)
+ self.logits_processor = LogitsProcessor(config)
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
+
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embedding(input_ids)
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.model.embed_tokens
+
+ @torch.no_grad()
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ input_embeds: torch.Tensor = None,
+ get_embedding: bool = False,
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids,
+ positions,
+ forward_batch,
+ input_embeds,
+ pp_proxy_tensors,
+ )
+ if self.pp_group.is_last_rank:
+ if not get_embedding:
+ return self.logits_processor(
+ input_ids, hidden_states, self.lm_head, forward_batch
+ )
+ else:
+ return self.pooler(hidden_states, forward_batch)
+ else:
+ return hidden_states
+
+ @torch.no_grad()
+ def forward_split_prefill(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ split_interval: Tuple[int, int], # [start, end) 0-based
+ input_embeds: torch.Tensor = None,
+ ):
+ start, end = split_interval
+ # embed
+ if start == 0:
+ if input_embeds is None:
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
+ else:
+ forward_batch.hidden_states = input_embeds
+ # decoder layer
+ for i in range(start, end):
+ layer = self.model.layers[i]
+ forward_batch.hidden_states, forward_batch.residual = layer(
+ positions,
+ forward_batch.hidden_states,
+ forward_batch,
+ forward_batch.residual,
+ )
+ if end == self.model.config.num_hidden_layers:
+ # norm
+ hidden_states, _ = self.model.norm(
+ forward_batch.hidden_states, forward_batch.residual
+ )
+ forward_batch.hidden_states = hidden_states
+ # logits process
+ result = self.logits_processor(
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
+ )
+ else:
+ result = None
+ return result
+
+ @property
+ def start_layer(self):
+ return self.model.start_layer
+
+ @property
+ def end_layer(self):
+ return self.model.end_layer
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ layer_id = get_layer_id(name)
+ if (
+ layer_id is not None
+ and hasattr(self.model, "start_layer")
+ and (
+ layer_id < self.model.start_layer
+ or layer_id >= self.model.end_layer
+ )
+ ):
+ continue
+ if "rotary_emb.inv_freq" in name or "projector" in name:
+ continue
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
+ # Models trained using ColossalAI may include these tensors in
+ # the checkpoint. Skip them.
+ continue
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
+ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
+ # Handle pp weight tying here
+ # find the embed_tokens.weight in the weights
+ embed_token_weights = next(
+ filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
+ )[1]
+ loaded_weight = embed_token_weights
+ else:
+ continue
+ if name.startswith("model.vision_tower") and name not in params_dict:
+ continue
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if name in params_dict.keys():
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ else:
+ logger.warning(f"Parameter {name} not found in params_dict")
+
+ def get_embed_and_head(self):
+ return self.model.embed_tokens.weight, self.lm_head.weight
+
+ def set_embed_and_head(self, embed, head):
+ del self.model.embed_tokens.weight
+ del self.lm_head.weight
+ self.model.embed_tokens.weight = embed
+ self.lm_head.weight = head
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
+ self.model.load_kv_cache_scales(quantization_param_path)
+
+EntryClass = SeedOssForCausalLM
diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py
index fd9ce55084f..6e6f12a92a9 100644
--- a/python/sglang/srt/reasoning_parser.py
+++ b/python/sglang/srt/reasoning_parser.py
@@ -186,6 +186,23 @@ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False)
)
+class SeedOssDetector(BaseReasoningFormatDetector):
+ """
+ Detector for SeedOSS model.
+ Args:
+ stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
+ If True, streams reasoning content as it arrives.
+ """
+ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = True):
+ # SeedOSS is assumed to be reasoning until `` token
+ super().__init__(
+ "",
+ "",
+ force_reasoning=force_reasoning,
+ stream_reasoning=stream_reasoning,
+ )
+
+
class GptOssDetector(BaseReasoningFormatDetector):
"""
Detector for T4-style reasoning format.
@@ -520,6 +537,7 @@ class ReasoningParser:
"qwen3": Qwen3Detector,
"qwen3-thinking": Qwen3Detector,
"step3": DeepSeekR1Detector,
+ "seed_oss": SeedOssDetector,
}
def __init__(