|
1 | 1 | import logging
|
2 | 2 | from collections.abc import Iterable
|
3 |
| -from typing import Any, Callable, List, Optional, Union, cast |
| 3 | +from typing import List, Optional, Union |
4 | 4 |
|
5 | 5 | from graphql import (
|
6 | 6 | ArgumentNode,
|
7 | 7 | DocumentNode,
|
8 |
| - EnumValueNode, |
9 | 8 | FieldNode,
|
10 |
| - GraphQLEnumType, |
| 9 | + GraphQLArgument, |
11 | 10 | GraphQLField,
|
12 |
| - GraphQLInputObjectType, |
13 |
| - GraphQLInputType, |
14 | 11 | GraphQLInterfaceType,
|
15 |
| - GraphQLList, |
16 | 12 | GraphQLNamedType,
|
17 |
| - GraphQLNonNull, |
18 | 13 | GraphQLObjectType,
|
19 |
| - GraphQLScalarType, |
20 | 14 | GraphQLSchema,
|
21 |
| - ListValueNode, |
22 | 15 | NameNode,
|
23 |
| - ObjectFieldNode, |
24 |
| - ObjectValueNode, |
25 | 16 | OperationDefinitionNode,
|
26 | 17 | OperationType,
|
27 | 18 | SelectionSetNode,
|
28 |
| - ValueNode, |
29 | 19 | ast_from_value,
|
30 | 20 | print_ast,
|
31 | 21 | )
|
|
35 | 25 |
|
36 | 26 | log = logging.getLogger(__name__)
|
37 | 27 |
|
38 |
| -Serializer = Callable[[Any], Optional[ValueNode]] |
39 |
| - |
40 | 28 |
|
41 | 29 | def dsl_gql(*fields: "DSLField") -> DocumentNode:
|
42 | 30 | """Given arguments of type :class:`DSLField` containing GraphQL requests,
|
@@ -311,67 +299,36 @@ def args(self, **kwargs) -> "DSLField":
|
311 | 299 | for this field.
|
312 | 300 | """
|
313 | 301 |
|
314 |
| - added_args = [] |
315 |
| - for name, value in kwargs.items(): |
316 |
| - arg = self.field.args.get(name) |
317 |
| - if not arg: |
318 |
| - raise KeyError(f"Argument {name} does not exist in {self.field}.") |
319 |
| - arg_type_serializer = self._get_arg_serializer(arg.type) |
320 |
| - serialized_value = arg_type_serializer(value) |
321 |
| - added_args.append( |
322 |
| - ArgumentNode(name=NameNode(value=name), value=serialized_value) |
323 |
| - ) |
324 |
| - self.ast_field.arguments = FrozenList(self.ast_field.arguments + added_args) |
325 |
| - log.debug(f"Added arguments {kwargs} in field {self!r})") |
326 |
| - return self |
| 302 | + assert self.ast_field.arguments is not None |
327 | 303 |
|
328 |
| - def _get_arg_serializer(self, arg_type: GraphQLInputType) -> Serializer: |
329 |
| - """Recursive function used to get a argument serializer function |
330 |
| - for a specific GraphQL input type. |
| 304 | + self.ast_field.arguments = FrozenList( |
| 305 | + self.ast_field.arguments |
| 306 | + + [ |
| 307 | + ArgumentNode( |
| 308 | + name=NameNode(value=name), |
| 309 | + value=ast_from_value(value, self._get_argument(name).type), |
| 310 | + ) |
| 311 | + for name, value in kwargs.items() |
| 312 | + ] |
| 313 | + ) |
331 | 314 |
|
332 |
| - The only possible sort of types are: |
333 |
| - GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType, GraphQLWrappingType |
334 |
| - GraphQLWrappingType can be GraphQLList or GraphQLNonNull |
335 |
| - """ |
| 315 | + log.debug(f"Added arguments {kwargs} in field {self!r})") |
336 | 316 |
|
337 |
| - log.debug(f"_get_arg_serializer({arg_type!r})") |
338 |
| - |
339 |
| - if isinstance(arg_type, GraphQLNonNull): |
340 |
| - return self._get_arg_serializer(arg_type.of_type) |
341 |
| - |
342 |
| - elif isinstance(arg_type, GraphQLInputObjectType): |
343 |
| - return lambda value: ObjectValueNode( |
344 |
| - fields=FrozenList( |
345 |
| - ObjectFieldNode( |
346 |
| - name=NameNode(value=k), |
347 |
| - value=( |
348 |
| - self._get_arg_serializer( |
349 |
| - cast(GraphQLInputObjectType, arg_type).fields[k].type |
350 |
| - ) |
351 |
| - )(v), |
352 |
| - ) |
353 |
| - for k, v in value.items() |
354 |
| - ) |
355 |
| - ) |
| 317 | + return self |
356 | 318 |
|
357 |
| - elif isinstance(arg_type, GraphQLList): |
358 |
| - inner_serializer = self._get_arg_serializer(arg_type.of_type) |
359 |
| - return lambda list_values: ListValueNode( |
360 |
| - values=FrozenList(inner_serializer(v) for v in list_values) |
361 |
| - ) |
| 319 | + def _get_argument(self, name: str) -> GraphQLArgument: |
| 320 | + """Method used to return the GraphQLArgument definition |
| 321 | + of an argument from its name. |
362 | 322 |
|
363 |
| - elif isinstance(arg_type, GraphQLEnumType): |
364 |
| - return lambda value: EnumValueNode( |
365 |
| - value=cast(GraphQLEnumType, arg_type).serialize(value) |
366 |
| - ) |
| 323 | + :raises KeyError: if the provided argument does not exist |
| 324 | + for this field. |
| 325 | + """ |
| 326 | + arg = self.field.args.get(name) |
367 | 327 |
|
368 |
| - # Impossible to be another type here |
369 |
| - assert isinstance(arg_type, GraphQLScalarType) |
| 328 | + if arg is None: |
| 329 | + raise KeyError(f"Argument {name} does not exist in {self.field}.") |
370 | 330 |
|
371 |
| - return lambda value: ast_from_value( |
372 |
| - cast(GraphQLScalarType, arg_type).serialize(value), |
373 |
| - cast(GraphQLScalarType, arg_type), |
374 |
| - ) |
| 331 | + return arg |
375 | 332 |
|
376 | 333 | @property
|
377 | 334 | def type_name(self):
|
|
0 commit comments