Skip to content

Commit 2f2eb91

Browse files
committed
chore: enforce protocol type checking for app instance
1 parent f5a613a commit 2f2eb91

File tree

8 files changed

+53
-57
lines changed

8 files changed

+53
-57
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,13 @@
99
from enum import Enum
1010
from functools import partial
1111
from http import HTTPStatus
12-
from typing import (
13-
Any,
14-
Callable,
15-
Dict,
16-
List,
17-
Match,
18-
Optional,
19-
Pattern,
20-
Set,
21-
Tuple,
22-
Type,
23-
Union,
24-
)
12+
from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union
2513

2614
from aws_lambda_powertools.event_handler import content_types
2715
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
2816
from aws_lambda_powertools.shared.cookies import Cookie
2917
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
3018
from aws_lambda_powertools.shared.json_encoder import Encoder
31-
from aws_lambda_powertools.shared.types import Protocol
3219
from aws_lambda_powertools.utilities.data_classes import (
3320
ALBEvent,
3421
APIGatewayProxyEvent,
@@ -1331,11 +1318,4 @@ def __init__(
13311318
super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes)
13321319

13331320

1334-
class NextMiddlewareCallback(Protocol):
1335-
def __call__(self, app: ApiGatewayResolver) -> Response:
1336-
"""Protocol for callback regardless of next_middleware(app), next(app)"""
1337-
...
1338-
1339-
def __name__(self) -> str: # noqa A003
1340-
"""Protocol for name of the Middleware"""
1341-
...
1321+
# Specialized types defined here due to circular dependency in `.types`
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from aws_lambda_powertools.event_handler.middlewares.base import BaseMiddlewareHandler
1+
from aws_lambda_powertools.event_handler.middlewares.base import BaseMiddlewareHandler, NextMiddleware
22
from aws_lambda_powertools.event_handler.middlewares.schema_validation import SchemaValidationMiddleware
33

4-
__all__ = ["BaseMiddlewareHandler", "SchemaValidationMiddleware"]
4+
__all__ = ["BaseMiddlewareHandler", "SchemaValidationMiddleware", "NextMiddleware"]

aws_lambda_powertools/event_handler/middlewares/base.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Callable, Generic, TypeVar
2+
from typing import Generic
33

4-
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, Response
4+
from typing_extensions import Protocol
55

6-
EventHandlerResolver = TypeVar("EventHandlerResolver", bound=ApiGatewayResolver)
6+
from aws_lambda_powertools.event_handler.api_gateway import Response
7+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
78

89

9-
class BaseMiddlewareHandler(Generic[EventHandlerResolver], ABC):
10+
class NextMiddleware(Protocol):
11+
def __call__(self, app: EventHandlerInstance) -> Response:
12+
"""Protocol for callback regardless of next_middleware(app), get_response(app) etc"""
13+
...
14+
15+
def __name__(self) -> str: # noqa A003
16+
"""Protocol for name of the Middleware"""
17+
...
18+
19+
20+
class BaseMiddlewareHandler(Generic[EventHandlerInstance], ABC):
1021
"""
1122
Base class for Middleware Handlers
1223
@@ -69,15 +80,15 @@ class BaseMiddlewareHandler(Generic[EventHandlerResolver], ABC):
6980
"""
7081

7182
@abstractmethod
72-
def handler(self, app: EventHandlerResolver, next_middleware: Callable[..., Any]) -> Response:
83+
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
7384
"""
7485
The Middleware Handler
7586
7687
Parameters
7788
----------
78-
app: EventHandlerResolver
89+
app: EventHandlerInstance
7990
An instance of an Event Handler that implements ApiGatewayResolver
80-
next_middleware: Callable[..., Any]
91+
next_middleware: NextMiddleware
8192
The next middleware handler in the chain
8293
8394
Returns
@@ -92,7 +103,7 @@ def handler(self, app: EventHandlerResolver, next_middleware: Callable[..., Any]
92103
def __name__(self) -> str: # noqa A003
93104
return str(self.__class__.__name__)
94105

95-
def __call__(self, app: EventHandlerResolver, next_middleware: Callable[..., Any]) -> Response:
106+
def __call__(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
96107
"""
97108
The Middleware handler function.
98109
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
from aws_lambda_powertools.event_handler.api_gateway import NextMiddlewareCallback
1+
from typing import TypeVar
22

3-
__all__ = ["NextMiddlewareCallback"]
3+
from aws_lambda_powertools.event_handler import ApiGatewayResolver
4+
5+
EventHandlerInstance = TypeVar("EventHandlerInstance", bound=ApiGatewayResolver)

examples/event_handler_rest/src/custom_middlewares.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from aws_lambda_powertools import Logger
22
from aws_lambda_powertools.event_handler import ApiGatewayResolver, Response
3-
from aws_lambda_powertools.event_handler.api_gateway import NextMiddlewareCallback
43
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError, ServiceError
4+
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
55

66
logger = Logger()
77

88

9-
def validate_correlation_id(app: ApiGatewayResolver, next_middleware: NextMiddlewareCallback) -> Response:
9+
def validate_correlation_id(app: ApiGatewayResolver, next_middleware: NextMiddleware) -> Response:
1010
# If missing mandatory header raise an error
1111
if not app.current_event.headers.get("x-correlation-id", None):
1212
raise BadRequestError("No [x-correlation-id] header provided. All requests must include this header.")
@@ -15,7 +15,7 @@ def validate_correlation_id(app: ApiGatewayResolver, next_middleware: NextMiddle
1515
return next_middleware(app)
1616

1717

18-
def sanitise_exceptions(app: ApiGatewayResolver, next_middleware: NextMiddlewareCallback) -> Response:
18+
def sanitise_exceptions(app: ApiGatewayResolver, next_middleware: NextMiddleware) -> Response:
1919
try:
2020
# Get the Result from the next middleware
2121
result = next_middleware(app)

examples/event_handler_rest/src/middleware_getting_started.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from aws_lambda_powertools import Logger
44
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
5-
from aws_lambda_powertools.event_handler.types import NextMiddlewareCallback
5+
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
66

77
app = APIGatewayRestResolver()
88
logger = Logger()
99

1010

11-
def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddlewareCallback) -> Response:
11+
def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
1212
request_id = app.current_event.request_context.request_id # (1)!
1313

1414
# Use API Gateway REST API request ID if caller didn't include a correlation ID

examples/event_handler_rest/src/middleware_global_middlewares_module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from aws_lambda_powertools import Logger
22
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
3-
from aws_lambda_powertools.event_handler.types import NextMiddlewareCallback
3+
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
44

55
logger = Logger()
66

77

8-
def log_request_response(app: APIGatewayRestResolver, next_middleware: NextMiddlewareCallback) -> Response:
8+
def log_request_response(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
99
logger.info("Incoming request", path=app.current_event.path, request=app.current_event.raw_event)
1010

1111
result = next_middleware(app)
@@ -14,7 +14,7 @@ def log_request_response(app: APIGatewayRestResolver, next_middleware: NextMiddl
1414
return result
1515

1616

17-
def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddlewareCallback) -> Response:
17+
def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
1818
request_id = app.current_event.request_context.request_id
1919

2020
# Use API Gateway REST API request ID if caller didn't include a correlation ID
@@ -32,7 +32,7 @@ def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMidd
3232
return result
3333

3434

35-
def enforce_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddlewareCallback) -> Response:
35+
def enforce_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
3636
# If missing mandatory header raise an error
3737
if not app.current_event.get_header_value("x-correlation-id", case_sensitive=False):
3838
return Response(status_code=400, body="Correlation ID header is now mandatory.") # (1)!

tests/functional/event_handler/test_api_middlewares.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
ApiGatewayResolver,
77
APIGatewayRestResolver,
88
BaseRouter,
9-
NextMiddlewareCallback,
109
ProxyEventType,
1110
Response,
1211
Router,
1312
)
1413
from aws_lambda_powertools.event_handler.exceptions import BadRequestError
15-
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, SchemaValidationMiddleware
14+
from aws_lambda_powertools.event_handler.middlewares import (
15+
BaseMiddlewareHandler,
16+
NextMiddleware,
17+
SchemaValidationMiddleware,
18+
)
1619
from tests.functional.utils import load_event
1720

