diff --git a/gql/dsl.py b/gql/dsl.py index 6407f267..80361240 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -84,7 +84,7 @@ def select(self, *fields): added_selections = selections(*fields) if selection_set: selection_set.selections = FrozenList( - selection_set.selections + added_selections + selection_set.selections + list(added_selections) ) else: self.ast_field.selection_set = SelectionSetNode( diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 54d71613..18cd2982 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -1,5 +1,5 @@ from inspect import isawaitable -from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Coroutine, cast +from typing import AsyncGenerator, Awaitable, cast from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe @@ -50,27 +50,16 @@ async def execute( async def subscribe( self, document: DocumentNode, *args, **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using an async generator - - The query can be a graphql query, mutation or subscription + """Send a subscription and receive the results using an async generator The results are sent as an ExecutionResult object """ - subscribe_result = subscribe(self.schema, document, *args, **kwargs) + subscribe_result = await subscribe(self.schema, document, *args, **kwargs) if isinstance(subscribe_result, ExecutionResult): - yield ExecutionResult + yield subscribe_result else: - # if we don't get an ExecutionResult, then we should receive - # a Coroutine returning an AsyncIterator[ExecutionResult] - - subscribe_coro = cast( - Coroutine[Any, Any, AsyncIterator[ExecutionResult]], subscribe_result - ) - - subscribe_generator = await subscribe_coro - - async for result in subscribe_generator: + async for result in subscribe_result: yield result diff --git a/gql/transport/transport.py b/gql/transport/transport.py index 05bb063c..56d882f4 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -20,7 +20,7 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: def connect(self): """Establish a session with the transport. """ - pass + pass # pragma: no cover def close(self): """Close the transport diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index 409ca1bd..7bc31037 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -1,3 +1,4 @@ +import asyncio from collections import namedtuple Human = namedtuple("Human", "id name friends appearsIn homePlanet") @@ -94,6 +95,11 @@ def getHero(episode): return artoo +async def getHeroAsync(episode): + await asyncio.sleep(0.001) + return getHero(episode) + + def getHuman(id): return humanData.get(id) diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index e23822f3..95320ffe 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -24,7 +24,7 @@ getCharacters, getDroid, getFriends, - getHero, + getHeroAsync, getHuman, reviews, ) @@ -146,7 +146,7 @@ type_=episodeEnum, # type: ignore ) }, - resolve=lambda root, info, **args: getHero(args.get("episode")), + resolve=lambda root, info, **args: getHeroAsync(args.get("episode")), ), "human": GraphQLField( humanType, diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 7501e84f..9047f352 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -53,6 +53,18 @@ def test_hero_name_and_friends_query(ds): assert query == str(query_dsl) +def test_hero_id_and_name(ds): + query = """ +hero { + id + name +} + """.strip() + query_dsl = ds.Query.hero.select(ds.Character.id) + query_dsl = query_dsl.select(ds.Character.name) + assert query == str(query_dsl) + + def test_nested_query(ds): query = """ hero { diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 7a76b3bf..3753ab2f 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -1,5 +1,5 @@ import pytest -from graphql import subscribe +from graphql import ExecutionResult, GraphQLError, subscribe from gql import Client, gql @@ -57,3 +57,38 @@ async def test_subscription_support_using_client(): ] assert results == expected + + +subscription_invalid_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + qsdfqsdfqsdf + } +""" + + +@pytest.mark.asyncio +async def test_subscription_support_using_client_invalid_field(): + + subs = gql(subscription_invalid_str) + + params = {"ep": "JEDI"} + + async with Client(schema=StarWarsSchema) as session: + + # We subscribe directly from the transport to avoid local validation + results = [ + result + async for result in session.transport.subscribe( + subs, variable_values=params + ) + ] + + assert len(results) == 1 + result = results[0] + assert isinstance(result, ExecutionResult) + assert result.data is None + assert isinstance(result.errors, list) + assert len(result.errors) == 1 + error = result.errors[0] + assert isinstance(error, GraphQLError) + assert error.message == "The subscription field 'qsdfqsdfqsdf' is not defined."