Skip to content

Commit e1f7cf5

Browse files
authored
[router] additional llama32 parser unit test and multi json support (sgl-project#9732)
1 parent 2bb9d45 commit e1f7cf5

File tree

2 files changed

+296
-4
lines changed

2 files changed

+296
-4
lines changed

sgl-router/src/tool_parser/parsers/json_parser.rs

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,92 @@ impl Default for JsonParser {
242242
#[async_trait]
243243
impl ToolParser for JsonParser {
244244
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
245+
// Check if we have multiple start tokens (e.g., multiple <|python_tag|> markers)
246+
if !self.token_config.start_tokens.is_empty() {
247+
let start_token = &self.token_config.start_tokens[0];
248+
if !start_token.is_empty() && text.matches(start_token).count() > 1 {
249+
// We have multiple occurrences of the start token
250+
let mut all_tools = Vec::new();
251+
let mut remaining = text;
252+
253+
while let Some(start_pos) = remaining.find(start_token.as_str()) {
254+
// Extract content after this start token
255+
let after_token = &remaining[start_pos + start_token.len()..];
256+
257+
// Find where this JSON ends (look for the next start token or end of string)
258+
let end_pos = if let Some(next_start) = after_token.find(start_token.as_str()) {
259+
next_start
260+
} else {
261+
after_token.len()
262+
};
263+
264+
let json_content = &after_token[..end_pos];
265+
266+
// Try to extract and parse JSON from this segment
267+
if let Some(extracted) = self.extract_json_from_text(json_content) {
268+
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
269+
if let Ok(tools) = self.parse_json_value(&value) {
270+
all_tools.extend(tools);
271+
}
272+
}
273+
}
274+
275+
// Move to the next segment
276+
remaining = &remaining[start_pos + start_token.len() + end_pos..];
277+
if remaining.is_empty() {
278+
break;
279+
}
280+
}
281+
282+
if !all_tools.is_empty() {
283+
return Ok(all_tools);
284+
}
285+
}
286+
}
287+
245288
// Extract JSON content from wrapper tokens if present
246289
let json_content = self.extract_json_content(text);
247290

248-
// Try to parse as JSON
291+
// Try to parse as JSON first
249292
match serde_json::from_str::<Value>(json_content) {
250293
Ok(value) => self.parse_json_value(&value),
251294
Err(_) => {
295+
// If parse failed, check if we have multiple JSON objects separated by the configured separator
296+
// This handles cases like: {"name": "func1", ...};{"name": "func2", ...}
297+
if !self.token_config.separator.is_empty()
298+
&& json_content.contains(&self.token_config.separator)
299+
{
300+
let mut all_tools = Vec::new();
301+
302+
// Split by separator and try to parse each part
303+
let parts: Vec<&str> =
304+
json_content.split(&self.token_config.separator).collect();
305+
for part in parts {
306+
let trimmed = part.trim();
307+
if trimmed.is_empty() {
308+
continue;
309+
}
310+
311+
// Try to parse this part as JSON
312+
if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
313+
if let Ok(tools) = self.parse_json_value(&value) {
314+
all_tools.extend(tools);
315+
}
316+
} else if let Some(extracted) = self.extract_json_from_text(trimmed) {
317+
// Try extracting JSON from this part
318+
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
319+
if let Ok(tools) = self.parse_json_value(&value) {
320+
all_tools.extend(tools);
321+
}
322+
}
323+
}
324+
}
325+
326+
if !all_tools.is_empty() {
327+
return Ok(all_tools);
328+
}
329+
}
330+
252331
// If no wrapper tokens configured and parse failed,
253332
// try to extract JSON from mixed text
254333
if self.token_config.start_tokens.is_empty() {
@@ -350,9 +429,11 @@ impl ToolParser for JsonParser {
350429
Value::Array(ref arr) => {
351430
// Check if array contains tool-like objects
352431
arr.iter().any(|v| {
353-
v.as_object().is_some_and(|o| {
354-
o.contains_key("name") || o.contains_key("function")
355-
})
432+
if let Some(obj) = v.as_object() {
433+
obj.contains_key("name") || obj.contains_key("function")
434+
} else {
435+
false
436+
}
356437
})
357438
}
358439
_ => false,

sgl-router/tests/tool_parser_llama.rs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,214 @@ async fn test_llama_json_array_format() {
141141
// Current implementation might handle this through JSON fallback
142142
assert!(!result.is_empty());
143143
}
144+
145+
#[tokio::test]
146+
async fn test_single_json() {
147+
// Test parsing plain JSON without python_tag
148+
let parser = LlamaParser::new();
149+
let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
150+
151+
let result = parser.parse_complete(text).await.unwrap();
152+
assert_eq!(result.len(), 1);
153+
assert_eq!(result[0].function.name, "get_weather");
154+
155+
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
156+
assert_eq!(args["city"], "Paris");
157+
}
158+
159+
#[tokio::test]
160+
async fn test_multiple_json_with_separator() {
161+
// Test multiple JSON objects with semicolon separator
162+
let parser = LlamaParser::new();
163+
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#;
164+
165+
let result = parser.parse_complete(text).await.unwrap();
166+
// Note: Current implementation may only parse the first one due to semicolon handling
167+
assert!(!result.is_empty());
168+
assert_eq!(result[0].function.name, "get_weather");
169+
}
170+
171+
#[tokio::test]
172+
async fn test_multiple_json_with_separator_customized() {
173+
// Test multiple JSON objects with python_tag repeated
174+
let parser = LlamaParser::new();
175+
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#;
176+
177+
let result = parser.parse_complete(text).await.unwrap();
178+
// Current implementation may handle this differently
179+
assert!(!result.is_empty());
180+
assert_eq!(result[0].function.name, "get_weather");
181+
}
182+
183+
#[tokio::test]
184+
async fn test_json_with_trailing_text() {
185+
// Test JSON with trailing text after
186+
let parser = LlamaParser::new();
187+
let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#;
188+
189+
let result = parser.parse_complete(text).await.unwrap();
190+
assert_eq!(result.len(), 1);
191+
assert_eq!(result[0].function.name, "get_weather");
192+
}
193+
194+
#[tokio::test]
195+
async fn test_invalid_then_valid_json() {
196+
// Test error recovery - invalid JSON followed by valid JSON
197+
let parser = LlamaParser::new();
198+
let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#;
199+
200+
let result = parser.parse_complete(text).await.unwrap();
201+
// Should parse at least one valid JSON
202+
if !result.is_empty() {
203+
assert_eq!(result[0].function.name, "get_weather");
204+
}
205+
}
206+
207+
#[tokio::test]
208+
async fn test_plain_text_only() {
209+
// Test plain text with no tool calls
210+
let parser = LlamaParser::new();
211+
let text = "This is just plain explanation text.";
212+
213+
let result = parser.parse_complete(text).await.unwrap();
214+
assert_eq!(result.len(), 0);
215+
}
216+
217+
#[tokio::test]
218+
async fn test_with_python_tag_prefix() {
219+
// Test text before python_tag
220+
let parser = LlamaParser::new();
221+
let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#;
222+
223+
let result = parser.parse_complete(text).await.unwrap();
224+
assert_eq!(result.len(), 1);
225+
assert_eq!(result[0].function.name, "get_weather");
226+
}
227+
228+
// ============================================================================
229+
// STREAMING TESTS
230+
// ============================================================================
231+
232+
#[tokio::test]
233+
async fn test_llama_streaming_simple() {
234+
let parser = LlamaParser::new();
235+
let mut state = sglang_router_rs::tool_parser::ParseState::new();
236+
237+
// Send complete JSON at once
238+
let full_json = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
239+
240+
let result = parser
241+
.parse_incremental(full_json, &mut state)
242+
.await
243+
.unwrap();
244+
245+
match result {
246+
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
247+
assert_eq!(tool.function.name, "search");
248+
}
249+
_ => panic!("Expected ToolComplete for complete JSON input"),
250+
}
251+
}
252+
253+
#[tokio::test]
254+
async fn test_llama_streaming_partial() {
255+
let parser = LlamaParser::new();
256+
let mut state = sglang_router_rs::tool_parser::ParseState::new();
257+
258+
// Stream in chunks
259+
let chunks = vec![
260+
r#"<|python"#,
261+
r#"_tag|>{"name": "#,
262+
r#""calculate", "#,
263+
r#""arguments": {"x": 10}"#,
264+
r#"}"#,
265+
];
266+
267+
let mut got_complete = false;
268+
269+
for chunk in chunks {
270+
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
271+
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
272+
assert_eq!(tool.function.name, "calculate");
273+
got_complete = true;
274+
}
275+
}
276+
277+
assert!(got_complete, "Should have completed parsing");
278+
}
279+
280+
#[tokio::test]
281+
async fn test_llama_streaming_plain_json() {
282+
let parser = LlamaParser::new();
283+
let mut state = sglang_router_rs::tool_parser::ParseState::new();
284+
285+
// Stream plain JSON without python_tag
286+
let chunks = vec![
287+
r#"{"name": "#,
288+
r#""search", "#,
289+
r#""arguments": "#,
290+
r#"{"query": "#,
291+
r#""test"}}"#,
292+
];
293+
294+
let mut got_complete = false;
295+
296+
for chunk in chunks {
297+
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
298+
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
299+
assert_eq!(tool.function.name, "search");
300+
got_complete = true;
301+
}
302+
}
303+
304+
assert!(got_complete, "Should have completed parsing");
305+
}
306+
307+
#[tokio::test]
308+
async fn test_llama_streaming_with_text_before() {
309+
let parser = LlamaParser::new();
310+
let mut state = sglang_router_rs::tool_parser::ParseState::new();
311+
312+
let chunks = vec![
313+
r#"Let me help you. "#,
314+
r#"<|python_tag|>"#,
315+
r#"{"name": "get_time","#,
316+
r#" "arguments": {"#,
317+
r#""timezone": "UTC"}}"#,
318+
];
319+
320+
let mut got_complete = false;
321+
322+
for chunk in chunks {
323+
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
324+
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
325+
assert_eq!(tool.function.name, "get_time");
326+
got_complete = true;
327+
}
328+
}
329+
330+
assert!(got_complete, "Should have completed parsing");
331+
}
332+
333+
#[tokio::test]
334+
async fn test_llama_streaming_multiple_tools() {
335+
// Test streaming multiple tool calls with semicolon separator
336+
let parser = LlamaParser::new();
337+
let mut state = sglang_router_rs::tool_parser::ParseState::new();
338+
339+
let text =
340+
r#"<|python_tag|>{"name": "func1", "arguments": {}};{"name": "func2", "arguments": {}}"#;
341+
342+
let result = parser.parse_incremental(text, &mut state).await.unwrap();
343+
344+
// Current implementation may handle this differently
345+
match result {
346+
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
347+
// At minimum should get first tool
348+
assert_eq!(tool.function.name, "func1");
349+
}
350+
_ => {
351+
// Also acceptable if waiting for more
352+
}
353+
}
354+
}

0 commit comments

Comments
 (0)