Skip to content

Commit 81fff0f

Browse files
authored
Improve enum compatibility (graphql-python#1153)
* Improve enum compatibility by supporting return enum as well as values and names * Handle invalid enum values * Rough implementation of compat middleware * Move enum middleware into compat module * Fix tests * Tweak enum examples * Add some tests for the middleware * Clean up tests * Add missing imports * Remove enum compat middleware * Use custom dedent function and pin graphql-core to >3.1.2
1 parent d042d5e commit 81fff0f

File tree

10 files changed

+290
-8
lines changed

10 files changed

+290
-8
lines changed

docs/types/enums.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ you can add description etc. to your enum without changing the original:
6161
6262
graphene.Enum.from_enum(
6363
AlreadyExistingPyEnum,
64-
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar')
64+
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar'
65+
)
6566
6667
6768
Notes
@@ -76,6 +77,7 @@ In the Python ``Enum`` implementation you can access a member by initing the Enu
7677
.. code:: python
7778
7879
from enum import Enum
80+
7981
class Color(Enum):
8082
RED = 1
8183
GREEN = 2
@@ -89,6 +91,7 @@ However, in Graphene ``Enum`` you need to call get to have the same effect:
8991
.. code:: python
9092
9193
from graphene import Enum
94+
9295
class Color(Enum):
9396
RED = 1
9497
GREEN = 2

graphene/relay/tests/test_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from graphql_relay import to_global_id
33

4-
from graphql.pyutils import dedent
4+
from graphene.tests.utils import dedent
55

66
from ...types import ObjectType, Schema, String
77
from ..node import Node, is_node

graphene/relay/tests/test_node_custom.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from graphql import graphql_sync
2-
from graphql.pyutils import dedent
2+
3+
from graphene.tests.utils import dedent
34

45
from ...types import Interface, ObjectType, Schema
56
from ...types.scalars import Int, String

graphene/tests/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from textwrap import dedent as _dedent
2+
3+
4+
def dedent(text: str) -> str:
5+
"""Fix indentation of given text by removing leading spaces and tabs.
6+
Also removes leading newlines and trailing spaces and tabs, but keeps trailing
7+
newlines.
8+
"""
9+
return _dedent(text.lstrip("\n").rstrip(" \t"))

graphene/types/definitions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from enum import Enum as PyEnum
2+
13
from graphql import (
24
GraphQLEnumType,
35
GraphQLInputObjectType,
46
GraphQLInterfaceType,
57
GraphQLObjectType,
68
GraphQLScalarType,
79
GraphQLUnionType,
10+
Undefined,
811
)
912

1013

@@ -36,7 +39,19 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType):
3639

3740

3841
class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType):
39-
pass
42+
def serialize(self, value):
43+
if not isinstance(value, PyEnum):
44+
enum = self.graphene_type._meta.enum
45+
try:
46+
# Try and get enum by value
47+
value = enum(value)
48+
except ValueError:
49+
# Try and get enum by name
50+
try:
51+
value = enum[value]
52+
except KeyError:
53+
return Undefined
54+
return super(GrapheneEnumType, self).serialize(value)
4055

4156

4257
class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType):

graphene/types/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def create_enum(graphene_type):
172172
deprecation_reason = graphene_type._meta.deprecation_reason(value)
173173

174174
values[name] = GraphQLEnumValue(
175-
value=value.value,
175+
value=value,
176176
description=description,
177177
deprecation_reason=deprecation_reason,
178178
)

graphene/types/tests/test_enum.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
from textwrap import dedent
2+
13
from ..argument import Argument
24
from ..enum import Enum, PyEnum
35
from ..field import Field
46
from ..inputfield import InputField
7+
from ..inputobjecttype import InputObjectType
8+
from ..mutation import Mutation
9+
from ..scalars import String
510
from ..schema import ObjectType, Schema
611

712

