Skip to content

Commit 3a8bf4f

Browse files
committed
fix: handle SSE error events from providers like Groq
- Parse error events that use explicit SSE event types (event: error) - Extract only JSON data from error events, skipping event type lines - Reset error accumulator after each error to prevent data corruption - Add comprehensive test coverage for both OpenAI and Groq error formats - Maintain backward compatibility with OpenAI's simpler format Fixes parsing of error events from Groq API when invalid tool calls occur
1 parent c4273cb commit 3a8bf4f

File tree

2 files changed

+159
-10
lines changed

2 files changed

+159
-10
lines changed

stream_reader.go

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai
33
import (
44
"bufio"
55
"bytes"
6+
"encoding/json"
67
"fmt"
78
"io"
89
"net/http"
@@ -40,6 +41,12 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
4041

4142
err = stream.unmarshaler.Unmarshal(rawLine, &response)
4243
if err != nil {
44+
// If we get a JSON parsing error, it might be because we got an error event
45+
// Check if we have accumulated error data
46+
if _, ok := err.(*json.SyntaxError); ok && len(stream.errAccumulator.Bytes()) > 0 {
47+
// We have error data, return a more informative error
48+
return response, fmt.Errorf("failed to parse response (error event received): %s", string(stream.errAccumulator.Bytes()))
49+
}
4350
return
4451
}
4552
return response, nil
@@ -65,28 +72,43 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
6572
if readErr != nil || hasErrorPrefix {
6673
respErr := stream.unmarshalError()
6774
if respErr != nil {
68-
return nil, fmt.Errorf("error, %w", respErr.Error)
75+
return nil, respErr.Error
76+
}
77+
// If we detected an error event but couldn't parse it, and the stream ended,
78+
// return a more informative error. This handles cases where providers send
79+
// error events that don't match the expected format and immediately close.
80+
if hasErrorPrefix && readErr == io.EOF {
81+
// Check if we have error data that failed to parse
82+
errBytes := stream.errAccumulator.Bytes()
83+
if len(errBytes) > 0 {
84+
return nil, fmt.Errorf("failed to parse error event: %s", string(errBytes))
85+
}
86+
return nil, fmt.Errorf("stream ended after error event")
6987
}
7088
return nil, readErr
7189
}
7290

7391
noSpaceLine := bytes.TrimSpace(rawLine)
7492
if errorPrefix.Match(noSpaceLine) {
7593
hasErrorPrefix = true
76-
}
77-
if !headerData.Match(noSpaceLine) || hasErrorPrefix {
78-
if hasErrorPrefix {
79-
noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil)
80-
}
81-
writeErr := stream.errAccumulator.Write(noSpaceLine)
94+
// Extract just the JSON part after "data: " prefix
95+
// This handles both OpenAI format (data: {"error": ...}) and
96+
// Groq format (event: error\ndata: {"error": ...})
97+
jsonData := headerData.ReplaceAll(noSpaceLine, nil)
98+
writeErr := stream.errAccumulator.Write(jsonData)
8299
if writeErr != nil {
83100
return nil, writeErr
84101
}
102+
continue
103+
}
104+
105+
// Skip non-data lines (e.g., "event: error" from Groq)
106+
// This allows us to handle SSE streams that use explicit event types
107+
if !headerData.Match(noSpaceLine) {
85108
emptyMessagesCount++
86109
if emptyMessagesCount > stream.emptyMessagesLimit {
87110
return nil, ErrTooManyEmptyStreamMessages
88111
}
89-
90112
continue
91113
}
92114

@@ -111,6 +133,10 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
111133
errResp = nil
112134
}
113135

136+
// Reset the error accumulator for future error events
137+
// A new accumulator is created to avoid potential interface issues
138+
stream.errAccumulator = utils.NewErrorAccumulator()
139+
114140
return
115141
}
116142

