diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 446b1eca856..1e6fe2a50bb 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -520,7 +520,7 @@ def __init__( cors: Optional[CORSConfig] = None, debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, - strip_prefixes: Optional[List[str]] = None, + strip_prefixes: Optional[List[Union[str, Pattern]]] = None, ): """ Parameters @@ -534,9 +534,10 @@ def __init__( environment variable serializer : Callable, optional function to serialize `obj` to a JSON formatted `str`, by default json.dumps - strip_prefixes: List[str], optional - optional list of prefixes to be removed from the request path before doing the routing. This is often used - with api gateways with multiple custom mappings. + strip_prefixes: List[Union[str, Pattern]], optional + optional list of prefixes to be removed from the request path before doing the routing. + This is often used with api gateways with multiple custom mappings. + Each prefix can be a static string or a compiled regex pattern """ self._proxy_type = proxy_type self._dynamic_routes: List[Route] = [] @@ -713,10 +714,21 @@ def _remove_prefix(self, path: str) -> str: return path for prefix in self._strip_prefixes: - if path == prefix: - return "/" - if self._path_starts_with(path, prefix): - return path[len(prefix) :] + if isinstance(prefix, str): + if path == prefix: + return "/" + + if self._path_starts_with(path, prefix): + return path[len(prefix) :] + + if isinstance(prefix, Pattern): + path = re.sub(prefix, "", path) + + # When using regexes, we might get into a point where everything is removed + # from the string, so we check if it's empty and return /, since there's nothing + # else to strip anymore. + if not path: + return "/" return path @@ -911,7 +923,7 @@ def __init__( cors: Optional[CORSConfig] = None, debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, - strip_prefixes: Optional[List[str]] = None, + strip_prefixes: Optional[List[Union[str, Pattern]]] = None, ): """Amazon API Gateway REST and HTTP API v1 payload resolver""" super().__init__(ProxyEventType.APIGatewayProxyEvent, cors, debug, serializer, strip_prefixes) @@ -942,7 +954,7 @@ def __init__( cors: Optional[CORSConfig] = None, debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, - strip_prefixes: Optional[List[str]] = None, + strip_prefixes: Optional[List[Union[str, Pattern]]] = None, ): """Amazon API Gateway HTTP API v2 payload resolver""" super().__init__(ProxyEventType.APIGatewayProxyEventV2, cors, debug, serializer, strip_prefixes) @@ -956,7 +968,7 @@ def __init__( cors: Optional[CORSConfig] = None, debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, - strip_prefixes: Optional[List[str]] = None, + strip_prefixes: Optional[List[Union[str, Pattern]]] = None, ): """Amazon Application Load Balancer (ALB) resolver""" super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes) diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index 6978b29f451..433a013ab0b 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Pattern, Union from aws_lambda_powertools.event_handler import CORSConfig from aws_lambda_powertools.event_handler.api_gateway import ( @@ -51,6 +51,6 @@ def __init__( cors: Optional[CORSConfig] = None, debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, - strip_prefixes: Optional[List[str]] = None, + strip_prefixes: Optional[List[Union[str, Pattern]]] = None, ): super().__init__(ProxyEventType.LambdaFunctionUrlEvent, cors, debug, serializer, strip_prefixes) diff --git a/aws_lambda_powertools/event_handler/vpc_lattice.py b/aws_lambda_powertools/event_handler/vpc_lattice.py index 1150f7224fb..b3cb042b40b 100644 --- a/aws_lambda_powertools/event_handler/vpc_lattice.py +++ b/aws_lambda_powertools/event_handler/vpc_lattice.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Pattern, Union from aws_lambda_powertools.event_handler import CORSConfig from aws_lambda_powertools.event_handler.api_gateway import ( @@ -47,7 +47,7 @@ def __init__( cors: Optional[CORSConfig] = None, debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, - strip_prefixes: Optional[List[str]] = None, + strip_prefixes: Optional[List[Union[str, Pattern]]] = None, ): """Amazon VPC Lattice resolver""" super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes) diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 708a9de6855..dcfa38f6f9a 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -272,7 +272,7 @@ When using [Custom Domain API Mappings feature](https://docs.aws.amazon.com/apig **Scenario**: You have a custom domain `api.mydomain.dev`. Then you set `/payment` API Mapping to forward any payment requests to your Payments API. -**Challenge**: This means your `path` value for any API requests will always contain `/payment/`, leading to HTTP 404 as Event Handler is trying to match what's after `payment/`. This gets further complicated with an [arbitrary level of nesting](https://github.com/aws-powertools/powertools-lambda-roadmap/issues/34){target="_blank"}. +**Challenge**: This means your `path` value for any API requests will always contain `/payment/`, leading to HTTP 404 as Event Handler is trying to match what's after `payment/`. This gets further complicated with an [arbitrary level of nesting](https://github.com/aws-powertools/powertools-lambda/issues/34){target="_blank"}. To address this API Gateway behavior, we use `strip_prefixes` parameter to account for these prefixes that are now injected into the path regardless of which type of API Gateway you're using. @@ -293,6 +293,14 @@ To address this API Gateway behavior, we use `strip_prefixes` parameter to accou For example, when using `strip_prefixes` value of `/pay`, there is no difference between a request path of `/pay` and `/pay/`; and the path argument would be defined as `/`. +For added flexibility, you can use regexes to strip a prefix. This is helpful when you have many options due to different combinations of prefixes (e.g: multiple environments, multiple versions). + +=== "strip_route_prefix_regex.py" + + ```python hl_lines="12" + --8<-- "examples/event_handler_rest/src/strip_route_prefix_regex.py" + ``` + ## Advanced ### CORS diff --git a/examples/event_handler_rest/src/strip_route_prefix_regex.py b/examples/event_handler_rest/src/strip_route_prefix_regex.py new file mode 100644 index 00000000000..4ea4b4249f4 --- /dev/null +++ b/examples/event_handler_rest/src/strip_route_prefix_regex.py @@ -0,0 +1,21 @@ +import re + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.utilities.typing import LambdaContext + +# This will support: +# /v1/dev/subscriptions/ +# /v1/stg/subscriptions/ +# /v1/qa/subscriptions/ +# /v2/dev/subscriptions/ +# ... +app = APIGatewayRestResolver(strip_prefixes=[re.compile(r"/v[1-3]+/(dev|stg|qa)")]) + + +@app.get("/subscriptions/") +def get_subscription(subscription): + return {"subscription_id": subscription} + + +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 26c71e1f27d..2afd1241bed 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1,5 +1,6 @@ import base64 import json +import re import zlib from copy import deepcopy from decimal import Decimal @@ -1077,6 +1078,38 @@ def foo(): assert response["statusCode"] == 200 +@pytest.mark.parametrize( + "path", + [ + pytest.param("/stg/foo", id="path matched pay prefix"), + pytest.param("/dev/foo", id="path matched pay prefix with multiple numbers"), + pytest.param("/foo", id="path does not start with any of the prefixes"), + ], +) +def test_remove_prefix_by_regex(path: str): + app = ApiGatewayResolver(strip_prefixes=[re.compile(r"/(dev|stg)")]) + + @app.get("/foo") + def foo(): + ... + + response = app({"httpMethod": "GET", "path": path}, None) + + assert response["statusCode"] == 200 + + +def test_empty_path_when_using_regexes(): + app = ApiGatewayResolver(strip_prefixes=[re.compile(r"/(dev|stg)")]) + + @app.get("/") + def foo(): + ... + + response = app({"httpMethod": "GET", "path": "/dev"}, None) + + assert response["statusCode"] == 200 + + @pytest.mark.parametrize( "prefix", [