@@ -224,3 +229,245 @@ class Meta:
224229
"GREEN": RGB1.GREEN,
225230
"BLUE": RGB1.BLUE,
226231
}
232+
233+
234+
def test_enum_types():
235+
from enum import Enum as PyEnum
236+
237+
class Color(PyEnum):
238+
"""Primary colors"""
239+
240+
RED = 1
241+
YELLOW = 2
242+
BLUE = 3
243+
244+
GColor = Enum.from_enum(Color)
245+
246+
class Query(ObjectType):
247+
color = GColor(required=True)
248+
249+
def resolve_color(_, info):
250+
return Color.RED
251+
252+
schema = Schema(query=Query)
253+
254+
assert str(schema) == dedent(
255+
'''\
256+
type Query {
257+
color: Color!
258+
}
259+
260+
"""Primary colors"""
261+
enum Color {
262+
RED
263+
YELLOW
264+
BLUE
265+
}
266+
'''
267+
)
268+
269+
270+
def test_enum_resolver():
271+
from enum import Enum as PyEnum
272+
273+
class Color(PyEnum):
274+
RED = 1
275+
GREEN = 2
276+
BLUE = 3
277+
278+
GColor = Enum.from_enum(Color)
279+
280+
class Query(ObjectType):
281+
color = GColor(required=True)
282+
283+
def resolve_color(_, info):
284+
return Color.RED
285+
286+
schema = Schema(query=Query)
287+
288+
results = schema.execute("query { color }")
289+
assert not results.errors
290+
291+
assert results.data["color"] == Color.RED.name
292+
293+
294+
def test_enum_resolver_compat():
295+
from enum import Enum as PyEnum
296+
297+
class Color(PyEnum):
298+
RED = 1
299+
GREEN = 2
300+
BLUE = 3
301+
302+
GColor = Enum.from_enum(Color)
303+
304+
class Query(ObjectType):
305+
color = GColor(required=True)
306+
color_by_name = GColor(required=True)
307+
308+
def resolve_color(_, info):
309+
return Color.RED.value
310+
311+
def resolve_color_by_name(_, info):
312+
return Color.RED.name
313+
314+
schema = Schema(query=Query)
315+
316+
results = schema.execute(
317+
"""query {
318+
color
319+
colorByName
320+
}"""
321+
)
322+
assert not results.errors
323+
324+
assert results.data["color"] == Color.RED.name
325+
assert results.data["colorByName"] == Color.RED.name
326+
327+
328+
def test_enum_resolver_invalid():
329+
from enum import Enum as PyEnum
330+
331+
class Color(PyEnum):
332+
RED = 1
333+
GREEN = 2
334+
BLUE = 3
335+
336+
GColor = Enum.from_enum(Color)
337+
338+
class Query(ObjectType):
339+
color = GColor(required=True)
340+
341+
def resolve_color(_, info):
342+
return "BLACK"
343+
344+
schema = Schema(query=Query)
345+
346+
results = schema.execute("query { color }")
347+
assert results.errors
348+
assert (
349+
results.errors[0].message
350+
== "Expected a value of type 'Color' but received: 'BLACK'"
351+
)
352+
353+
354+
def test_field_enum_argument():
355+
class Color(Enum):
356+
RED = 1
357+
GREEN = 2
358+
BLUE = 3
359+
360+
class Brick(ObjectType):
361+
color = Color(required=True)
362+
363+
color_filter = None
364+
365+
class Query(ObjectType):
366+
bricks_by_color = Field(Brick, color=Color(required=True))
367+
368+
def resolve_bricks_by_color(_, info, color):
369+
nonlocal color_filter
370+
color_filter = color
371+
return Brick(color=color)
372+
373+
schema = Schema(query=Query)
374+
375+
results = schema.execute(
376+
"""
377+
query {
378+
bricksByColor(color: RED) {
379+
color
380+
}
381+
}
382+
"""
383+
)
384+
assert not results.errors
385+
assert results.data == {"bricksByColor": {"color": "RED"}}
386+
assert color_filter == Color.RED
387+
388+
389+
def test_mutation_enum_input():
390+
class RGB(Enum):
391+
"""Available colors"""
392+
393+
RED = 1
394+
GREEN = 2
395+
BLUE = 3
396+
397+
color_input = None
398+
399+
class CreatePaint(Mutation):
400+
class Arguments:
401+
color = RGB(required=True)
402+
403+
color = RGB(required=True)
404+
405+
def mutate(_, info, color):
406+
nonlocal color_input
407+
color_input = color
408+
return CreatePaint(color=color)
409+
410+
class MyMutation(ObjectType):
411+
create_paint = CreatePaint.Field()
412+
413+
class Query(ObjectType):
414+
a = String()
415+
416+
schema = Schema(query=Query, mutation=MyMutation)
417+
result = schema.execute(
418+
""" mutation MyMutation {
419+
createPaint(color: RED) {
420+
color
421+
}
422+
}
423+
"""
424+
)
425+
assert not result.errors
426+
assert result.data == {"createPaint": {"color": "RED"}}
427+
428+
assert color_input == RGB.RED
429+
430+
431+
def test_mutation_enum_input_type():
432+
class RGB(Enum):
433+
"""Available colors"""
434+
435+
RED = 1
436+
GREEN = 2
437+
BLUE = 3
438+
439+
class ColorInput(InputObjectType):
440+
color = RGB(required=True)
441+
442+
color_input_value = None
443+
444+
class CreatePaint(Mutation):
445+
class Arguments:
446+
color_input = ColorInput(required=True)
447+
448+
color = RGB(required=True)
449+
450+
def mutate(_, info, color_input):
451+
nonlocal color_input_value
452+
color_input_value = color_input.color
453+
return CreatePaint(color=color_input.color)
454+
455+
class MyMutation(ObjectType):
456+
create_paint = CreatePaint.Field()
457+
458+
class Query(ObjectType):
459+
a = String()
460+
461+
schema = Schema(query=Query, mutation=MyMutation)
462+
result = schema.execute(
463+
""" mutation MyMutation {
464+
createPaint(colorInput: { color: RED }) {
465+
color
466+
}
467+
}
468+
""",
469+
)
470+
assert not result.errors
471+
assert result.data == {"createPaint": {"color": "RED"}}
472+
473+
assert color_input_value == RGB.RED

graphene/types/tests/test_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from graphql.type import GraphQLObjectType, GraphQLSchema
12
from pytest import raises
23

3-
from graphql.type import GraphQLObjectType, GraphQLSchema
4-
from graphql.pyutils import dedent
4+
from graphene.tests.utils import dedent
55

66
from ..field import Field
77
from ..objecttype import ObjectType

graphene/types/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ def get_type(_type):
4141
if inspect.isfunction(_type) or isinstance(_type, partial):
4242
return _type()
4343
return _type
44+
45+
46+
def get_underlying_type(_type):
47+
"""Get the underlying type even if it is wrapped in structures like NonNull"""
48+
while hasattr(_type, "of_type"):
49+
_type = _type.of_type
50+
return _type

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def run_tests(self):
8282
keywords="api graphql protocol rest relay graphene",
8383
packages=find_packages(exclude=["examples*"]),
8484
install_requires=[
85-
"graphql-core>=3.1.1,<4",
85+
"graphql-core>=3.1.2,<4",
8686
"graphql-relay>=3.0,<4",
8787
"aniso8601>=8,<9",
8888
],

0 commit comments

Comments
 (0)