stream_reader_test.go

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) {
5454

5555
func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
5656
stream := &streamReader[ChatCompletionStreamResponse]{
57-
reader: bufio.NewReader(bytes.NewReader([]byte("\n"))),
57+
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"error\": {\"message\": \"test error\"}}\n"))),
5858
errAccumulator: &utils.DefaultErrorAccumulator{
5959
Buffer: &test.FailingErrorBuffer{},
6060
},
61-
unmarshaler: &utils.JSONUnmarshaler{},
61+
unmarshaler: &utils.JSONUnmarshaler{},
62+
emptyMessagesLimit: 5,
6263
}
6364
_, err := stream.Recv()
6465
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
@@ -76,3 +77,125 @@ func TestStreamReaderRecvRaw(t *testing.T) {
7677
t.Fatalf("Did not return raw line: %v", string(rawLine))
7778
}
7879
}
80+
81+
func TestStreamReaderParsesErrorEvents(t *testing.T) {
82+
// Test case simulating Groq's error event format
83+
errorEvent := `event: error
84+
data: {"error":{"message":"Invalid tool_call: tool \"name_unknown\" does not exist.","type":"invalid_request_error","code":"invalid_tool_call"}}
85+
86+
`
87+
stream := &streamReader[ChatCompletionStreamResponse]{
88+
reader: bufio.NewReader(bytes.NewReader([]byte(errorEvent))),
89+
errAccumulator: utils.NewErrorAccumulator(),
90+
unmarshaler: &utils.JSONUnmarshaler{},
91+
emptyMessagesLimit: 5,
92+
}
93+
94+
// Process the error event
95+
_, err := stream.Recv()
96+
if err == nil {
97+
t.Fatal("Expected error but got nil")
98+
}
99+
100+
// Verify it's an APIError
101+
apiErr, ok := err.(*APIError)
102+
if !ok {
103+
t.Fatalf("Expected APIError type but got %T: %v", err, err)
104+
}
105+
106+
// Verify the error fields are correctly parsed
107+
if apiErr.Message != "Invalid tool_call: tool \"name_unknown\" does not exist." {
108+
t.Errorf("Unexpected error message: %s", apiErr.Message)
109+
}
110+
if apiErr.Type != "invalid_request_error" {
111+
t.Errorf("Unexpected error type: %s", apiErr.Type)
112+
}
113+
if apiErr.Code != "invalid_tool_call" {
114+
t.Errorf("Unexpected error code: %v", apiErr.Code)
115+
}
116+
}
117+
118+
func TestStreamReaderHandlesErrorEventWithExtraData(t *testing.T) {
119+
// Test case with error event followed by more data
120+
errorEvent := `data: {"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}
121+
event: error
122+
data: {"error":{"message":"Stream interrupted","type":"server_error"}}
123+
data: [DONE]
124+
`
125+
stream := &streamReader[ChatCompletionStreamResponse]{
126+
reader: bufio.NewReader(bytes.NewReader([]byte(errorEvent))),
127+
errAccumulator: utils.NewErrorAccumulator(),
128+
unmarshaler: &utils.JSONUnmarshaler{},
129+
emptyMessagesLimit: 5,
130+
}
131+
132+
// First recv should return the chat completion
133+
resp, err := stream.Recv()
134+
if err != nil {
135+
t.Fatalf("First recv failed: %v", err)
136+
}
137+
if resp.ID != "chatcmpl-123" {
138+
t.Errorf("Unexpected response ID: %s", resp.ID)
139+
}
140+
141+
// Second recv should return the error
142+
_, err = stream.Recv()
143+
if err == nil {
144+
t.Fatal("Expected error but got nil")
145+
}
146+
147+
// Verify it's an APIError
148+
apiErr, ok := err.(*APIError)
149+
if !ok {
150+
t.Fatalf("Expected APIError type but got %T: %v", err, err)
151+
}
152+
153+
if apiErr.Message != "Stream interrupted" {
154+
t.Errorf("Unexpected error message: %s", apiErr.Message)
155+
}
156+
}
157+
158+
func TestStreamReaderResetsErrorAccumulator(t *testing.T) {
159+
// Test that error accumulator is reset after processing an error
160+
multipleErrors := `event: error
161+
data: {"error":{"message":"First error","type":"error_type_1"}}
162+
163+
event: error
164+
data: {"error":{"message":"Second error","type":"error_type_2"}}
165+
`
166+
stream := &streamReader[ChatCompletionStreamResponse]{
167+
reader: bufio.NewReader(bytes.NewReader([]byte(multipleErrors))),
168+
errAccumulator: utils.NewErrorAccumulator(),
169+
unmarshaler: &utils.JSONUnmarshaler{},
170+
emptyMessagesLimit: 5,
171+
}
172+
173+
// First recv should return the first error
174+
_, err1 := stream.Recv()
175+
if err1 == nil {
176+
t.Fatal("Expected first error but got nil")
177+
}
178+
apiErr1, ok := err1.(*APIError)
179+
if !ok {
180+
t.Fatalf("Expected APIError type but got %T: %v", err1, err1)
181+
}
182+
if apiErr1.Message != "First error" {
183+
t.Errorf("Unexpected first error message: %s", apiErr1.Message)
184+
}
185+
186+
// Second recv should return the second error (not a concatenation)
187+
_, err2 := stream.Recv()
188+
if err2 == nil {
189+
t.Fatal("Expected second error but got nil")
190+
}
191+
apiErr2, ok := err2.(*APIError)
192+
if !ok {
193+
t.Fatalf("Expected APIError type but got %T: %v", err2, err2)
194+
}
195+
if apiErr2.Message != "Second error" {
196+
t.Errorf("Unexpected second error message: %s", apiErr2.Message)
197+
}
198+
if apiErr2.Type != "error_type_2" {
199+
t.Errorf("Unexpected second error type: %s", apiErr2.Type)
200+
}
201+
}

0 commit comments

Comments
 (0)