Skip to content

Commit 7e79d70

Browse files
committed
Improve enum compatibility by supporting return enum as well as values and names
1 parent 88f79b2 commit 7e79d70

File tree

3 files changed

+156
-2
lines changed

3 files changed

+156
-2
lines changed

graphene/types/definitions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from enum import Enum as PyEnum
2+
13
from graphql import (
24
GraphQLEnumType,
35
GraphQLInputObjectType,
@@ -36,7 +38,16 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType):
3638

3739

3840
class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType):
39-
pass
41+
def serialize(self, value):
42+
if not isinstance(value, PyEnum):
43+
enum = self.graphene_type._meta.enum
44+
try:
45+
# Try and get enum by value
46+
value = enum(value)
47+
except ValueError:
48+
# Try ang get enum by name
49+
value = enum[value]
50+
return super(GrapheneEnumType, self).serialize(value)
4051

4152

4253
class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType):

graphene/types/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def create_enum(graphene_type):
174174
deprecation_reason = graphene_type._meta.deprecation_reason(value)
175175

176176
values[name] = GraphQLEnumValue(
177-
value=value.value,
177+
value=value,
178178
description=description,
179179
deprecation_reason=deprecation_reason,
180180
)

graphene/types/tests/test_enum.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from textwrap import dedent
2+
13
from ..argument import Argument
24
from ..enum import Enum, PyEnum
35
from ..field import Field
46
from ..inputfield import InputField
57
from ..schema import ObjectType, Schema
8+
from ..mutation import Mutation
69

710

811
def test_enum_construction():
@@ -224,3 +227,143 @@ class Meta:
224227
"GREEN": RGB1.GREEN,
225228
"BLUE": RGB1.BLUE,
226229
}
230+
231+
232+
def test_enum_types():
233+
from enum import Enum as PyEnum
234+
235+
class Color(PyEnum):
236+
RED = 1
237+
GREEN = 2
238+
BLUE = 3
239+
240+
GColor = Enum.from_enum(Color)
241+
242+
class Query(ObjectType):
243+
color = GColor(required=True)
244+
245+
def resolve_color(_, info):
246+
return Color.RED.value
247+
248+
schema = Schema(query=Query)
249+
250+
assert str(schema) == dedent(
251+
'''\
252+
"""An enumeration."""
253+
enum Color {
254+
RED
255+
GREEN
256+
BLUE
257+
}
258+
259+
type Query {
260+
color: Color!
261+
}
262+
'''
263+
)
264+
265+
266+
def test_enum_resolver():
267+
from enum import Enum as PyEnum
268+
269+
class Color(PyEnum):
270+
RED = 1
271+
GREEN = 2
272+
BLUE = 3
273+
274+
GColor = Enum.from_enum(Color)
275+
276+
class Query(ObjectType):
277+
color = GColor(required=True)
278+
279+
def resolve_color(_, info):
280+
return Color.RED
281+
282+
schema = Schema(query=Query)
283+
284+
results = schema.execute("query { color }")
285+
assert not results.errors
286+
287+
assert results.data["color"] == Color.RED.name
288+
289+
290+
def test_enum_resolver_compat():
291+
from enum import Enum as PyEnum
292+
293+
class Color(PyEnum):
294+
RED = 1
295+
GREEN = 2
296+
BLUE = 3
297+
298+
GColor = Enum.from_enum(Color)
299+
300+
class Query(ObjectType):
301+
color = GColor(required=True)
302+
color_by_name = GColor(required=True)
303+
304+
def resolve_color(_, info):
305+
return Color.RED.value
306+
307+
def resolve_color_by_name(_, info):
308+
return Color.RED.name
309+
310+
schema = Schema(query=Query)
311+
312+
results = schema.execute(
313+
"""query {
314+
color
315+
colorByName
316+
}"""
317+
)
318+
assert not results.errors
319+
320+
assert results.data["color"] == Color.RED.name
321+
assert results.data["colorByName"] == Color.RED.name
322+
323+
324+
def test_enum_mutation():
325+
from enum import Enum as PyEnum
326+
327+
class Color(PyEnum):
328+
RED = 1
329+
GREEN = 2
330+
BLUE = 3
331+
332+
GColor = Enum.from_enum(Color)
333+
334+
my_fav_color = None
335+
336+
class Query(ObjectType):
337+
fav_color = GColor(required=True)
338+
339+
def resolve_fav_color(_, info):
340+
return my_fav_color
341+
342+
class SetFavColor(Mutation):
343+
class Arguments:
344+
fav_color = Argument(GColor, required=True)
345+
346+
Output = Query
347+
348+
def mutate(self, info, fav_color):
349+
nonlocal my_fav_color
350+
my_fav_color = fav_color
351+
return Query()
352+
353+
class MyMutations(ObjectType):
354+
set_fav_color = SetFavColor.Field()
355+
356+
schema = Schema(query=Query, mutation=MyMutations)
357+
358+
results = schema.execute(
359+
"""mutation {
360+
setFavColor(favColor: RED) {
361+
favColor
362+
}
363+
}"""
364+
)
365+
assert not results.errors
366+
367+
assert my_fav_color == Color.RED
368+
369+
assert results.data["setFavColor"]["favColor"] == Color.RED.name

0 commit comments

Comments
 (0)