1
- from functools import partial
1
+ from functools import partial , reduce
2
2
from inspect import isfunction
3
- from itertools import chain
4
3
5
- from typing import Callable , Iterator , Dict , Tuple , Any , Iterable , Optional , cast
4
+ from typing import Callable , Iterator , Dict , Tuple , Any , Optional
6
5
7
6
__all__ = ["MiddlewareManager" ]
8
7
@@ -41,8 +40,10 @@ def get_field_resolver(
41
40
if self ._middleware_resolvers is None :
42
41
return field_resolver
43
42
if field_resolver not in self ._cached_resolvers :
44
- self ._cached_resolvers [field_resolver ] = middleware_chain (
45
- field_resolver , self ._middleware_resolvers
43
+ self ._cached_resolvers [field_resolver ] = reduce (
44
+ lambda chained_fns , next_fn : partial (next_fn , chained_fns ),
45
+ self ._middleware_resolvers ,
46
+ field_resolver ,
46
47
)
47
48
return self ._cached_resolvers [field_resolver ]
48
49
@@ -56,19 +57,3 @@ def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]
56
57
resolver_func = getattr (middleware , "resolve" , None )
57
58
if resolver_func is not None :
58
59
yield resolver_func
59
-
60
-
61
- def middleware_chain (
62
- func : GraphQLFieldResolver , middlewares : Iterable [Callable ]
63
- ) -> GraphQLFieldResolver :
64
- """Chain the given function with the provided middlewares.
65
-
66
- Returns a new resolver function that is the chain of both.
67
- """
68
- if not middlewares :
69
- return func
70
- middlewares = chain ((func ,), middlewares )
71
- last_func : Optional [GraphQLFieldResolver ] = None
72
- for middleware in middlewares :
73
- last_func = partial (middleware , last_func ) if last_func else middleware
74
- return cast (GraphQLFieldResolver , last_func )
0 commit comments