Skip to content

Commit 17536fc

Browse files
committed
address @wyattanderson comments
1 parent 1e3817f commit 17536fc

File tree

3 files changed

+86
-24
lines changed

3 files changed

+86
-24
lines changed

graphene_sqlalchemy/converter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,9 @@ def inner(fn):
102102

103103
def convert_sqlalchemy_column(column_prop, registry, **field_kwargs):
104104
column = column_prop.columns[0]
105-
if 'type' not in field_kwargs:
106-
field_kwargs['type'] = convert_sqlalchemy_type(getattr(column, "type", None), column, registry)
107-
108-
if 'required' not in field_kwargs:
109-
field_kwargs['required'] = not is_column_nullable(column)
110-
111-
if 'description' not in field_kwargs:
112-
field_kwargs['description'] = get_column_doc(column)
105+
field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
106+
field_kwargs.setdefault('required', not is_column_nullable(column))
107+
field_kwargs.setdefault('description', get_column_doc(column))
113108

114109
return Field(
115110
resolver=_get_attr_resolver(column_prop.key),

graphene_sqlalchemy/tests/test_types.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,17 @@ def test_sqlalchemy_override_fields():
114114
def convert_composite_class(composite, registry):
115115
return String()
116116

117-
class ReporterType(SQLAlchemyObjectType):
117+
class ReporterMixin(object):
118+
# columns
119+
first_name = ORMField(required=True)
120+
last_name = ORMField(description='Overridden')
121+
122+
class ReporterType(SQLAlchemyObjectType, ReporterMixin):
118123
class Meta:
119124
model = Reporter
120125
interfaces = (Node,)
121126

122127
# columns
123-
first_name = ORMField(required=True)
124-
last_name = ORMField(description='Overridden')
125128
email = ORMField(deprecation_reason='Overridden')
126129
email_v2 = ORMField(prop_name='email', type=Int)
127130

@@ -151,9 +154,10 @@ class Meta:
151154
use_connection = False
152155

153156
assert list(ReporterType._meta.fields.keys()) == [
154-
# First the ORMField in the order they were defined
157+
# Fields from ReporterMixin
155158
"first_name",
156159
"last_name",
160+
# Fields from ReporterType
157161
"email",
158162
"email_v2",
159163
"column_prop",

graphene_sqlalchemy/types.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,59 @@
2727
class ORMField(OrderedType):
2828
def __init__(
2929
self,
30-
type=None,
3130
prop_name=None,
31+
type=None,
32+
required=None,
3233
description=None,
3334
deprecation_reason=None,
34-
required=None,
3535
_creation_counter=None,
3636
**field_kwargs
3737
):
38+
"""
39+
Use this to override fields automatically generated by SQLAlchemyObjectType.
40+
Unless specified, options will default to SQLAlchemyObjectType usual behavior
41+
for the given SQLAlchemy model property.
42+
43+
Usage:
44+
class MyModel(Base):
45+
id = Column(Integer(), primary_key=True)
46+
name = Column(String)
47+
48+
class MyType(SQLAlchemyObjectType):
49+
class Meta:
50+
model = MyModel
51+
52+
id = ORMField(type=graphene.Int)
53+
name = ORMField(required=True)
54+
55+
-> MyType.id will be of type Int (vs ID).
56+
-> MyType.name will be of type NonNull(String) (vs String).
57+
58+
Parameters
59+
- prop_name : str, optional
60+
Name of the SQLAlchemy property used to resolve this field.
61+
Default to the name of the attribute referencing the ORMField.
62+
- type: optional
63+
Default to the type mapping in converter.py.
64+
- description: str, optional
65+
Default to the `doc` attribute of the SQLAlchemy column property.
66+
- required: bool, optional
67+
Default to the opposite of the `nullable` attribute of the SQLAlchemy column property.
68+
- description: str, optional
69+
Same behavior as in graphene.Field. Defaults to None.
70+
- deprecation_reason: str, optional
71+
Same behavior as in graphene.Field. Defaults to None.
72+
- _creation_counter: int, optional
73+
Same behavior as in graphene.Field.
74+
"""
3875
super(ORMField, self).__init__(_creation_counter=_creation_counter)
3976
# The is only useful for documentation and auto-completion
4077
common_kwargs = {
41-
'type': type,
42-
'prop_name': prop_name,
43-
'description': description,
44-
'deprecation_reason': deprecation_reason,
45-
'required': required,
78+
'prop_name': prop_name,
79+
'type': type,
80+
'required': required,
81+
'description': description,
82+
'deprecation_reason': deprecation_reason,
4683
}
4784
common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None}
4885
self.kwargs = field_kwargs
@@ -52,7 +89,27 @@ def __init__(
5289
def construct_fields(
5390
obj_type, model, registry, only_fields, exclude_fields, connection_field_factory
5491
):
92+
"""
93+
Construct all the fields for a SQLAlchemyObjectType.
94+
The main steps are:
95+
- Gather all the relevant attributes from the SQLAlchemy model
96+
- Gather all the ORM fields defined on the type
97+
- Merge in overrides and build up all the fields
98+
99+
Parameters
100+
- obj_type : SQLAlchemyObjectType
101+
- model : the SQLAlchemy model
102+
- registry : Registry
103+
- only_fields : tuple[string]
104+
- exclude_fields : tuple[string]
105+
- connection_field_factory : function
106+
107+
Returns
108+
- fields
109+
An OrderedDict of field names to graphene.Field
110+
"""
55111
inspected_model = sqlalchemyinspect(model)
112+
# Gather all the relevant attributes from the SQLAlchemy model
56113
all_model_props = OrderedDict(
57114
inspected_model.column_attrs.items() +
58115
inspected_model.composites.items() +
@@ -61,31 +118,37 @@ def construct_fields(
61118
inspected_model.relationships.items()
62119
)
63120

121+
# Filter out excluded fields
64122
auto_orm_field_names = []
65123
for prop_name, prop in all_model_props.items():
66124
if (only_fields and prop_name not in only_fields) or (prop_name in exclude_fields):
67125
continue
68126
auto_orm_field_names.append(prop_name)
69127

70-
# TODO Get ORMField fields defined on parent classes
71-
custom_orm_fields_items = []
72-
for attname, value in list(obj_type.__dict__.items()):
73-
if isinstance(value, ORMField):
74-
custom_orm_fields_items.append((attname, value))
128+
# Gather all the ORM fields defined on the type
129+
custom_orm_fields_items = [
130+
(attname, value)
131+
for base in reversed(obj_type.__mro__)
132+
for attname, value in base.__dict__.items()
133+
if isinstance(value, ORMField)
134+
]
75135
custom_orm_fields_items = sorted(custom_orm_fields_items, key=lambda item: item[1])
76136

137+
# Set the prop_name if not set
77138
for orm_field_name, orm_field in custom_orm_fields_items:
78139
prop_name = orm_field.kwargs.get('prop_name', orm_field_name)
79140
if prop_name not in all_model_props:
80141
raise Exception('Cannot map ORMField "{}" to SQLAlchemy model property'.format(orm_field_name))
81142
orm_field.kwargs['prop_name'] = prop_name
82143

144+
# Merge automatic fields with custom ORM fields
83145
orm_fields = OrderedDict(custom_orm_fields_items)
84146
for orm_field_name in auto_orm_field_names:
85147
if orm_field_name in orm_fields:
86148
continue
87149
orm_fields[orm_field_name] = ORMField(prop_name=orm_field_name)
88150

151+
# Build all the field dictionary
89152
fields = OrderedDict()
90153
for orm_field_name, orm_field in orm_fields.items():
91154
prop_name = orm_field.kwargs.pop('prop_name')

0 commit comments

Comments
 (0)