86
86
87
87
__all__ = [
88
88
"extend_schema" ,
89
+ "extend_schema_impl" ,
89
90
"get_description" ,
90
91
"ASTDefinitionBuilder" ,
91
92
]
@@ -125,6 +126,18 @@ def extend_schema(
125
126
126
127
assert_valid_sdl_extension (document_ast , schema )
127
128
129
+ schema_kwargs = schema .to_kwargs ()
130
+ extended_kwargs = extend_schema_impl (schema_kwargs , document_ast , assume_valid )
131
+ return (
132
+ schema if schema_kwargs is extended_kwargs else GraphQLSchema (** extended_kwargs )
133
+ )
134
+
135
+
136
+ def extend_schema_impl (
137
+ schema_kwargs : Dict [str , Any ], document_ast : DocumentNode , assume_valid = False
138
+ ) -> Dict [str , Any ]:
139
+ # Note: schema_kwargs should become a TypedDict once we require Python 3.8
140
+
128
141
# Collect the type definitions and extensions found in the document.
129
142
type_defs : List [TypeDefinitionNode ] = []
130
143
type_extensions_map : DefaultDict [str , Any ] = defaultdict (list )
@@ -159,7 +172,7 @@ def extend_schema(
159
172
and not schema_extensions
160
173
and not schema_def
161
174
):
162
- return schema
175
+ return schema_kwargs
163
176
164
177
# Below are functions used for producing this schema that have closed over this
165
178
# scope and have access to the schema, cache, and newly defined types.
@@ -359,23 +372,15 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
359
372
ast_builder = ASTDefinitionBuilder (resolve_type )
360
373
361
374
type_map = ast_builder .build_type_map (type_defs , type_extensions_map )
362
- for existing_type_name , existing_type in schema . type_map . items () :
363
- type_map [existing_type_name ] = extend_named_type (existing_type )
375
+ for existing_type in schema_kwargs [ "types" ] :
376
+ type_map [existing_type . name ] = extend_named_type (existing_type )
364
377
365
378
# Get the extended root operation types.
366
- operation_types : Dict [OperationType , GraphQLObjectType ] = {}
367
- if schema .query_type :
368
- operation_types [OperationType .QUERY ] = cast (
369
- GraphQLObjectType , replace_named_type (schema .query_type )
370
- )
371
- if schema .mutation_type :
372
- operation_types [OperationType .MUTATION ] = cast (
373
- GraphQLObjectType , replace_named_type (schema .mutation_type )
374
- )
375
- if schema .subscription_type :
376
- operation_types [OperationType .SUBSCRIPTION ] = cast (
377
- GraphQLObjectType , replace_named_type (schema .subscription_type )
378
- )
379
+ operation_types : Dict [OperationType , GraphQLNamedType ] = {}
380
+ for operation_type in OperationType :
381
+ original_type = schema_kwargs [operation_type .value ]
382
+ if original_type :
383
+ operation_types [operation_type ] = replace_named_type (original_type )
379
384
# Then, incorporate schema definition and all schema extensions.
380
385
if schema_def :
381
386
operation_types .update (ast_builder .get_operation_types ([schema_def ]))
@@ -384,26 +389,27 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
384
389
385
390
# Then produce and return a Schema with these types.
386
391
get_operation = operation_types .get
387
- return GraphQLSchema (
388
- # Note: While this could make early assertions to get the correctly
389
- # typed values, that would throw immediately while type system
390
- # validation with validateSchema() will produce more actionable results.
391
- query = get_operation (OperationType .QUERY ),
392
- mutation = get_operation (OperationType .MUTATION ),
393
- subscription = get_operation (OperationType .SUBSCRIPTION ),
394
- types = type_map .values (),
395
- directives = [replace_directive (directive ) for directive in schema .directives ]
392
+ return {
393
+ "query" : get_operation (OperationType .QUERY ),
394
+ "mutation" : get_operation (OperationType .MUTATION ),
395
+ "subscription" : get_operation (OperationType .SUBSCRIPTION ),
396
+ "types" : type_map .values (),
397
+ "directives" : [
398
+ replace_directive (directive ) for directive in schema_kwargs ["directives" ]
399
+ ]
396
400
+ ast_builder .build_directives (directive_defs ),
397
- ast_node = schema_def or schema .ast_node ,
398
- extension_ast_nodes = (
401
+ "extensions" : None ,
402
+ "ast_node" : schema_def or schema_kwargs ["ast_node" ],
403
+ "extension_ast_nodes" : (
399
404
(
400
- schema . extension_ast_nodes
405
+ schema_kwargs [ " extension_ast_nodes" ]
401
406
or cast (FrozenList [SchemaExtensionNode ], FrozenList ())
402
407
)
403
408
+ schema_extensions
404
409
)
405
410
or None ,
406
- )
411
+ "assume_valid" : assume_valid ,
412
+ }
407
413
408
414
409
415
def default_type_resolver (type_name : str , * _args ) -> NoReturn :
@@ -427,15 +433,14 @@ def get_operation_types(
427
433
# Note: While this could make early assertions to get the correctly
428
434
# typed values below, that would throw immediately while type system
429
435
# validation with validate_schema() will produce more actionable results.
430
- op_types : Dict [OperationType , GraphQLObjectType ] = {}
431
- for node in nodes :
432
- if node .operation_types :
433
- for operation_type in node .operation_types :
434
- type_name = operation_type .type .name .value
435
- op_types [operation_type .operation ] = cast (
436
- GraphQLObjectType , self ._resolve_type (type_name )
437
- )
438
- return op_types
436
+ return {
437
+ operation_type .operation : cast (
438
+ GraphQLObjectType , self ._resolve_type (operation_type .type .name .value )
439
+ )
440
+ for node in nodes
441
+ if node .operation_types
442
+ for operation_type in node .operation_types
443
+ }
439
444
440
445
def get_named_type (self , node : NamedTypeNode ) -> GraphQLNamedType :
441
446
name = node .name .value
0 commit comments