Skip to content

feat(event_handler): support to enable or disable compression in custom responses #2544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
body: Union[str, bytes, None] = None,
headers: Optional[Dict[str, Union[str, List[str]]]] = None,
cookies: Optional[List[Cookie]] = None,
compress: Optional[bool] = None,
):
"""

Expand All @@ -199,6 +200,7 @@ def __init__(
self.base64_encoded = False
self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {}
self.cookies = cookies or []
self.compress = compress
if content_type:
self.headers.setdefault("Content-Type", content_type)

Expand Down Expand Up @@ -233,6 +235,38 @@ def _add_cache_control(self, cache_control: str):
cache_control = cache_control if self.response.status_code == 200 else "no-cache"
self.response.headers["Cache-Control"] = cache_control

@staticmethod
def _has_compression_enabled(
route_compression: bool, response_compression: Optional[bool], event: BaseProxyEvent
) -> bool:
"""
Checks if compression is enabled.

NOTE: Response compression takes precedence.

Parameters
----------
route_compression: bool, optional
A boolean indicating whether compression is enabled or not in the route setting.
response_compression: bool, optional
A boolean indicating whether compression is enabled or not in the response setting.
event: BaseProxyEvent
The event object containing the request details.

Returns
-------
bool
True if compression is enabled and the "gzip" encoding is accepted, False otherwise.
"""
encoding: str = event.get_header_value(name="accept-encoding", default_value="", case_sensitive=False) # type: ignore[assignment] # noqa: E501
if "gzip" in encoding:
if response_compression is not None:
return response_compression # e.g., Response(compress=False/True))
if route_compression:
return True # e.g., @app.get(compress=True)

return False

def _compress(self):
"""Compress the response body, but only if `Accept-Encoding` headers includes gzip."""
self.response.headers["Content-Encoding"] = "gzip"
Expand All @@ -250,7 +284,9 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
self._add_cors(event, cors or CORSConfig())
if self.route.cache_control:
self._add_cache_control(self.route.cache_control)
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""):
if self._has_compression_enabled(
route_compression=self.route.compress, response_compression=self.response.compress, event=event
):
self._compress()

def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
query_string_parameters=self.query_string_parameters, name=name, default_value=default_value
)

# Maintenance: missing @overload to ensure return type is a str when default_value is set
def get_header_value(
self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False
) -> Optional[str]:
Expand Down
15 changes: 12 additions & 3 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,24 @@ You can use the `Response` class to have full control over the response. For exa

### Compress

You can compress with gzip and base64 encode your responses via `compress` parameter.
You can compress with gzip and base64 encode your responses via `compress` parameter. You have the option to pass the `compress` parameter when working with a specific route or using the Response object.

???+ info
The `compress` parameter used in the Response object takes precedence over the one used in the route.

???+ warning
The client must send the `Accept-Encoding` header, otherwise a normal response will be sent.

=== "compressing_responses.py"
=== "compressing_responses_using_route.py"

```python hl_lines="17 27"
--8<-- "examples/event_handler_rest/src/compressing_responses.py"
--8<-- "examples/event_handler_rest/src/compressing_responses_using_route.py"
```

=== "compressing_responses_using_response.py"

```python hl_lines="24"
--8<-- "examples/event_handler_rest/src/compressing_responses_using_response.py"
```

=== "compressing_responses.json"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import requests

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler import (
APIGatewayRestResolver,
Response,
content_types,
)
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.utilities.typing import LambdaContext

tracer = Tracer()
logger = Logger()
app = APIGatewayRestResolver()


@app.get("/todos")
@tracer.capture_method
def get_todos():
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
todos.raise_for_status()

# for brevity, we'll limit to the first 10 only
return Response(status_code=200, content_type=content_types.APPLICATION_JSON, body=todos.json()[:10], compress=True)


# You can continue to use other utilities just as before
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST)
@tracer.capture_lambda_handler
def lambda_handler(event: dict, context: LambdaContext) -> dict:
return app.resolve(event, context)
52 changes: 52 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,58 @@ def test_cors_preflight_body_is_empty_not_null():
assert result["body"] == ""


def test_override_route_compress_parameter():
# GIVEN a function that has compress=True
# AND an event with a "Accept-Encoding" that include gzip
# AND the Response object with compress=False
app = ApiGatewayResolver()
mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}
expected_value = '{"test": "value"}'

@app.get("/my/request", compress=True)
def with_compression() -> Response:
return Response(200, content_types.APPLICATION_JSON, expected_value, compress=False)

def handler(event, context):
return app.resolve(event, context)

# WHEN calling the event handler
result = handler(mock_event, None)

# THEN then the response is not compressed
assert result["isBase64Encoded"] is False
assert result["body"] == expected_value
assert result["multiValueHeaders"].get("Content-Encoding") is None


def test_response_with_compress_enabled():
# GIVEN a function
# AND an event with a "Accept-Encoding" that include gzip
# AND the Response object with compress=True
app = ApiGatewayResolver()
mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}
expected_value = '{"test": "value"}'

@app.get("/my/request")
def route_without_compression() -> Response:
return Response(200, content_types.APPLICATION_JSON, expected_value, compress=True)

def handler(event, context):
return app.resolve(event, context)

# WHEN calling the event handler
result = handler(mock_event, None)

# THEN then gzip the response and base64 encode as a string
assert result["isBase64Encoded"] is True
body = result["body"]
assert isinstance(body, str)
decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8")
assert decompress == expected_value
headers = result["multiValueHeaders"]
assert headers["Content-Encoding"] == ["gzip"]


def test_compress():
# GIVEN a function that has compress=True
# AND an event with a "Accept-Encoding" that include gzip
Expand Down