From 937e6fe670be1ed37ddb1f8626ec7825682dd1b7 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Tue, 8 Mar 2022 09:19:52 +1100 Subject: [PATCH 1/3] Attempt to type lambda_handler_decorator accurately --- .../middleware_factory/factory.py | 62 +++++++++++++++++-- 1 file changed, 57 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/middleware_factory/factory.py b/aws_lambda_powertools/middleware_factory/factory.py index 8ab16c5e8b7..365d2889428 100644 --- a/aws_lambda_powertools/middleware_factory/factory.py +++ b/aws_lambda_powertools/middleware_factory/factory.py @@ -1,18 +1,64 @@ import functools import inspect import logging +import sys import os -from typing import Callable, Optional +from typing import Any, Callable, Dict, Optional, Union, cast, overload + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol from ..shared import constants from ..shared.functions import resolve_truthy_env_var_choice from ..tracing import Tracer +from ..utilities.typing import LambdaContext from .exceptions import MiddlewareInvalidArgumentError logger = logging.getLogger(__name__) +# context: Any to avoid forcing users to type it as context: LambdaContext +_Handler = Callable[[Any, LambdaContext], Any] +_RawHandlerDecorator = Callable[[_Handler], _Handler] + + +class _FactoryDecorator(Protocol): + # it'd be better for this to be using ParamSpec (available from 3.10) + def __call__( + self, handler: _Handler, event: Dict[str, Any], context: LambdaContext, **kwargs: Any + ) -> _RawHandlerDecorator: + ... + + +class _HandlerDecorator(Protocol): + @overload + def __call__(self, decorator: _Handler) -> _Handler: + ... + + @overload + def __call__(self, decorator: None = None, **kwargs: Any) -> _RawHandlerDecorator: + ... + + def __call__(self, decorator: Optional[_Handler] = None, **kwargs: Any) -> Union[_Handler, _RawHandlerDecorator]: + ... + + +@overload +def lambda_handler_decorator(decorator: _FactoryDecorator) -> _HandlerDecorator: + ... + + +@overload +def lambda_handler_decorator( + decorator: None = None, trace_execution: Optional[bool] = None +) -> Callable[[_FactoryDecorator], _HandlerDecorator]: + ... + -def lambda_handler_decorator(decorator: Optional[Callable] = None, trace_execution: Optional[bool] = None): +def lambda_handler_decorator( + decorator: Optional[_FactoryDecorator] = None, trace_execution: Optional[bool] = None +) -> Union[_HandlerDecorator, Callable[[_FactoryDecorator], _HandlerDecorator]]: """Decorator factory for decorating Lambda handlers. You can use lambda_handler_decorator to create your own middlewares, @@ -103,19 +149,25 @@ def lambda_handler(event, context): """ if decorator is None: - return functools.partial(lambda_handler_decorator, trace_execution=trace_execution) + return cast( + Callable[[_FactoryDecorator], _HandlerDecorator], + functools.partial(lambda_handler_decorator, trace_execution=trace_execution), + ) trace_execution = resolve_truthy_env_var_choice( env=os.getenv(constants.MIDDLEWARE_FACTORY_TRACE_ENV, "false"), choice=trace_execution ) @functools.wraps(decorator) - def final_decorator(func: Optional[Callable] = None, **kwargs): + def final_decorator( + func: Optional[_RawHandlerDecorator] = None, **kwargs: Any + ) -> Union[_Handler, _RawHandlerDecorator]: # If called with kwargs return new func with kwargs if func is None: return functools.partial(final_decorator, **kwargs) if not inspect.isfunction(func): + assert decorator is not None # @custom_middleware(True) vs @custom_middleware(log_event=True) raise MiddlewareInvalidArgumentError( f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}" # type: ignore # noqa: E501 @@ -138,4 +190,4 @@ def wrapper(event, context): return wrapper - return final_decorator + return cast(_HandlerDecorator, final_decorator) From 665e0bb8f76b9d72a51ad690255296257e927cf8 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Tue, 8 Mar 2022 09:34:44 +1100 Subject: [PATCH 2/3] Simplify to just Callable[..., Any] --- .../middleware_factory/factory.py | 63 ++----------------- 1 file changed, 6 insertions(+), 57 deletions(-) diff --git a/aws_lambda_powertools/middleware_factory/factory.py b/aws_lambda_powertools/middleware_factory/factory.py index 365d2889428..a78e36ab918 100644 --- a/aws_lambda_powertools/middleware_factory/factory.py +++ b/aws_lambda_powertools/middleware_factory/factory.py @@ -1,64 +1,19 @@ import functools import inspect import logging -import sys import os -from typing import Any, Callable, Dict, Optional, Union, cast, overload - -if sys.version_info >= (3, 8): - from typing import Protocol -else: - from typing_extensions import Protocol +from typing import Any, Callable, Optional from ..shared import constants from ..shared.functions import resolve_truthy_env_var_choice from ..tracing import Tracer -from ..utilities.typing import LambdaContext from .exceptions import MiddlewareInvalidArgumentError logger = logging.getLogger(__name__) -# context: Any to avoid forcing users to type it as context: LambdaContext -_Handler = Callable[[Any, LambdaContext], Any] -_RawHandlerDecorator = Callable[[_Handler], _Handler] - - -class _FactoryDecorator(Protocol): - # it'd be better for this to be using ParamSpec (available from 3.10) - def __call__( - self, handler: _Handler, event: Dict[str, Any], context: LambdaContext, **kwargs: Any - ) -> _RawHandlerDecorator: - ... - - -class _HandlerDecorator(Protocol): - @overload - def __call__(self, decorator: _Handler) -> _Handler: - ... - - @overload - def __call__(self, decorator: None = None, **kwargs: Any) -> _RawHandlerDecorator: - ... - - def __call__(self, decorator: Optional[_Handler] = None, **kwargs: Any) -> Union[_Handler, _RawHandlerDecorator]: - ... - - -@overload -def lambda_handler_decorator(decorator: _FactoryDecorator) -> _HandlerDecorator: - ... - - -@overload -def lambda_handler_decorator( - decorator: None = None, trace_execution: Optional[bool] = None -) -> Callable[[_FactoryDecorator], _HandlerDecorator]: - ... - -def lambda_handler_decorator( - decorator: Optional[_FactoryDecorator] = None, trace_execution: Optional[bool] = None -) -> Union[_HandlerDecorator, Callable[[_FactoryDecorator], _HandlerDecorator]]: +# giving this an accurate return type is hard +def lambda_handler_decorator(decorator: Optional[Callable] = None, trace_execution: Optional[bool] = None) -> Callable: """Decorator factory for decorating Lambda handlers. You can use lambda_handler_decorator to create your own middlewares, @@ -149,25 +104,19 @@ def lambda_handler(event, context): """ if decorator is None: - return cast( - Callable[[_FactoryDecorator], _HandlerDecorator], - functools.partial(lambda_handler_decorator, trace_execution=trace_execution), - ) + return functools.partial(lambda_handler_decorator, trace_execution=trace_execution) trace_execution = resolve_truthy_env_var_choice( env=os.getenv(constants.MIDDLEWARE_FACTORY_TRACE_ENV, "false"), choice=trace_execution ) @functools.wraps(decorator) - def final_decorator( - func: Optional[_RawHandlerDecorator] = None, **kwargs: Any - ) -> Union[_Handler, _RawHandlerDecorator]: + def final_decorator(func: Optional[Callable] = None, **kwargs: Any): # If called with kwargs return new func with kwargs if func is None: return functools.partial(final_decorator, **kwargs) if not inspect.isfunction(func): - assert decorator is not None # @custom_middleware(True) vs @custom_middleware(log_event=True) raise MiddlewareInvalidArgumentError( f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}" # type: ignore # noqa: E501 @@ -190,4 +139,4 @@ def wrapper(event, context): return wrapper - return cast(_HandlerDecorator, final_decorator) + return final_decorator From 58f1ea4dcb9fb77101bb6b45fa41a62263e17018 Mon Sep 17 00:00:00 2001 From: Heitor Lessa Date: Fri, 8 Apr 2022 09:43:54 +0200 Subject: [PATCH 3/3] chore: add a maintenance note on the challenge --- aws_lambda_powertools/middleware_factory/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/middleware_factory/factory.py b/aws_lambda_powertools/middleware_factory/factory.py index a78e36ab918..850983728c2 100644 --- a/aws_lambda_powertools/middleware_factory/factory.py +++ b/aws_lambda_powertools/middleware_factory/factory.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -# giving this an accurate return type is hard +# Maintenance: we can't yet provide an accurate return type without ParamSpec etc. see #1066 def lambda_handler_decorator(decorator: Optional[Callable] = None, trace_execution: Optional[bool] = None) -> Callable: """Decorator factory for decorating Lambda handlers.