Skip to content

Commit b0e56f2

Browse files
committed
align return types from execution and subscription
with respect to possible promises
1 parent 540bb38 commit b0e56f2

File tree

2 files changed

+136
-72
lines changed

2 files changed

+136
-72
lines changed

src/execution/__tests__/subscribe-test.ts

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ import { expectJSON } from '../../__testUtils__/expectJSON';
55
import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick';
66

77
import { isAsyncIterable } from '../../jsutils/isAsyncIterable';
8+
import { isPromise } from '../../jsutils/isPromise';
9+
import type { PromiseOrValue } from '../../jsutils/PromiseOrValue';
810

911
import { parse } from '../../language/parser';
1012

1113
import { GraphQLList, GraphQLObjectType } from '../../type/definition';
1214
import { GraphQLBoolean, GraphQLInt, GraphQLString } from '../../type/scalars';
1315
import { GraphQLSchema } from '../../type/schema';
1416

17+
import type { ExecutionResult } from '../execute';
1518
import { createSourceEventStream, subscribe } from '../subscribe';
1619

1720
import { SimplePubSub } from './simplePubSub';
@@ -134,9 +137,6 @@ async function expectPromise(promise: Promise<unknown>) {
134137
}
135138

136139
return {
137-
toReject() {
138-
expect(caughtError).to.be.an.instanceOf(Error);
139-
},
140140
toRejectWith(message: string) {
141141
expect(caughtError).to.be.an.instanceOf(Error);
142142
expect(caughtError).to.have.property('message', message);
@@ -151,6 +151,34 @@ const DummyQueryType = new GraphQLObjectType({
151151
},
152152
});
153153

154+
function subscribeWithFn(
155+
subscribeFn: () => unknown,
156+
): PromiseOrValue<ExecutionResult | AsyncGenerator<ExecutionResult>> {
157+
const schema = new GraphQLSchema({
158+
query: DummyQueryType,
159+
subscription: new GraphQLObjectType({
160+
name: 'Subscription',
161+
fields: {
162+
foo: { type: GraphQLString, subscribe: subscribeFn },
163+
},
164+
}),
165+
});
166+
const document = parse('subscription { foo }');
167+
const result = subscribe({ schema, document });
168+
169+
if (isPromise(result)) {
170+
return result.then(async (resolvedResult) => {
171+
expectJSON(await createSourceEventStream(schema, document)).toDeepEqual(
172+
resolvedResult,
173+
);
174+
return result;
175+
});
176+
}
177+
178+
expectJSON(createSourceEventStream(schema, document)).toDeepEqual(result);
179+
return result;
180+
}
181+
154182
/* eslint-disable @typescript-eslint/require-await */
155183
// Check all error cases when initializing the subscription.
156184
describe('Subscription Initialization Phase', () => {
@@ -356,24 +384,22 @@ describe('Subscription Initialization Phase', () => {
356384
});
357385

358386
// @ts-expect-error (schema must not be null)
359-
(await expectPromise(subscribe({ schema: null, document }))).toRejectWith(
387+
expect(() => subscribe({ schema: null, document })).to.throw(
360388
'Expected null to be a GraphQL schema.',
361389
);
362390

363391
// @ts-expect-error
364-
(await expectPromise(subscribe({ document }))).toRejectWith(
392+
expect(() => subscribe({ document })).to.throw(
365393
'Expected undefined to be a GraphQL schema.',
366394
);
367395

368396
// @ts-expect-error (document must not be null)
369-
(await expectPromise(subscribe({ schema, document: null }))).toRejectWith(
397+
expect(() => subscribe({ schema, document: null })).to.throw(
370398
'Must provide document.',
371399
);
372400

373401
// @ts-expect-error
374-
(await expectPromise(subscribe({ schema }))).toRejectWith(
375-
'Must provide document.',
376-
);
402+
expect(() => subscribe({ schema })).to.throw('Must provide document.');
377403
});
378404

379405
it('resolves to an error if schema does not support subscriptions', async () => {
@@ -427,50 +453,22 @@ describe('Subscription Initialization Phase', () => {
427453
});
428454

429455
// @ts-expect-error
430-
(await expectPromise(subscribe({ schema, document: {} }))).toReject();
456+
expect(() => subscribe({ schema, document: {} })).to.throw();
431457
});
432458

433459
it('throws an error if subscribe does not return an iterator', async () => {
434-
const schema = new GraphQLSchema({
435-
query: DummyQueryType,
436-
subscription: new GraphQLObjectType({
437-
name: 'Subscription',
438-
fields: {
439-
foo: {
440-
type: GraphQLString,
441-
subscribe: () => 'test',
442-
},
443-
},
444-
}),
445-
});
446-
447-
const document = parse('subscription { foo }');
460+
expect(() => subscribeWithFn(() => 'test')).to.throw(
461+
'Subscription field must return Async Iterable. Received: "test".',
462+
);
448463

449-
(await expectPromise(subscribe({ schema, document }))).toRejectWith(
464+
const result = subscribeWithFn(() => Promise.resolve('test'));
465+
assert(isPromise(result));
466+
(await expectPromise(result)).toRejectWith(
450467
'Subscription field must return Async Iterable. Received: "test".',
451468
);
452469
});
453470

454471
it('resolves to an error for subscription resolver errors', async () => {
455-
async function subscribeWithFn(subscribeFn: () => unknown) {
456-
const schema = new GraphQLSchema({
457-
query: DummyQueryType,
458-
subscription: new GraphQLObjectType({
459-
name: 'Subscription',
460-
fields: {
461-
foo: { type: GraphQLString, subscribe: subscribeFn },
462-
},
463-
}),
464-
});
465-
const document = parse('subscription { foo }');
466-
const result = await subscribe({ schema, document });
467-
468-
expectJSON(await createSourceEventStream(schema, document)).toDeepEqual(
469-
result,
470-
);
471-
return result;
472-
}
473-
474472
const expectedResult = {
475473
errors: [
476474
{
@@ -483,12 +481,12 @@ describe('Subscription Initialization Phase', () => {
483481

484482
expectJSON(
485483
// Returning an error
486-
await subscribeWithFn(() => new Error('test error')),
484+
subscribeWithFn(() => new Error('test error')),
487485
).toDeepEqual(expectedResult);
488486

489487
expectJSON(
490488
// Throwing an error
491-
await subscribeWithFn(() => {
489+
subscribeWithFn(() => {
492490
throw new Error('test error');
493491
}),
494492
).toDeepEqual(expectedResult);

src/execution/subscribe.ts

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import { inspect } from '../jsutils/inspect';
22
import { isAsyncIterable } from '../jsutils/isAsyncIterable';
3+
import { isPromise } from '../jsutils/isPromise';
34
import type { Maybe } from '../jsutils/Maybe';
45
import { addPath, pathToArray } from '../jsutils/Path';
6+
import type { PromiseOrValue } from '../jsutils/PromiseOrValue';
57

68
import { GraphQLError } from '../error/GraphQLError';
79
import { locatedError } from '../error/locatedError';
@@ -47,9 +49,11 @@ import { getArgumentValues } from './values';
4749
*
4850
* Accepts either an object with named arguments, or individual arguments.
4951
*/
50-
export async function subscribe(
52+
export function subscribe(
5153
args: ExecutionArgs,
52-
): Promise<AsyncGenerator<ExecutionResult, void, void> | ExecutionResult> {
54+
): PromiseOrValue<
55+
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
56+
> {
5357
const {
5458
schema,
5559
document,
@@ -61,7 +65,7 @@ export async function subscribe(
6165
subscribeFieldResolver,
6266
} = args;
6367

64-
const resultOrStream = await createSourceEventStream(
68+
const resultOrStream = createSourceEventStream(
6569
schema,
6670
document,
6771
rootValue,
@@ -71,6 +75,42 @@ export async function subscribe(
7175
subscribeFieldResolver,
7276
);
7377

78+
if (isPromise(resultOrStream)) {
79+
return resultOrStream.then((resolvedResultOrStream) =>
80+
mapSourceToResponse(
81+
schema,
82+
document,
83+
resolvedResultOrStream,
84+
contextValue,
85+
variableValues,
86+
operationName,
87+
fieldResolver,
88+
),
89+
);
90+
}
91+
92+
return mapSourceToResponse(
93+
schema,
94+
document,
95+
resultOrStream,
96+
contextValue,
97+
variableValues,
98+
operationName,
99+
fieldResolver,
100+
);
101+
}
102+
103+
function mapSourceToResponse(
104+
schema: GraphQLSchema,
105+
document: DocumentNode,
106+
resultOrStream: ExecutionResult | AsyncIterable<unknown>,
107+
contextValue?: unknown,
108+
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
109+
operationName?: Maybe<string>,
110+
fieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
111+
): PromiseOrValue<
112+
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
113+
> {
74114
if (!isAsyncIterable(resultOrStream)) {
75115
return resultOrStream;
76116
}
@@ -81,7 +121,7 @@ export async function subscribe(
81121
// the GraphQL specification. The `execute` function provides the
82122
// "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
83123
// "ExecuteQuery" algorithm, for which `execute` is also used.
84-
const mapSourceToResponse = (payload: unknown) =>
124+
return mapAsyncIterator(resultOrStream, (payload: unknown) =>
85125
execute({
86126
schema,
87127
document,
@@ -90,10 +130,8 @@ export async function subscribe(
90130
variableValues,
91131
operationName,
92132
fieldResolver,
93-
});
94-
95-
// Map every source value to a ExecutionResult value as described above.
96-
return mapAsyncIterator(resultOrStream, mapSourceToResponse);
133+
}),
134+
);
97135
}
98136

99137
/**
@@ -124,15 +162,15 @@ export async function subscribe(
124162
* or otherwise separating these two steps. For more on this, see the
125163
* "Supporting Subscriptions at Scale" information in the GraphQL specification.
126164
*/
127-
export async function createSourceEventStream(
165+
export function createSourceEventStream(
128166
schema: GraphQLSchema,
129167
document: DocumentNode,
130168
rootValue?: unknown,
131169
contextValue?: unknown,
132170
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
133171
operationName?: Maybe<string>,
134172
subscribeFieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
135-
): Promise<AsyncIterable<unknown> | ExecutionResult> {
173+
): PromiseOrValue<AsyncIterable<unknown> | ExecutionResult> {
136174
// If arguments are missing or incorrectly typed, this is an internal
137175
// developer mistake which should throw an early error.
138176
assertValidExecutionArguments(schema, document, variableValues);
@@ -155,17 +193,22 @@ export async function createSourceEventStream(
155193
}
156194

157195
try {
158-
const eventStream = await executeSubscription(exeContext);
159-
160-
// Assert field returned an event stream, otherwise yield an error.
161-
if (!isAsyncIterable(eventStream)) {
162-
throw new Error(
163-
'Subscription field must return Async Iterable. ' +
164-
`Received: ${inspect(eventStream)}.`,
165-
);
196+
const eventStream = executeSubscription(exeContext);
197+
198+
if (isPromise(eventStream)) {
199+
return eventStream
200+
.then((resolvedEventStream) => ensureAsyncIterable(resolvedEventStream))
201+
.then(undefined, (error) => {
202+
// If it GraphQLError, report it as an ExecutionResult, containing only errors and no data.
203+
// Otherwise treat the error as a system-class error and re-throw it.
204+
if (error instanceof GraphQLError) {
205+
return { errors: [error] };
206+
}
207+
throw error;
208+
});
166209
}
167210

168-
return eventStream;
211+
return ensureAsyncIterable(eventStream);
169212
} catch (error) {
170213
// If it GraphQLError, report it as an ExecutionResult, containing only errors and no data.
171214
// Otherwise treat the error as a system-class error and re-throw it.
@@ -176,9 +219,19 @@ export async function createSourceEventStream(
176219
}
177220
}
178221

179-
async function executeSubscription(
180-
exeContext: ExecutionContext,
181-
): Promise<unknown> {
222+
function ensureAsyncIterable(eventStream: unknown): AsyncIterable<unknown> {
223+
// Assert field returned an event stream, otherwise yield an error.
224+
if (!isAsyncIterable(eventStream)) {
225+
throw new Error(
226+
'Subscription field must return Async Iterable. ' +
227+
`Received: ${inspect(eventStream)}.`,
228+
);
229+
}
230+
231+
return eventStream;
232+
}
233+
234+
function executeSubscription(exeContext: ExecutionContext): unknown {
182235
const { schema, fragments, operation, variableValues, rootValue } =
183236
exeContext;
184237

@@ -233,13 +286,26 @@ async function executeSubscription(
233286
// Call the `subscribe()` resolver or the default resolver to produce an
234287
// AsyncIterable yielding raw payloads.
235288
const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver;
236-
const eventStream = await resolveFn(rootValue, args, contextValue, info);
237289

238-
if (eventStream instanceof Error) {
239-
throw eventStream;
290+
const eventStream = resolveFn(rootValue, args, contextValue, info);
291+
292+
if (isPromise(eventStream)) {
293+
return eventStream
294+
.then((resolvedEventStream) => throwReturnedError(resolvedEventStream))
295+
.then(undefined, (error) => {
296+
throw locatedError(error, fieldNodes, pathToArray(path));
297+
});
240298
}
241-
return eventStream;
299+
300+
return throwReturnedError(eventStream);
242301
} catch (error) {
243302
throw locatedError(error, fieldNodes, pathToArray(path));
244303
}
245304
}
305+
306+
function throwReturnedError(eventStream: unknown): unknown {
307+
if (eventStream instanceof Error) {
308+
throw eventStream;
309+
}
310+
return eventStream;
311+
}

0 commit comments

Comments
 (0)