diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 97e9814bfba..9e325478ca5 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -19,6 +19,7 @@ from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector from sglang.srt.function_call.step3_detector import Step3Detector +from sglang.srt.function_call.seed_oss_detector import SeedOssDetector logger = logging.getLogger(__name__) @@ -43,6 +44,7 @@ class FunctionCallParser: "glm45": Glm4MoeDetector, "step3": Step3Detector, "gpt-oss": GptOssDetector, + "seed_oss": SeedOssDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/seed_oss_detector.py b/python/sglang/srt/function_call/seed_oss_detector.py new file mode 100644 index 00000000000..23dc0891645 --- /dev/null +++ b/python/sglang/srt/function_call/seed_oss_detector.py @@ -0,0 +1,471 @@ +import json +import logging +import re +from typing import List + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer +from sglang.srt.function_call.utils import _is_complete_json + +logger = logging.getLogger(__name__) + + +class SeedOssDetector(BaseFormatDetector): + """Detector for the Seed Open Source format using XML-like tags.""" + + def __init__(self): + super().__init__() + self._buffer = "" + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list + + # Format-specific tokens + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.bot_token: str = self.tool_call_start_token # For base class compatibility + self.eot_token: str = self.tool_call_end_token # For base class compatibility + + # Sentinel tokens for streaming mode + self.function_prefix: str = "|| bool: + """Check if the text contains a Seed OSS format tool call.""" + return self.tool_call_start_token in text + + def _parse_xml_function_call(self, function_call_str: str, tools: List[Tool]) -> ToolCallItem: + """ + Parse a function call from the XML format. + + Args: + function_call_str: The function call string in XML format + tools: List of available tools + + Returns: + A ToolCallItem representing the parsed function call + """ + def get_arguments_config(func_name: str) -> dict: + if not tools: + return {} + for i, tool in enumerate(tools): + if tool.function and tool.function.name == func_name: + if not tool.function.parameters: + return {} + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning(f"Tool '{func_name}' is not defined in the tools list.") + return {} + + def convert_param_value(param_value: str, param_name: str, param_config: dict, func_name: str) -> any: + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + f"Parsed parameter '{param_name}' is not defined in the tool " + f"parameters for tool '{func_name}', directly returning the string value." + ) + return param_value + + if (isinstance(param_config[param_name], dict) and "type" in param_config[param_name]): + param_type = str(param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif (param_type.startswith("int") or param_type.startswith("uint") or + param_type.startswith("long") or param_type.startswith("short") or + param_type.startswith("unsigned")): + try: + param_value = int(param_value) + except: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " + f"'{func_name}', degenerating to string." + ) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value = float(param_value) + param_value = float_param_value if float_param_value - int(float_param_value) != 0 else int(float_param_value) + except: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " + f"'{func_name}', degenerating to string." + ) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." + ) + return param_value == "true" + else: + if param_type == "object" or param_type.startswith("dict"): + try: + param_value = json.loads(param_value) + return param_value + except: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' is not a valid JSON object in tool " + f"'{func_name}', will try other methods to parse it." + ) + try: + param_value = eval(param_value) + except: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `eval()` in tool '{func_name}', degenerating to string." + ) + return param_value + + # Extract function name + end_index = function_call_str.index(">") + function_name = function_call_str[:end_index] + param_config = get_arguments_config(function_name) + parameters = function_call_str[end_index + 1:] + param_dict = {} + + for match in self.tool_call_parameter_regex.findall(parameters): + match_text = match[0] if match[0] else match[1] + idx = match_text.index(">") + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = convert_param_value(param_value, param_name, param_config, function_name) + + return ToolCallItem( + tool_index=-1, # To be updated by the caller based on tools list + name=function_name, + parameters=json.dumps(param_dict, ensure_ascii=False) + ) + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + Args: + text: The complete text to parse. + tools: List of available tools. + + Returns: + ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + if not self.has_tool_call(text): + return StreamingParseResult(normal_text=text, calls=[]) + + try: + # Extracting the content before tool call + content_index = text.find(self.tool_call_start_token) + content = text[:content_index] if content_index >= 0 else "" + + # Find all tool calls + matched_ranges = self.tool_call_regex.findall(text) + raw_tool_calls = [match[0] if match[0] else match[1] for match in matched_ranges] + + # Back-off strategy if no tool_call tags found + if len(raw_tool_calls) == 0: + return StreamingParseResult(normal_text=text, calls=[]) + + # Extract function calls from tool calls + raw_function_calls = [] + for tool_call in raw_tool_calls: + raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call)) + + function_calls = [match[0] if match[0] else match[1] for match in raw_function_calls] + + if len(function_calls) == 0: + return StreamingParseResult(normal_text=text, calls=[]) + + # Parse each function call + tool_calls = [] + for idx, function_call_str in enumerate(function_calls): + tool_call = self._parse_xml_function_call(function_call_str, tools) + if tool_call: + # Update tool index based on position in response + tool_call.tool_index = idx + tool_calls.append(tool_call) + + # Store in prev_tool_call_arr for later use + if idx >= len(self.prev_tool_call_arr): + self.prev_tool_call_arr.append({}) + self.prev_tool_call_arr[idx] = { + "name": tool_call.name, + "arguments": tool_call.parameters + } + + return StreamingParseResult(normal_text=content, calls=tool_calls) + + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}") + # Return the normal text if parsing fails + return StreamingParseResult(normal_text=text) + + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """ + Streaming incremental parsing for the SeedOss format. + + Args: + new_text: The new text increment to parse. + tools: List of available tools. + + Returns: + StreamingParseResult with parsed calls or normal text. + """ + self._buffer += new_text + current_text = self._buffer + + # Check if we have a tool call + has_tool_call = ( + self.tool_call_start_token in current_text or + self.function_prefix in current_text + ) + + if not has_tool_call: + self._buffer = "" + # Clean up any end tokens in the normal text + for e_token in [self.tool_call_end_token, self.function_end_token, self.parameter_end_token]: + if e_token in new_text: + new_text = new_text.replace(e_token, "") + return StreamingParseResult(normal_text=new_text) + + # Initialize tool indices if not already done + if not hasattr(self, "_tool_indices"): + self._tool_indices = { + tool.function.name: i + for i, tool in enumerate(tools) + if tool.function and tool.function.name + } + + calls = [] + try: + # Check for function start + if self.function_prefix in current_text and not self.current_tool_name_sent: + # Extract function name + func_start = current_text.find(self.function_prefix) + len(self.function_prefix) + func_end = current_text.find(">", func_start) + + if func_end != -1: + function_name = current_text[func_start:func_end] + + # Initialize state if this is the first tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.is_function = True + + # Store the tool call info + self.prev_tool_call_arr[self.current_tool_id] = { + "name": function_name, + "arguments": {}, + } + + # Check for parameter + elif self.is_function and self.parameter_prefix in current_text: + # Handle parameter + param_matches = self.tool_call_parameter_regex.findall(current_text) + if param_matches: + # Process each parameter match + for match in param_matches: + match_text = match[0] if match[0] else match[1] + if not match_text: + continue + + idx = match_text.find(">") + if idx == -1: + continue + + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + + # Clean up parameter value + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Check if parameter is complete + is_complete = self.parameter_end_token in current_text + + # Extract the part we haven't sent yet + if param_value: + arguments_diff = json.dumps({param_name: param_value}) + if self._last_arguments: + # Only send what's new + last_args_obj = json.loads(self._last_arguments) + if param_name in last_args_obj and last_args_obj[param_name] == param_value: + # No change, skip + continue + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=arguments_diff, + ) + ) + self._last_arguments = arguments_diff + + # Update tracked parameters + if is_complete: + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = json.loads(arguments_diff) + + # Parameter is complete, clean up + self._last_arguments = "" + + # Check if function is also complete + if self.function_end_token in current_text: + self.is_function = False + self.current_tool_name_sent = False + self.current_tool_id += 1 + + # Remove the processed function from buffer + func_end_pos = current_text.find(self.function_end_token) + len(self.function_end_token) + self._buffer = current_text[func_end_pos:] + + # Check if a function has ended without us catching a specific parameter + elif self.is_function and self.function_end_token in current_text: + self.is_function = False + self.current_tool_name_sent = False + self.current_tool_id += 1 + + # Remove the processed function from buffer + func_end_pos = current_text.find(self.function_end_token) + len(self.function_end_token) + self._buffer = current_text[func_end_pos:] + + # Check if the entire tool call section has ended + if self.tool_call_end_token in current_text: + # Clean up buffer + tool_end_pos = current_text.find(self.tool_call_end_token) + len(self.tool_call_end_token) + self._buffer = current_text[tool_end_pos:] + + # Reset state for next potential tool call + if not self.is_function: + self.current_tool_id = -1 + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}") + return StreamingParseResult(normal_text="") + + def structure_info(self) -> _GetInfoFunc: + """Return metadata about the structure of SeedOss tool calls.""" + def _get_info() -> StructureInfo: + return StructureInfo( + start_phrase=self.tool_call_start_token, + end_phrase=self.tool_call_end_token, + likely_content_before_start=True, + graceful_recovery_possible=True, + can_generate_content_after_end=True, + additional_escape_phrases=[ + self.function_prefix, + self.function_end_token, + self.parameter_prefix, + self.parameter_end_token + ] + ) + return _get_info + + def build_ebnf(self, tools: List[Tool]) -> str: + """Build an EBNF grammar for the SeedOss format.""" + composer = EBNFComposer() + + # Define the overall structure + composer.add_production("tool_call", f'"{self.tool_call_start_token}" function "{self.tool_call_end_token}"') + + # Define what a function looks like + composer.add_production("function", f'"{self.function_prefix}" function_name ">" parameters "{self.function_end_token}"') + + # Function name is any of the available tool names + function_names = " | ".join([f'"{tool.function.name}"' for tool in tools if tool.function and tool.function.name]) + composer.add_production("function_name", function_names if function_names else '"unknown_function"') + + # Parameters are a sequence of parameter definitions + composer.add_production("parameters", "parameter*") + + # Each parameter has a name and value + composer.add_production("parameter", f'"{self.parameter_prefix}" parameter_name ">" parameter_value "{self.parameter_end_token}"') + + # Parameter name can be any key from the tools + param_names = set() + for tool in tools: + if not (tool.function and tool.function.parameters and isinstance(tool.function.parameters, dict)): + continue + + properties = tool.function.parameters.get("properties", {}) + for param_name in properties.keys(): + param_names.add(param_name) + + param_name_production = " | ".join([f'"{name}"' for name in param_names]) if param_names else '"param"' + composer.add_production("parameter_name", param_name_production) + + # Parameter value can be any string (simplified) + composer.add_production("parameter_value", 'string | number | "true" | "false" | "null" | object | array') + composer.add_production("string", r'"\"" [^"]* "\""') + composer.add_production("number", r"[0-9]+ (\.[0-9]+)?") + composer.add_production("object", r'"{" (string ":" value ("," string ":" value)*)? "}"') + composer.add_production("array", r'"\[" (value ("," value)*)? "\]"') + composer.add_production("value", "string | number | object | array | \"true\" | \"false\" | \"null\"") + + return composer.compose() \ No newline at end of file diff --git a/python/sglang/srt/models/seed_oss.py b/python/sglang/srt/models/seed_oss.py new file mode 100644 index 00000000000..f7e9bcf8f1e --- /dev/null +++ b/python/sglang/srt/models/seed_oss.py @@ -0,0 +1,589 @@ +# Adapted from qwen2.py & modified details to support SeedOss +"""Inference-only SeedOss model compatible with HuggingFace weights.""" + +import logging +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +import torch +from torch import nn + +from transformers import SeedOssConfig + +from sglang.srt.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.dp_attention import is_dp_attention_enabled +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) +from sglang.srt.utils import add_prefix, make_layers + + +logger = logging.getLogger(__name__) + + +class SeedOssMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SeedOssAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 4096 * 32, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + if head_dim is not None: + self.head_dim = head_dim + else: + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + + return output + + +class SeedOssDecoderLayer(nn.Module): + def __init__( + self, + config: SeedOssConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + head_dim = getattr(config, "head_dim", None) + + self.self_attn = SeedOssAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=head_dim, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + + self.mlp = SeedOssMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class SeedOssModel(nn.Module): + def __init__( + self, + config: SeedOssConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer, + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() + + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to SeedOssDecoderLayer + decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: decoder_layer_type( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + alt_stream=alt_stream, + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix=add_prefix("layers", prefix), + ) + + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) + # TODO: Add support for EAGLE3 + + def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.get_input_embeddings()(input_ids) + + def get_input_embeddings(self) -> nn.Embedding: + return self.embed_tokens + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + if hidden_states.shape[0] != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) + + +class SeedOssForCausalLM(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: SeedOssConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.pp_group = get_pp_group() + self.config = config + self.quant_config = quant_config + self.model = SeedOssModel( + config, + quant_config=quant_config, + prefix=add_prefix("model", prefix), + ) + + # handle the lm head on different pp ranks + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + else: + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() + + # perform weight tying for PP + if self.pp_group.world_size > 1 and config.tie_word_embeddings: + if self.pp_group.is_first_rank: + self.pp_group.send( + self.model.embed_tokens.weight, dst=self.pp_group.last_rank + ) + else: + emb_token_weight = self.pp_group.recv( + size=(config.vocab_size, config.hidden_size), + dtype=next(self.model.parameters()).dtype, + src=self.pp_group.first_rank, + ) + self.lm_head.weight.copy_(emb_token_weight) + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embedding(input_ids) + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors, + ) + if self.pp_group.is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) + else: + return hidden_states + + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + return result + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: + # Handle pp weight tying here + # find the embed_tokens.weight in the weights + embed_token_weights = next( + filter(lambda x: x[0] == "model.embed_tokens.weight", weights) + )[1] + loaded_weight = embed_token_weights + else: + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + +EntryClass = SeedOssForCausalLM diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py index fd9ce55084f..6e6f12a92a9 100644 --- a/python/sglang/srt/reasoning_parser.py +++ b/python/sglang/srt/reasoning_parser.py @@ -186,6 +186,23 @@ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False) ) +class SeedOssDetector(BaseReasoningFormatDetector): + """ + Detector for SeedOSS model. + Args: + stream_reasoning (bool): If False, accumulates reasoning content until the end tag. + If True, streams reasoning content as it arrives. + """ + def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = True): + # SeedOSS is assumed to be reasoning until `` token + super().__init__( + "", + "", + force_reasoning=force_reasoning, + stream_reasoning=stream_reasoning, + ) + + class GptOssDetector(BaseReasoningFormatDetector): """ Detector for T4-style reasoning format. @@ -520,6 +537,7 @@ class ReasoningParser: "qwen3": Qwen3Detector, "qwen3-thinking": Qwen3Detector, "step3": DeepSeekR1Detector, + "seed_oss": SeedOssDetector, } def __init__(