Skip to content

Improve enum compatibility #1153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/types/enums.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ you can add description etc. to your enum without changing the original:

graphene.Enum.from_enum(
AlreadyExistingPyEnum,
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar')
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar'
)


Notes
Expand All @@ -76,6 +77,7 @@ In the Python ``Enum`` implementation you can access a member by initing the Enu
.. code:: python

from enum import Enum

class Color(Enum):
RED = 1
GREEN = 2
Expand All @@ -89,6 +91,7 @@ However, in Graphene ``Enum`` you need to call get to have the same effect:
.. code:: python

from graphene import Enum

class Color(Enum):
RED = 1
GREEN = 2
Expand Down
2 changes: 1 addition & 1 deletion graphene/relay/tests/test_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from graphql_relay import to_global_id

from graphql.pyutils import dedent
from graphene.tests.utils import dedent

from ...types import ObjectType, Schema, String
from ..node import Node, is_node
Expand Down
3 changes: 2 additions & 1 deletion graphene/relay/tests/test_node_custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from graphql import graphql_sync
from graphql.pyutils import dedent

from graphene.tests.utils import dedent

from ...types import Interface, ObjectType, Schema
from ...types.scalars import Int, String
Expand Down
9 changes: 9 additions & 0 deletions graphene/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from textwrap import dedent as _dedent


def dedent(text: str) -> str:
"""Fix indentation of given text by removing leading spaces and tabs.
Also removes leading newlines and trailing spaces and tabs, but keeps trailing
newlines.
"""
return _dedent(text.lstrip("\n").rstrip(" \t"))
17 changes: 16 additions & 1 deletion graphene/types/definitions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from enum import Enum as PyEnum

from graphql import (
GraphQLEnumType,
GraphQLInputObjectType,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLScalarType,
GraphQLUnionType,
Undefined,
)


Expand Down Expand Up @@ -36,7 +39,19 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType):


class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType):
pass
def serialize(self, value):
if not isinstance(value, PyEnum):
enum = self.graphene_type._meta.enum
try:
# Try and get enum by value
value = enum(value)
except ValueError:
# Try and get enum by name
try:
value = enum[value]
except KeyError:
return Undefined
return super(GrapheneEnumType, self).serialize(value)


class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType):
Expand Down
2 changes: 1 addition & 1 deletion graphene/types/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def create_enum(graphene_type):
deprecation_reason = graphene_type._meta.deprecation_reason(value)

values[name] = GraphQLEnumValue(
value=value.value,
value=value,
description=description,
deprecation_reason=deprecation_reason,
)
Expand Down
247 changes: 247 additions & 0 deletions graphene/types/tests/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from textwrap import dedent

from ..argument import Argument
from ..enum import Enum, PyEnum
from ..field import Field
from ..inputfield import InputField
from ..inputobjecttype import InputObjectType
from ..mutation import Mutation
from ..scalars import String
from ..schema import ObjectType, Schema


Expand Down Expand Up @@ -224,3 +229,245 @@ class Meta:
"GREEN": RGB1.GREEN,
"BLUE": RGB1.BLUE,
}


def test_enum_types():
from enum import Enum as PyEnum

class Color(PyEnum):
"""Primary colors"""

RED = 1
YELLOW = 2
BLUE = 3

GColor = Enum.from_enum(Color)

class Query(ObjectType):
color = GColor(required=True)

def resolve_color(_, info):
return Color.RED

schema = Schema(query=Query)

assert str(schema) == dedent(
'''\
type Query {
color: Color!
}

"""Primary colors"""
enum Color {
RED
YELLOW
BLUE
}
'''
)


def test_enum_resolver():
from enum import Enum as PyEnum

class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3

GColor = Enum.from_enum(Color)

class Query(ObjectType):
color = GColor(required=True)

def resolve_color(_, info):
return Color.RED

schema = Schema(query=Query)

results = schema.execute("query { color }")
assert not results.errors

assert results.data["color"] == Color.RED.name


def test_enum_resolver_compat():
from enum import Enum as PyEnum

class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3

