|
1 | 1 | # type: ignore
|
| 2 | +from __future__ import print_function |
| 3 | + |
2 | 4 | import json
|
3 | 5 |
|
4 | 6 | from pytest import raises
|
5 |
| - |
6 | 7 | from graphql.error import GraphQLError
|
7 | 8 | from graphql.execution import MiddlewareManager, execute
|
| 9 | +from graphql.execution.middleware import get_middleware_resolvers, middleware_chain |
8 | 10 | from graphql.language.parser import parse
|
9 | 11 | from graphql.type import (
|
10 | 12 | GraphQLArgument,
|
@@ -138,3 +140,52 @@ def resolve(self, next, *args, **kwargs):
|
138 | 140 | "ok": "ok",
|
139 | 141 | "not_ok": "not_ok",
|
140 | 142 | }
|
| 143 | + |
| 144 | + |
| 145 | +def test_middleware_chain(capsys): |
| 146 | + # type: (Any) -> None |
| 147 | + class CharPrintingMiddleware(object): |
| 148 | + def __init__(self, char): |
| 149 | + # type: (str) -> None |
| 150 | + self.char = char |
| 151 | + |
| 152 | + def resolve(self, next, *args, **kwargs): |
| 153 | + # type: (Callable, *Any, **Any) -> str |
| 154 | + print("resolve() called for middleware {}".format(self.char)) |
| 155 | + return next(*args, **kwargs).then( |
| 156 | + lambda x: print("then() for {}".format(self.char)) |
| 157 | + ) |
| 158 | + |
| 159 | + middlewares = [ |
| 160 | + CharPrintingMiddleware("a"), |
| 161 | + CharPrintingMiddleware("b"), |
| 162 | + CharPrintingMiddleware("c"), |
| 163 | + ] |
| 164 | + |
| 165 | + middlewares_resolvers = get_middleware_resolvers(middlewares) |
| 166 | + |
| 167 | + def func(): |
| 168 | + # type: () -> None |
| 169 | + return |
| 170 | + |
| 171 | + chain_iter = middleware_chain(func, middlewares_resolvers, wrap_in_promise=True) |
| 172 | + |
| 173 | + assert_stdout(capsys, "") |
| 174 | + |
| 175 | + chain_iter() |
| 176 | + |
| 177 | + expected_stdout = ( |
| 178 | + "resolve() called for middleware c\n" |
| 179 | + "resolve() called for middleware b\n" |
| 180 | + "resolve() called for middleware a\n" |
| 181 | + "then() for a\n" |
| 182 | + "then() for b\n" |
| 183 | + "then() for c\n" |
| 184 | + ) |
| 185 | + assert_stdout(capsys, expected_stdout) |
| 186 | + |
| 187 | + |
| 188 | +def assert_stdout(capsys, expected_stdout): |
| 189 | + # type: (Any, str) -> None |
| 190 | + captured = capsys.readouterr() |
| 191 | + assert captured.out == expected_stdout |
0 commit comments