Skip to content

Commit 8494cde

Browse files
author
Michael Brewer
committed
refactor(event-handler): Add ResponseBuilder
Change: - Add new ResponseBuilder to replace the Tuple[Route, Response] - Refactor code to be easier to maintain
1 parent e24a985 commit 8494cde

File tree

2 files changed

+68
-47
lines changed

2 files changed

+68
-47
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -91,28 +91,46 @@ def __init__(
9191
if content_type:
9292
self.headers.setdefault("Content-Type", content_type)
9393

94-
def add_cors(self, cors: CORSConfig):
95-
self.headers.update(cors.to_dict())
9694

97-
def add_cache_control(self, cache_control: str):
98-
self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache"
95+
class ResponseBuilder:
96+
def __init__(self, response: Response, route: Route = None):
97+
self.response = response
98+
self.route = route
9999

100-
def compress(self):
101-
self.headers["Content-Encoding"] = "gzip"
102-
if isinstance(self.body, str):
103-
self.body = bytes(self.body, "utf-8")
104-
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
105-
self.body = gzip.compress(self.body) + gzip.flush()
100+
def _add_cors(self, cors: CORSConfig):
101+
self.response.headers.update(cors.to_dict())
102+
103+
def _add_cache_control(self, cache_control: str):
104+
self.response.headers["Cache-Control"] = cache_control if self.response.status_code == 200 else "no-cache"
106105

107-
def to_dict(self) -> Dict[str, Any]:
108-
if isinstance(self.body, bytes):
109-
self.base64_encoded = True
110-
self.body = base64.b64encode(self.body).decode()
106+
def _compress(self):
107+
self.response.headers["Content-Encoding"] = "gzip"
108+
if isinstance(self.response.body, str):
109+
self.response.body = bytes(self.response.body, "utf-8")
110+
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
111+
self.response.body = gzip.compress(self.response.body) + gzip.flush()
112+
113+
def _route(self, event: BaseProxyEvent, cors: CORSConfig = None):
114+
if self.route is None:
115+
return
116+
if self.route.cors:
117+
self._add_cors(cors or CORSConfig())
118+
if self.route.cache_control:
119+
self._add_cache_control(self.route.cache_control)
120+
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""):
121+
self._compress()
122+
123+
def build(self, event: BaseProxyEvent, cors: CORSConfig = None) -> Dict[str, Any]:
124+
self._route(event, cors)
125+
126+
if isinstance(self.response.body, bytes):
127+
self.response.base64_encoded = True
128+
self.response.body = base64.b64encode(self.response.body).decode()
111129
return {
112-
"statusCode": self.status_code,
113-
"headers": self.headers,
114-
"body": self.body,
115-
"isBase64Encoded": self.base64_encoded,
130+
"statusCode": self.response.status_code,
131+
"headers": self.response.headers,
132+
"body": self.response.body,
133+
"isBase64Encoded": self.response.base64_encoded,
116134
}
117135

118136

@@ -153,18 +171,7 @@ def register_resolver(func: Callable):
153171
def resolve(self, event, context) -> Dict[str, Any]:
154172
self.current_event = self._to_data_class(event)
155173
self.lambda_context = context
156-
route, response = self._find_route(self.current_event.http_method.upper(), self.current_event.path)
157-
if route is None: # No matching route was found
158-
return response.to_dict()
159-
160-
if route.cors:
161-
response.add_cors(self._cors or CORSConfig())
162-
if route.cache_control:
163-
response.add_cache_control(route.cache_control)
164-
if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""):
165-
response.compress()
166-
167-
return response.to_dict()
174+
return self._resolve_response().build(self.current_event, self._cors)
168175

169176
@staticmethod
170177
def _compile_regex(rule: str):
@@ -178,30 +185,36 @@ def _to_data_class(self, event: Dict) -> BaseProxyEvent:
178185
return APIGatewayProxyEventV2(event)
179186
return ALBEvent(event)
180187

