Skip to content

align return types from execution and subscription #3620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions src/execution/__tests__/subscribe-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { expectJSON } from '../../__testUtils__/expectJSON';
import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick';

import { isAsyncIterable } from '../../jsutils/isAsyncIterable';
import { isPromise } from '../../jsutils/isPromise';
import type { PromiseOrValue } from '../../jsutils/PromiseOrValue';

import { parse } from '../../language/parser';

Expand Down Expand Up @@ -135,9 +137,6 @@ async function expectPromise(promise: Promise<unknown>) {
}

return {
toReject() {
expect(caughtError).to.be.an.instanceOf(Error);
},
toRejectWith(message: string) {
expect(caughtError).to.be.an.instanceOf(Error);
expect(caughtError).to.have.property('message', message);
Expand All @@ -152,9 +151,9 @@ const DummyQueryType = new GraphQLObjectType({
},
});

async function subscribeWithBadFn(
function subscribeWithBadFn(
subscribeFn: () => unknown,
): Promise<ExecutionResult> {
): PromiseOrValue<ExecutionResult> {
const schema = new GraphQLSchema({
query: DummyQueryType,
subscription: new GraphQLObjectType({
Expand All @@ -165,13 +164,28 @@ async function subscribeWithBadFn(
}),
});
const document = parse('subscription { foo }');
const result = await subscribe({ schema, document });

assert(!isAsyncIterable(result));
expectJSON(await createSourceEventStream(schema, document)).toDeepEqual(
result,
);
return result;
const subscribeResult = subscribe({ schema, document });
const streamResult = createSourceEventStream(schema, document);

if (isPromise(subscribeResult)) {
assert(isPromise(streamResult));
return Promise.all([subscribeResult, streamResult]).then((resolved) =>
expectEquivalentStreamErrors(resolved[0], resolved[1]),
);
}

assert(!isPromise(streamResult));
return expectEquivalentStreamErrors(subscribeResult, streamResult);
}

function expectEquivalentStreamErrors(
subscribeResult: ExecutionResult | AsyncGenerator<ExecutionResult>,
createSourceEventStreamResult: ExecutionResult | AsyncIterable<unknown>,
): ExecutionResult {
assert(!isAsyncIterable(subscribeResult));
expectJSON(createSourceEventStreamResult).toDeepEqual(subscribeResult);
return subscribeResult;
}

/* eslint-disable @typescript-eslint/require-await */
Expand Down Expand Up @@ -379,24 +393,22 @@ describe('Subscription Initialization Phase', () => {
});

// @ts-expect-error (schema must not be null)
(await expectPromise(subscribe({ schema: null, document }))).toRejectWith(
expect(() => subscribe({ schema: null, document })).to.throw(
'Expected null to be a GraphQL schema.',
);

// @ts-expect-error
(await expectPromise(subscribe({ document }))).toRejectWith(
expect(() => subscribe({ document })).to.throw(
'Expected undefined to be a GraphQL schema.',
);

// @ts-expect-error (document must not be null)
(await expectPromise(subscribe({ schema, document: null }))).toRejectWith(
expect(() => subscribe({ schema, document: null })).to.throw(
'Must provide document.',
);

// @ts-expect-error
(await expectPromise(subscribe({ schema }))).toRejectWith(
'Must provide document.',
);
expect(() => subscribe({ schema })).to.throw('Must provide document.');
});

it('resolves to an error if schema does not support subscriptions', async () => {
Expand Down Expand Up @@ -450,11 +462,11 @@ describe('Subscription Initialization Phase', () => {
});

// @ts-expect-error
(await expectPromise(subscribe({ schema, document: {} }))).toReject();
expect(() => subscribe({ schema, document: {} })).to.throw();
});

it('throws an error if subscribe does not return an iterator', async () => {
expectJSON(await subscribeWithBadFn(() => 'test')).toDeepEqual({
const expectedResult = {
errors: [
{
message:
Expand All @@ -463,7 +475,13 @@ describe('Subscription Initialization Phase', () => {
path: ['foo'],
},
],
});
};

expectJSON(subscribeWithBadFn(() => 'test')).toDeepEqual(expectedResult);

const result = subscribeWithBadFn(() => Promise.resolve('test'));
assert(isPromise(result));
expectJSON(await result).toDeepEqual(expectedResult);
});

it('resolves to an error for subscription resolver errors', async () => {
Expand All @@ -479,12 +497,12 @@ describe('Subscription Initialization Phase', () => {

expectJSON(
// Returning an error
await subscribeWithBadFn(() => new Error('test error')),
subscribeWithBadFn(() => new Error('test error')),
).toDeepEqual(expectedResult);

expectJSON(
// Throwing an error
await subscribeWithBadFn(() => {
subscribeWithBadFn(() => {
throw new Error('test error');
}),
).toDeepEqual(expectedResult);
Expand Down
101 changes: 76 additions & 25 deletions src/execution/subscribe.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { inspect } from '../jsutils/inspect';
import { isAsyncIterable } from '../jsutils/isAsyncIterable';
import { isPromise } from '../jsutils/isPromise';
import type { Maybe } from '../jsutils/Maybe';
import { addPath, pathToArray } from '../jsutils/Path';
import type { PromiseOrValue } from '../jsutils/PromiseOrValue';

import { GraphQLError } from '../error/GraphQLError';
import { locatedError } from '../error/locatedError';
Expand Down Expand Up @@ -47,9 +49,11 @@ import { getArgumentValues } from './values';
*
* Accepts either an object with named arguments, or individual arguments.
*/
export async function subscribe(
export function subscribe(
args: ExecutionArgs,
): Promise<AsyncGenerator<ExecutionResult, void, void> | ExecutionResult> {
): PromiseOrValue<
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
> {
const {
schema,
document,
Expand All @@ -61,7 +65,7 @@ export async function subscribe(
subscribeFieldResolver,
} = args;

const resultOrStream = await createSourceEventStream(
const resultOrStream = createSourceEventStream(
schema,
document,
rootValue,
Expand All @@ -71,6 +75,42 @@ export async function subscribe(
subscribeFieldResolver,
);

if (isPromise(resultOrStream)) {
return resultOrStream.then((resolvedResultOrStream) =>
mapSourceToResponse(
schema,
document,
resolvedResultOrStream,
contextValue,
variableValues,
operationName,
fieldResolver,
),
);
}

return mapSourceToResponse(
schema,
document,
resultOrStream,
contextValue,
variableValues,
operationName,
fieldResolver,
);
}

function mapSourceToResponse(
schema: GraphQLSchema,
document: DocumentNode,
resultOrStream: ExecutionResult | AsyncIterable<unknown>,
contextValue?: unknown,
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
operationName?: Maybe<string>,
fieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
): PromiseOrValue<
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
> {
if (!isAsyncIterable(resultOrStream)) {
return resultOrStream;
}
Expand All @@ -81,7 +121,7 @@ export async function subscribe(
// the GraphQL specification. The `execute` function provides the
// "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
// "ExecuteQuery" algorithm, for which `execute` is also used.
const mapSourceToResponse = (payload: unknown) =>
return mapAsyncIterator(resultOrStream, (payload: unknown) =>
execute({
schema,
document,
Expand All @@ -90,10 +130,8 @@ export async function subscribe(
variableValues,
operationName,
fieldResolver,
});

// Map every source value to a ExecutionResult value as described above.
return mapAsyncIterator(resultOrStream, mapSourceToResponse);
}),
);
}

/**
Expand Down Expand Up @@ -124,15 +162,15 @@ export async function subscribe(
* or otherwise separating these two steps. For more on this, see the
* "Supporting Subscriptions at Scale" information in the GraphQL specification.
*/
export async function createSourceEventStream(
export function createSourceEventStream(
schema: GraphQLSchema,
document: DocumentNode,
rootValue?: unknown,
contextValue?: unknown,
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
operationName?: Maybe<string>,
subscribeFieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
): Promise<AsyncIterable<unknown> | ExecutionResult> {
): PromiseOrValue<AsyncIterable<unknown> | ExecutionResult> {
// If arguments are missing or incorrectly typed, this is an internal
// developer mistake which should throw an early error.
assertValidExecutionArguments(schema, document, variableValues);
Expand All @@ -155,17 +193,20 @@ export async function createSourceEventStream(
}

try {
const eventStream = await executeSubscription(exeContext);
const eventStream = executeSubscription(exeContext);
if (isPromise(eventStream)) {
return eventStream.then(undefined, (error) => ({ errors: [error] }));
}

return eventStream;
} catch (error) {
return { errors: [error] };
}
}

async function executeSubscription(
function executeSubscription(
exeContext: ExecutionContext,
): Promise<AsyncIterable<unknown>> {
): PromiseOrValue<AsyncIterable<unknown>> {
const { schema, fragments, operation, variableValues, rootValue } =
exeContext;

Expand Down Expand Up @@ -220,22 +261,32 @@ async function executeSubscription(
// Call the `subscribe()` resolver or the default resolver to produce an
// AsyncIterable yielding raw payloads.
const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver;
const eventStream = await resolveFn(rootValue, args, contextValue, info);
const eventStream = resolveFn(rootValue, args, contextValue, info);

if (eventStream instanceof Error) {
throw eventStream;
if (isPromise(eventStream)) {
return eventStream.then(assertEventStream).then(undefined, (error) => {
throw locatedError(error, fieldNodes, pathToArray(path));
});
}

// Assert field returned an event stream, otherwise yield an error.
if (!isAsyncIterable(eventStream)) {
throw new GraphQLError(
'Subscription field must return Async Iterable. ' +
`Received: ${inspect(eventStream)}.`,
);
}

return eventStream;
return assertEventStream(eventStream);
} catch (error) {
throw locatedError(error, fieldNodes, pathToArray(path));
}
}

function assertEventStream(eventStream: unknown): AsyncIterable<unknown> {
if (eventStream instanceof Error) {
throw eventStream;
}

// Assert field returned an event stream, otherwise yield an error.
if (!isAsyncIterable(eventStream)) {
throw new GraphQLError(
'Subscription field must return Async Iterable. ' +
`Received: ${inspect(eventStream)}.`,
);
}

return eventStream;
}