1821
API_REST_EVENT = load_event("apiGatewayProxyEvent.json")
@@ -29,15 +32,15 @@
2932
)
3033
def test_route_with_middleware(app: BaseRouter, event):
3134
# define custom middleware to inject new argument - "custom"
32-
def middleware_1(app: BaseRouter, next_middleware: NextMiddlewareCallback):
35+
def middleware_1(app: BaseRouter, next_middleware: NextMiddleware):
3336
# add additional data to Router Context
3437
app.append_context(custom="custom")
3538
response = next_middleware(app)
3639

3740
return response
3841

3942
# define custom middleware to inject new argument - "another_one"
40-
def middleware_2(app: BaseRouter, next_middleware: NextMiddlewareCallback):
43+
def middleware_2(app: BaseRouter, next_middleware: NextMiddleware):
4144
# add additional data to Router Context
4245
app.append_context(another_one=6)
4346
response = next_middleware(app)
@@ -84,15 +87,15 @@ def get_lambda() -> Response:
8487
)
8588
def test_with_router_middleware(app: BaseRouter, event, other_event):
8689
# define custom middleware to inject new argument - "custom"
87-
def global_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback):
90+
def global_middleware(app: BaseRouter, next_middleware: NextMiddleware):
8891
# add custom data to context
8992
app.append_context(custom="custom")
9093
response = next_middleware(app)
9194

