Skip to content

Commit 03fed0c

Browse files
authored
fix(server): apply error validation to event iterator (#893)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Improved support for streaming responses: consistent error mapping for event iterators and preserved signal propagation during streaming. - Bug Fixes - Ensures errors thrown inside streaming handlers are correctly transformed and surfaced, including both mapped and generic errors. - Preserves iterator instances returned by handlers without unintended modification. - Tests - Added coverage for event iterator behavior, error flows, and signal handling to prevent regressions. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 3b7800f commit 03fed0c

File tree

2 files changed

+71
-12
lines changed

2 files changed

+71
-12
lines changed

packages/server/src/procedure-client.test.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { ORPCError } from '@orpc/client'
2+
import { HibernationEventIterator } from '@orpc/standard-server'
23
import * as z from 'zod'
34
import { createORPCErrorConstructorMap, validateORPCError } from './error'
45
import { isLazy, lazy, unlazy } from './lazy'
@@ -39,6 +40,16 @@ const procedure = new Procedure({
3940
meta: {},
4041
})
4142

43+
const unvalidatedProcedure = new Procedure({
44+
errorMap: baseErrors,
45+
route: {},
46+
handler,
47+
middlewares: [preMid1, preMid2, postMid1, postMid2],
48+
inputValidationIndex: 2,
49+
outputValidationIndex: 2,
50+
meta: {},
51+
})
52+
4253
const procedureCases = [
4354
['without lazy', procedure],
4455
['with lazy', lazy(() => Promise.resolve({ default: procedure }))],
@@ -462,6 +473,39 @@ describe.each(procedureCases)('createProcedureClient - case %s', async (_, proce
462473
expect(validateORPCError).toBeCalledTimes(1)
463474
expect(validateORPCError).toBeCalledWith(baseErrors, e1)
464475
})
476+
477+
describe('event iterator', async () => {
478+
const client = createProcedureClient(unvalidatedProcedure)
479+
480+
it('throw non-ORPCError right away', async () => {
481+
const e1 = new Error('non-ORPC Error')
482+
handler.mockImplementationOnce(async function* () {
483+
throw e1
484+
} as any)
485+
486+
const iterator = await client({ val: '123' }) as any
487+
488+
await expect(iterator.next()).rejects.toBe(e1)
489+
})
490+
491+
it('validate ORPC Error', async () => {
492+
const e1 = new ORPCError('BAD_REQUEST')
493+
const e2 = new ORPCError('BAD_REQUEST', { defined: true })
494+
495+
handler.mockImplementationOnce(async function* () {
496+
throw e1
497+
} as any)
498+
vi.mocked(validateORPCError).mockReturnValueOnce(Promise.resolve(e2))
499+
500+
// signal here for test coverage
501+
const iterator = await client({ val: '123' }, { signal: AbortSignal.timeout(10) }) as any
502+
503+
await expect(iterator.next()).rejects.toBe(e2)
504+
505+
expect(validateORPCError).toBeCalledTimes(1)
506+
expect(validateORPCError).toBeCalledWith(baseErrors, e1)
507+
})
508+
})
465509
})
466510

467511
it('with client context', async () => {
@@ -510,6 +554,13 @@ describe.each(procedureCases)('createProcedureClient - case %s', async (_, proce
510554
expect((handler as any).mock.calls[3][0].context.preMid2).toBe(6)
511555
expect((handler as any).mock.calls[3][0].context.postMid1).toBe(7)
512556
})
557+
558+
it('not modify HibernationEventIterator', async () => {
559+
const client = createProcedureClient(unvalidatedProcedure)
560+
const iterator = new HibernationEventIterator(() => {})
561+
handler.mockResolvedValueOnce(iterator as any)
562+
await expect(client({ val: '123' })).resolves.toBe(iterator)
563+
})
513564
})
514565

515566
it('still work without InputSchema', async () => {

packages/server/src/procedure-client.ts

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import type { Context } from './context'
55
import type { ORPCErrorConstructorMap } from './error'
66
import type { Lazyable } from './lazy'
77
import type { AnyProcedure, Procedure, ProcedureHandlerOptions } from './procedure'
8-
import { ORPCError } from '@orpc/client'
8+
import { mapEventIterator, ORPCError } from '@orpc/client'
99
import { ValidationError } from '@orpc/contract'
1010
import { asyncIteratorWithSpan, intercept, isAsyncIteratorObject, resolveMaybeOptionalOptions, runWithSpan, toArray, value } from '@orpc/shared'
1111
import { HibernationEventIterator } from '@orpc/standard-server'
@@ -98,6 +98,14 @@ export function createProcedureClient<
9898
const context = await value(options.context ?? {} as TInitialContext, clientContext)
9999
const errors = createORPCErrorConstructorMap(procedure['~orpc'].errorMap)
100100

101+
const validateError = async (e: unknown) => {
102+
if (e instanceof ORPCError) {
103+
return await validateORPCError(procedure['~orpc'].errorMap, e)
104+
}
105+
106+
return e
107+
}
108+
101109
try {
102110
const output = await runWithSpan(
103111
{ name: 'call_procedure', signal: callerOptions?.signal },
@@ -129,29 +137,29 @@ export function createProcedureClient<
129137
}
130138

131139
/**
132-
* asyncIteratorWithSpan return AsyncIteratorClass
140+
* asyncIteratorWithSpan/mapEventIterator return AsyncIteratorClass
133141
* which is backwards compatible with Event Iterator & almost async iterator.
134142
*
135143
* @warning
136144
* If remove this return, can be breaking change
137145
* because AsyncIteratorClass convert `.throw` to `.return` (rarely used)
138146
*/
139-
return asyncIteratorWithSpan(
140-
{ name: 'consume_event_iterator_output', signal: callerOptions?.signal },
141-
output,
147+
return mapEventIterator(
148+
asyncIteratorWithSpan(
149+
{ name: 'consume_event_iterator_output', signal: callerOptions?.signal },
150+
output,
151+
),
152+
{
153+
value: v => v,
154+
error: e => validateError(e),
155+
},
142156
) as typeof output
143157
}
144158

145159
return output
146160
}
147161
catch (e) {
148-
if (!(e instanceof ORPCError)) {
149-
throw e
150-
}
151-
152-
const validated = await validateORPCError(procedure['~orpc'].errorMap, e)
153-
154-
throw validated
162+
throw await validateError(e)
155163
}
156164
}
157165
}

0 commit comments

Comments
 (0)