181-
def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]:
188+
def _resolve_response(self) -> ResponseBuilder:
189+
method = self.current_event.http_method.upper()
190+
path = self.current_event.path
182191
for route in self._routes:
183192
if method != route.method:
184193
continue
185194
match: Optional[re.Match] = route.rule.match(path)
186195
if match:
187196
return self._call_route(route, match.groupdict())
188197

198+
return self.not_found(method, path)
199+
200+
def not_found(self, method: str, path: str) -> ResponseBuilder:
189201
headers = {}
190202
if self._cors:
191203
headers.update(self._cors.to_dict())
192204
if method == "OPTIONS": # Preflight
193205
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
194-
return None, Response(status_code=204, content_type=None, body=None, headers=headers)
195-
196-
return None, Response(
197-
status_code=404,
198-
content_type="application/json",
199-
body=json.dumps({"message": f"No route found for '{method}.{path}'"}),
200-
headers=headers,
206+
return ResponseBuilder(Response(status_code=204, content_type=None, body=None, headers=headers))
207+
return ResponseBuilder(
208+
Response(
209+
status_code=404,
210+
content_type="application/json",
211+
body=json.dumps({"message": f"No route found for '{method}.{path}'"}),
212+
headers=headers,
213+
)
201214
)
202215

203-
def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]:
204-
return route, self._to_response(route.func(**args))
216+
def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
217+
return ResponseBuilder(self._to_response(route.func(**args)), route)
205218

206219
@staticmethod
207220
def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response:

tests/functional/event_handler/test_api_gateway.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
from pathlib import Path
66
from typing import Dict, Tuple
77

8-
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, CORSConfig, ProxyEventType, Response
8+
from aws_lambda_powertools.event_handler.api_gateway import (
9+
ApiGatewayResolver,
10+
CORSConfig,
11+
ProxyEventType,
12+
Response,
13+
ResponseBuilder,
14+
)
915
from aws_lambda_powertools.shared.json_encoder import Encoder
1016
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
1117
from tests.functional.utils import load_event
@@ -106,14 +112,14 @@ def test_include_rule_matching():
106112
@app.get("/<name>/<my_id>")
107113
def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]:
108114
assert name == "my"
109-
return 200, "plain/html", my_id
115+
return 200, TEXT_HTML, my_id
110116

111117
# WHEN calling the event handler
112118
result = app(LOAD_GW_EVENT, {})
113119

114120
# THEN
115121
assert result["statusCode"] == 200
116-
assert result["headers"]["Content-Type"] == "plain/html"
122+
assert result["headers"]["Content-Type"] == TEXT_HTML
117123
assert result["body"] == "path"
118124

119125

@@ -389,14 +395,16 @@ def another_one():
389395
def test_no_content_response():
390396
# GIVEN a response with no content-type or body
391397
response = Response(status_code=204, content_type=None, body=None, headers=None)
398+
response_builder = ResponseBuilder(response)
392399

393400
# WHEN calling to_dict
394-
result = response.to_dict()
401+
result = response_builder.build(APIGatewayProxyEvent(LOAD_GW_EVENT))
395402

396403
# THEN return an None body and no Content-Type header
404+
assert result["statusCode"] == response.status_code
397405
assert result["body"] is None
398-
assert result["statusCode"] == 204
399-
assert "Content-Type" not in result["headers"]
406+
headers = result["headers"]
407+
assert "Content-Type" not in headers
400408

401409

402410
def test_no_matches_with_cors():
@@ -413,7 +421,7 @@ def test_no_matches_with_cors():
413421
assert "Access-Control-Allow-Origin" in result["headers"]
414422

415423

416-
def test_preflight():
424+
def test_cors_preflight():
417425
# GIVEN an event for an OPTIONS call that does not match any of the given routes
418426
# AND cors is enabled
419427
app = ApiGatewayResolver(cors=CORSConfig())

0 commit comments

Comments
 (0)