9295
return response
9396

9497
# define custom middleware to inject new argument - "another_one"
95-
def middleware_2(app: BaseRouter, next_middleware: NextMiddlewareCallback):
98+
def middleware_2(app: BaseRouter, next_middleware: NextMiddleware):
9699
# add data to resolver context
97100
app.append_context(another_one=6)
98101
response = next_middleware(app)
@@ -143,7 +146,7 @@ def get_other_lambda() -> Response:
143146
],
144147
)
145148
def test_dynamic_route_with_middleware(app: BaseRouter, event):
146-
def middleware_one(app: BaseRouter, next_middleware: NextMiddlewareCallback):
149+
def middleware_one(app: BaseRouter, next_middleware: NextMiddleware):
147150
# inject data into the resolver context
148151
app.append_context(injected="injected_value")
149152
response = next_middleware(app)
@@ -182,12 +185,12 @@ def middleware_one(app: BaseRouter, next_middleware):
182185

183186
return response
184187

185-
def early_return_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback):
188+
def early_return_middleware(app: BaseRouter, next_middleware: NextMiddleware):
186189
assert app.context.get("injected") == "injected_value"
187190

188191
return Response(400, content_types.TEXT_HTML, "bad_response")
189192

190-
def not_executed_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback):
193+
def not_executed_middleware(app: BaseRouter, next_middleware: NextMiddleware):
191194
# This should never be executed - if it is an excpetion will be raised
192195
raise NotImplementedError()
193196

@@ -310,19 +313,19 @@ def post_lambda() -> Response:
310313
],
311314
)
312315
def test_middleware_short_circuit_via_httperrors(app: BaseRouter, event):
313-
def middleware_one(app: BaseRouter, next_middleware: NextMiddlewareCallback):
316+
def middleware_one(app: BaseRouter, next_middleware: NextMiddleware):
314317
# inject a variable into the kwargs of the middleware chain
315318
app.append_context(injected="injected_value")
316319
response = next_middleware(app)
317320

318321
return response
319322

320-
def early_return_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback):
323+
def early_return_middleware(app: BaseRouter, next_middleware: NextMiddleware):
321324
# ensure "injected" context variable is passed in by middleware_one
322325
assert app.context.get("injected") == "injected_value"
323326
raise BadRequestError("bad_response")
324327

325-
def not_executed_middleware(app: BaseRouter, next_middleware: NextMiddlewareCallback):
328+
def not_executed_middleware(app: BaseRouter, next_middleware: NextMiddleware):
326329
# This should never be executed - if it is an excpetion will be raised
327330
raise NotImplementedError()
328331

@@ -373,7 +376,7 @@ def router_middleware(app: BaseRouter, next_middleware):
373376

374377
to_inject: str = "injected_value"
375378

376-
def middleware_one(app: BaseRouter, next_middleware: NextMiddlewareCallback):
379+
def middleware_one(app: BaseRouter, next_middleware: NextMiddleware):
377380
# inject a variable into the kwargs of the middleware chain
378381
app.append_context(injected=to_inject)
379382
response = next_middleware(app)
@@ -406,7 +409,7 @@ def __init__(self, header: str):
406409
super().__init__()
407410
self.header = header
408411

409-
def handler(self, app: ApiGatewayResolver, get_response: NextMiddlewareCallback, **kwargs) -> Response:
412+
def handler(self, app: ApiGatewayResolver, get_response: NextMiddleware, **kwargs) -> Response:
410413
request_id = app.current_event.request_context.request_id # type: ignore[attr-defined] # using REST event in a base Resolver # noqa: E501
411414
correlation_id = app.current_event.get_header_value(
412415
name=self.header,

0 commit comments

Comments
 (0)