Skip to content

Commit 76ed8e6

Browse files
author
Tom McCarthy
committed
fix: allow passing list or tuple for methods to the route function for ApiGatewayResolver
1 parent 566043a commit 76ed8e6

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from enum import Enum
1111
from functools import partial
1212
from http import HTTPStatus
13-
from typing import Any, Callable, Dict, List, Optional, Set, Union
13+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
1414

1515
from aws_lambda_powertools.event_handler import content_types
1616
from aws_lambda_powertools.event_handler.exceptions import ServiceError
@@ -453,27 +453,30 @@ def __init__(
453453
def route(
454454
self,
455455
rule: str,
456-
method: str,
456+
method: Union[str, Union[List[str], Tuple[str]]],
457457
cors: Optional[bool] = None,
458458
compress: bool = False,
459459
cache_control: Optional[str] = None,
460460
):
461461
"""Route decorator includes parameter `method`"""
462462

463463
def register_resolver(func: Callable):
464-
logger.debug(f"Adding route using rule {rule} and method {method.upper()}")
464+
methods = (method,) if isinstance(method, str) else method
465+
logger.debug(f"Adding route using rule {rule} and methods: {','.join((m.upper() for m in methods))}")
465466
if cors is None:
466467
cors_enabled = self._cors_enabled
467468
else:
468469
cors_enabled = cors
469-
self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
470-
route_key = method + rule
471-
if route_key in self._route_keys:
472-
warnings.warn(f"A route like this was already registered. method: '{method}' rule: '{rule}'")
473-
self._route_keys.append(route_key)
474-
if cors_enabled:
475-
logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS")
476-
self._cors_methods.add(method.upper())
470+
471+
for item in methods:
472+
self._routes.append(Route(item, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
473+
route_key = item + rule
474+
if route_key in self._route_keys:
475+
warnings.warn(f"A route like this was already registered. method: '{item}' rule: '{rule}'")
476+
self._route_keys.append(route_key)
477+
if cors_enabled:
478+
logger.debug(f"Registering method {item.upper()} to Allow Methods in CORS")
479+
self._cors_methods.add(item.upper())
477480
return func
478481

479482
return register_resolver
@@ -679,14 +682,13 @@ def __init__(self):
679682
def route(
680683
self,
681684
rule: str,
682-
method: Union[str, List[str]],
685+
method: Union[str, Union[List[str], Tuple[str]]],
683686
cors: Optional[bool] = None,
684687
compress: bool = False,
685688
cache_control: Optional[str] = None,
686689
):
687690
def register_route(func: Callable):
688-
methods = method if isinstance(method, list) else [method]
689-
for item in methods:
690-
self._routes[(rule, item, cors, compress, cache_control)] = func
691+
methods = (method,) if isinstance(method, str) else tuple(method)
692+
self._routes[(rule, methods, cors, compress, cache_control)] = func
691693

692694
return register_route

tests/functional/event_handler/test_api_gateway.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,3 +1021,39 @@ def get_func_another_duplicate():
10211021
# THEN only execute the first registered route
10221022
# AND print warnings
10231023
assert result["statusCode"] == 200
1024+
1025+
1026+
def test_route_multiple_methods():
1027+
# GIVEN a function with http methods passed as a list
1028+
app = ApiGatewayResolver()
1029+
req = "foo"
1030+
get_event = deepcopy(LOAD_GW_EVENT)
1031+
get_event["resource"] = "/accounts/{account_id}"
1032+
get_event["path"] = f"/accounts/{req}"
1033+
1034+
post_event = deepcopy(get_event)
1035+
post_event["httpMethod"] = "POST"
1036+
1037+
put_event = deepcopy(get_event)
1038+
put_event["httpMethod"] = "PUT"
1039+
1040+
lambda_context = {}
1041+
1042+
@app.route(rule="/accounts/<account_id>", method=["GET", "POST"])
1043+
def foo(account_id):
1044+
assert app.lambda_context == lambda_context
1045+
assert account_id == f"{req}"
1046+
return {}
1047+
1048+
# WHEN calling the event handler with the supplied methods
1049+
get_result = app(get_event, lambda_context)
1050+
post_result = app(post_event, lambda_context)
1051+
put_result = app(put_event, lambda_context)
1052+
1053+
# THEN events are processed correctly
1054+
assert get_result["statusCode"] == 200
1055+
assert get_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
1056+
assert post_result["statusCode"] == 200
1057+
assert post_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
1058+
assert put_result["statusCode"] == 404
1059+
assert put_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON

0 commit comments

Comments
 (0)