|
6 | 6 | ApiGatewayResolver,
|
7 | 7 | APIGatewayRestResolver,
|
8 | 8 | BaseRouter,
|
9 |
| - NextMiddlewareCallback, |
10 | 9 | ProxyEventType,
|
11 | 10 | Response,
|
12 | 11 | Router,
|
13 | 12 | )
|
14 | 13 | from aws_lambda_powertools.event_handler.exceptions import BadRequestError
|
15 |
| -from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, SchemaValidationMiddleware |
| 14 | +from aws_lambda_powertools.event_handler.middlewares import ( |
| 15 | + BaseMiddlewareHandler, |
| 16 | + NextMiddleware, |
| 17 | + SchemaValidationMiddleware, |
| 18 | +) |
16 | 19 | from tests.functional.utils import load_event
|
17 | 20 |
|
18 | 21 | API_REST_EVENT = load_event("apiGatewayProxyEvent.json")
|
|
29 | 32 | )
|
30 | 33 | def test_route_with_middleware(app: BaseRouter, event):
|
31 | 34 | # define custom middleware to inject new argument - "custom"
|
32 |
| - def middleware_1(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 35 | + def middleware_1(app: BaseRouter, next_middleware: NextMiddleware): |
33 | 36 | # add additional data to Router Context
|
34 | 37 | app.append_context(custom="custom")
|
35 | 38 | response = next_middleware(app)
|
36 | 39 |
|
37 | 40 | return response
|
38 | 41 |
|
39 | 42 | # define custom middleware to inject new argument - "another_one"
|
40 |
| - def middleware_2(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 43 | + def middleware_2(app: BaseRouter, next_middleware: NextMiddleware): |
41 | 44 | # add additional data to Router Context
|
42 | 45 | app.append_context(another_one=6)
|
43 | 46 | response = next_middleware(app)
|
@@ -84,15 +87,15 @@ def get_lambda() -> Response:
|
84 | 87 | )
|
85 | 88 | def test_with_router_middleware(app: BaseRouter, event, other_event):
|
86 | 89 | # define custom middleware to inject new argument - "custom"
|
87 |
| - def global_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 90 | + def global_middleware(app: BaseRouter, next_middleware: NextMiddleware): |
88 | 91 | # add custom data to context
|
89 | 92 | app.append_context(custom="custom")
|
90 | 93 | response = next_middleware(app)
|
91 | 94 |
|
92 | 95 | return response
|
93 | 96 |
|
94 | 97 | # define custom middleware to inject new argument - "another_one"
|
95 |
| - def middleware_2(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 98 | + def middleware_2(app: BaseRouter, next_middleware: NextMiddleware): |
96 | 99 | # add data to resolver context
|
97 | 100 | app.append_context(another_one=6)
|
98 | 101 | response = next_middleware(app)
|
@@ -143,7 +146,7 @@ def get_other_lambda() -> Response:
|
143 | 146 | ],
|
144 | 147 | )
|
145 | 148 | def test_dynamic_route_with_middleware(app: BaseRouter, event):
|
146 |
| - def middleware_one(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 149 | + def middleware_one(app: BaseRouter, next_middleware: NextMiddleware): |
147 | 150 | # inject data into the resolver context
|
148 | 151 | app.append_context(injected="injected_value")
|
149 | 152 | response = next_middleware(app)
|
@@ -182,12 +185,12 @@ def middleware_one(app: BaseRouter, next_middleware):
|
182 | 185 |
|
183 | 186 | return response
|
184 | 187 |
|
185 |
| - def early_return_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 188 | + def early_return_middleware(app: BaseRouter, next_middleware: NextMiddleware): |
186 | 189 | assert app.context.get("injected") == "injected_value"
|
187 | 190 |
|
188 | 191 | return Response(400, content_types.TEXT_HTML, "bad_response")
|
189 | 192 |
|
190 |
| - def not_executed_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 193 | + def not_executed_middleware(app: BaseRouter, next_middleware: NextMiddleware): |
191 | 194 | # This should never be executed - if it is an excpetion will be raised
|
192 | 195 | raise NotImplementedError()
|
193 | 196 |
|
@@ -310,19 +313,19 @@ def post_lambda() -> Response:
|
310 | 313 | ],
|
311 | 314 | )
|
312 | 315 | def test_middleware_short_circuit_via_httperrors(app: BaseRouter, event):
|
313 |
| - def middleware_one(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 316 | + def middleware_one(app: BaseRouter, next_middleware: NextMiddleware): |
314 | 317 | # inject a variable into the kwargs of the middleware chain
|
315 | 318 | app.append_context(injected="injected_value")
|
316 | 319 | response = next_middleware(app)
|
317 | 320 |
|
318 | 321 | return response
|
319 | 322 |
|
320 |
| - def early_return_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 323 | + def early_return_middleware(app: BaseRouter, next_middleware: NextMiddleware): |
321 | 324 | # ensure "injected" context variable is passed in by middleware_one
|
322 | 325 | assert app.context.get("injected") == "injected_value"
|
323 | 326 | raise BadRequestError("bad_response")
|
324 | 327 |
|
325 |
| - def not_executed_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 328 | + def not_executed_middleware(app: BaseRouter, next_middleware: NextMiddleware): |
326 | 329 | # This should never be executed - if it is an excpetion will be raised
|
327 | 330 | raise NotImplementedError()
|
328 | 331 |
|
@@ -373,7 +376,7 @@ def router_middleware(app: BaseRouter, next_middleware):
|
373 | 376 |
|
374 | 377 | to_inject: str = "injected_value"
|
375 | 378 |
|
376 |
| - def middleware_one(app: BaseRouter, next_middleware: NextMiddlewareCallback): |
| 379 | + def middleware_one(app: BaseRouter, next_middleware: NextMiddleware): |
377 | 380 | # inject a variable into the kwargs of the middleware chain
|
378 | 381 | app.append_context(injected=to_inject)
|
379 | 382 | response = next_middleware(app)
|
@@ -406,7 +409,7 @@ def __init__(self, header: str):
|
406 | 409 | super().__init__()
|
407 | 410 | self.header = header
|
408 | 411 |
|
409 |
| - def handler(self, app: ApiGatewayResolver, get_response: NextMiddlewareCallback, **kwargs) -> Response: |
| 412 | + def handler(self, app: ApiGatewayResolver, get_response: NextMiddleware, **kwargs) -> Response: |
410 | 413 | request_id = app.current_event.request_context.request_id # type: ignore[attr-defined] # using REST event in a base Resolver # noqa: E501
|
411 | 414 | correlation_id = app.current_event.get_header_value(
|
412 | 415 | name=self.header,
|
|
0 commit comments