Skip to content

Commit 43790e2

Browse files
author
Ryan O'Hara-Reid
committed
Add func MatchReturnResponses
1 parent 139384d commit 43790e2

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

exchanges/stream/stream_types.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ type Connection interface {
3535
SetProxy(string)
3636
GetURL() string
3737
Shutdown() error
38+
// MatchReturnResponses sets up a channel to listen for an expected number of responses. This is used for when a
39+
// request is sent and a response is expected in a different connection. Please see implementation in
40+
// websocket_connection.go
41+
MatchReturnResponses(ctx context.Context, signature any, expected int) (<-chan MatchedResponse, error)
3842
}
3943

4044
// Inspector is a hook that allows for custom message inspection

exchanges/stream/websocket_connection.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,32 @@ func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature an
366366
return resps, nil
367367
}
368368

369+
// MatchedResponse encapsulates the matched responses along with any errors encountered.
370+
type MatchedResponse struct {
371+
Responses [][]byte
372+
Err error
373+
}
374+
375+
// MatchReturnResponses sets up a channel to listen for an expected number of responses. These responses may not
376+
// originate from the same connection as the request, but can come from an alternative connection. It returns a channel
377+
// that will receive a MatchedResponse containing the collected responses or an error.
378+
func (w *WebsocketConnection) MatchReturnResponses(ctx context.Context, signature any, expected int) (<-chan MatchedResponse, error) {
379+
connectionListen, err := w.Match.Set(signature, expected)
380+
if err != nil {
381+
return nil, err
382+
}
383+
384+
out := make(chan MatchedResponse, 1) // buffered so routine below doesn't leak
385+
386+
go func() {
387+
resps, err := w.waitForResponses(ctx, signature, connectionListen, expected)
388+
out <- MatchedResponse{Responses: resps, Err: err}
389+
close(out)
390+
}()
391+
392+
return out, nil
393+
}
394+
369395
func removeURLQueryString(url string) string {
370396
if index := strings.Index(url, "?"); index != -1 {
371397
return url[:index]

exchanges/stream/websocket_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,3 +1525,24 @@ func TestGetConnection(t *testing.T) {
15251525
require.NoError(t, err)
15261526
assert.Same(t, expected, conn)
15271527
}
1528+
1529+
func TestMatchReturnResponses(t *testing.T) {
1530+
t.Parallel()
1531+
1532+
conn := WebsocketConnection{Match: NewMatch()}
1533+
_, err := conn.MatchReturnResponses(context.Background(), nil, 0)
1534+
require.ErrorIs(t, err, errInvalidBufferSize)
1535+
1536+
ch, err := conn.MatchReturnResponses(context.Background(), nil, 1)
1537+
require.NoError(t, err)
1538+
1539+
require.ErrorIs(t, (<-ch).Err, ErrSignatureTimeout)
1540+
conn.ResponseMaxLimit = time.Millisecond
1541+
1542+
ch, err = conn.MatchReturnResponses(context.Background(), nil, 1)
1543+
require.NoError(t, err)
1544+
1545+
exp := []byte("test")
1546+
require.True(t, conn.Match.IncomingWithData(nil, exp))
1547+
require.Equal(t, (<-ch).Responses[0], exp)
1548+
}

0 commit comments

Comments
 (0)