Skip to content

feat(apigateway): ignore trailing slashes in routes (APIGatewayRestResolver) #1609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/e2e/event_handler/handlers/alb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

app = ALBResolver()

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.


@app.post("/todos")
def todos():
Expand All @@ -22,10 +25,5 @@ def todos():
)


@app.get("/hello")
def hello():
return Response(status_code=200, content_type=content_types.TEXT_PLAIN, body="Hello World")


def lambda_handler(event, context):
return app.resolve(event, context)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

app = APIGatewayHttpResolver()

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.


@app.post("/todos")
def todos():
Expand All @@ -26,10 +29,5 @@ def todos():
)


@app.get("/hello")
def hello():
return Response(status_code=200, content_type=content_types.TEXT_PLAIN, body="Hello World")


def lambda_handler(event, context):
return app.resolve(event, context)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

app = APIGatewayRestResolver()

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.


@app.post("/todos")
def todos():
Expand All @@ -26,10 +29,5 @@ def todos():
)


@app.get("/hello")
def hello():
return Response(status_code=200, content_type=content_types.TEXT_PLAIN, body="Hello World")


def lambda_handler(event, context):
return app.resolve(event, context)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

app = LambdaFunctionUrlResolver()

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.


@app.post("/todos")
def todos():
Expand All @@ -26,10 +29,5 @@ def todos():
)


@app.get("/hello")
def hello():
return Response(status_code=200, content_type=content_types.TEXT_PLAIN, body="Hello World")


def lambda_handler(event, context):
return app.resolve(event, context)
8 changes: 0 additions & 8 deletions tests/e2e/event_handler/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ def _create_api_gateway_http(self, function: Function):
integration=apigwv2integrations.HttpLambdaIntegration("TodosIntegration", function),
)

apigw.add_routes(
path="/hello",
methods=[apigwv2.HttpMethod.GET],
integration=apigwv2integrations.HttpLambdaIntegration("HelloIntegration", function),
)
CfnOutput(self.stack, "APIGatewayHTTPUrl", value=(apigw.url or ""))

def _create_api_gateway_rest(self, function: Function):
Expand All @@ -77,9 +72,6 @@ def _create_api_gateway_rest(self, function: Function):
todos = apigw.root.add_resource("todos")
todos.add_method("POST", apigwv1.LambdaIntegration(function, proxy=True))

hello = apigw.root.add_resource("hello")
hello.add_method("GET", apigwv1.LambdaIntegration(function, proxy=True))

CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url)

def _create_lambda_function_url(self, function: Function):
Expand Down
38 changes: 21 additions & 17 deletions tests/e2e/event_handler/test_paths_ending_with_slash.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,62 +34,66 @@ def lambda_function_url_endpoint(infrastructure: dict) -> str:


def test_api_gateway_rest_trailing_slash(apigw_rest_endpoint):
# GIVEN
url = f"{apigw_rest_endpoint}hello/"
# GIVEN API URL ends in a trailing slash
url = f"{apigw_rest_endpoint}todos/"
body = "Hello World"
status_code = 200

# WHEN
response = data_fetcher.get_http_response(
Request(
method="GET",
method="POST",
url=url,
json={"body": body},
)
)

# THEN
assert response.status_code == status_code
# response.content is a binary string, needs to be decoded to compare with the real string
assert response.content.decode("ascii") == body
# THEN expect a HTTP 200 response
assert response.status_code == 200


def test_api_gateway_http_trailing_slash(apigw_http_endpoint):
# GIVEN the URL for the API ends in a trailing slash API gateway should return a 404
url = f"{apigw_http_endpoint}hello/"
url = f"{apigw_http_endpoint}todos/"
body = "Hello World"

# WHEN
# WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="GET",
method="POST",
url=url,
json={"body": body},
)
)


def test_lambda_function_url_trailing_slash(lambda_function_url_endpoint):
# GIVEN the URL for the API ends in a trailing slash it should behave as if there was not one
url = f"{lambda_function_url_endpoint}hello/" # the function url endpoint already has the trailing /
url = f"{lambda_function_url_endpoint}todos/" # the function url endpoint already has the trailing /
body = "Hello World"

# WHEN
# WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="GET",
method="POST",
url=url,
json={"body": body},
)
)


def test_alb_url_trailing_slash(alb_multi_value_header_listener_endpoint):
# GIVEN url has a trailing slash - it should behave as if there was not one
url = f"{alb_multi_value_header_listener_endpoint}/hello/"
url = f"{alb_multi_value_header_listener_endpoint}/todos/"
body = "Hello World"

# WHEN
# WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="GET",
method="POST",
url=url,
json={"body": body},
)
)