GColor = Enum.from_enum(Color)

class Query(ObjectType):
color = GColor(required=True)
color_by_name = GColor(required=True)

def resolve_color(_, info):
return Color.RED.value

def resolve_color_by_name(_, info):
return Color.RED.name

schema = Schema(query=Query)

results = schema.execute(
"""query {
color
colorByName
}"""
)
assert not results.errors

assert results.data["color"] == Color.RED.name
assert results.data["colorByName"] == Color.RED.name


def test_enum_resolver_invalid():
from enum import Enum as PyEnum

class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3

GColor = Enum.from_enum(Color)

class Query(ObjectType):
color = GColor(required=True)

def resolve_color(_, info):
return "BLACK"

schema = Schema(query=Query)

results = schema.execute("query { color }")
assert results.errors
assert (
results.errors[0].message
== "Expected a value of type 'Color' but received: 'BLACK'"
)


def test_field_enum_argument():
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3

class Brick(ObjectType):
color = Color(required=True)

color_filter = None

class Query(ObjectType):
bricks_by_color = Field(Brick, color=Color(required=True))

def resolve_bricks_by_color(_, info, color):
nonlocal color_filter
color_filter = color
return Brick(color=color)

schema = Schema(query=Query)

results = schema.execute(
"""
query {
bricksByColor(color: RED) {
color
}
}
"""
)
assert not results.errors
assert results.data == {"bricksByColor": {"color": "RED"}}
assert color_filter == Color.RED


def test_mutation_enum_input():
class RGB(Enum):
"""Available colors"""

RED = 1
GREEN = 2
BLUE = 3

color_input = None

class CreatePaint(Mutation):
class Arguments:
color = RGB(required=True)

color = RGB(required=True)

def mutate(_, info, color):
nonlocal color_input
color_input = color
return CreatePaint(color=color)

class MyMutation(ObjectType):
create_paint = CreatePaint.Field()

class Query(ObjectType):
a = String()

schema = Schema(query=Query, mutation=MyMutation)
result = schema.execute(
""" mutation MyMutation {
createPaint(color: RED) {
color
}
}
"""
)
assert not result.errors
assert result.data == {"createPaint": {"color": "RED"}}

assert color_input == RGB.RED


def test_mutation_enum_input_type():
class RGB(Enum):
"""Available colors"""

RED = 1
GREEN = 2
BLUE = 3

class ColorInput(InputObjectType):
color = RGB(required=True)

color_input_value = None

class CreatePaint(Mutation):
class Arguments:
color_input = ColorInput(required=True)

color = RGB(required=True)

def mutate(_, info, color_input):
nonlocal color_input_value
color_input_value = color_input.color
return CreatePaint(color=color_input.color)

class MyMutation(ObjectType):
create_paint = CreatePaint.Field()

class Query(ObjectType):
a = String()

schema = Schema(query=Query, mutation=MyMutation)
result = schema.execute(
""" mutation MyMutation {
createPaint(colorInput: { color: RED }) {
color
}
}
""",
)
assert not result.errors
assert result.data == {"createPaint": {"color": "RED"}}

assert color_input_value == RGB.RED
4 changes: 2 additions & 2 deletions graphene/types/tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from graphql.type import GraphQLObjectType, GraphQLSchema
from pytest import raises

from graphql.type import GraphQLObjectType, GraphQLSchema
from graphql.pyutils import dedent
from graphene.tests.utils import dedent

from ..field import Field
from ..objecttype import ObjectType
Expand Down
7 changes: 7 additions & 0 deletions graphene/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,10 @@ def get_type(_type):
if inspect.isfunction(_type) or isinstance(_type, partial):
return _type()
return _type


def get_underlying_type(_type):
"""Get the underlying type even if it is wrapped in structures like NonNull"""
while hasattr(_type, "of_type"):
_type = _type.of_type
return _type
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def run_tests(self):
keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["examples*"]),
install_requires=[
"graphql-core>=3.1.1,<4",
"graphql-core>=3.1.2,<4",
"graphql-relay>=3.0,<4",
"aniso8601>=8,<9",
],
Expand Down