@@ -91,28 +91,46 @@ def __init__(
91
91
if content_type :
92
92
self .headers .setdefault ("Content-Type" , content_type )
93
93
94
- def add_cors (self , cors : CORSConfig ):
95
- self .headers .update (cors .to_dict ())
96
94
97
- def add_cache_control (self , cache_control : str ):
98
- self .headers ["Cache-Control" ] = cache_control if self .status_code == 200 else "no-cache"
95
+ class ResponseBuilder :
96
+ def __init__ (self , response : Response , route : Route = None ):
97
+ self .response = response
98
+ self .route = route
99
99
100
- def compress (self ):
101
- self .headers ["Content-Encoding" ] = "gzip"
102
- if isinstance (self .body , str ):
103
- self .body = bytes (self .body , "utf-8" )
104
- gzip = zlib .compressobj (9 , zlib .DEFLATED , zlib .MAX_WBITS | 16 )
105
- self .body = gzip .compress (self .body ) + gzip .flush ()
100
+ def _add_cors (self , cors : CORSConfig ):
101
+ self .response .headers .update (cors .to_dict ())
102
+
103
+ def _add_cache_control (self , cache_control : str ):
104
+ self .response .headers ["Cache-Control" ] = cache_control if self .response .status_code == 200 else "no-cache"
106
105
107
- def to_dict (self ) -> Dict [str , Any ]:
108
- if isinstance (self .body , bytes ):
109
- self .base64_encoded = True
110
- self .body = base64 .b64encode (self .body ).decode ()
106
+ def _compress (self ):
107
+ self .response .headers ["Content-Encoding" ] = "gzip"
108
+ if isinstance (self .response .body , str ):
109
+ self .response .body = bytes (self .response .body , "utf-8" )
110
+ gzip = zlib .compressobj (9 , zlib .DEFLATED , zlib .MAX_WBITS | 16 )
111
+ self .response .body = gzip .compress (self .response .body ) + gzip .flush ()
112
+
113
+ def _route (self , event : BaseProxyEvent , cors : CORSConfig = None ):
114
+ if self .route is None :
115
+ return
116
+ if self .route .cors :
117
+ self ._add_cors (cors or CORSConfig ())
118
+ if self .route .cache_control :
119
+ self ._add_cache_control (self .route .cache_control )
120
+ if self .route .compress and "gzip" in (event .get_header_value ("accept-encoding" , "" ) or "" ):
121
+ self ._compress ()
122
+
123
+ def build (self , event : BaseProxyEvent , cors : CORSConfig = None ) -> Dict [str , Any ]:
124
+ self ._route (event , cors )
125
+
126
+ if isinstance (self .response .body , bytes ):
127
+ self .response .base64_encoded = True
128
+ self .response .body = base64 .b64encode (self .response .body ).decode ()
111
129
return {
112
- "statusCode" : self .status_code ,
113
- "headers" : self .headers ,
114
- "body" : self .body ,
115
- "isBase64Encoded" : self .base64_encoded ,
130
+ "statusCode" : self .response . status_code ,
131
+ "headers" : self .response . headers ,
132
+ "body" : self .response . body ,
133
+ "isBase64Encoded" : self .response . base64_encoded ,
116
134
}
117
135
118
136
@@ -153,18 +171,7 @@ def register_resolver(func: Callable):
153
171
def resolve (self , event , context ) -> Dict [str , Any ]:
154
172
self .current_event = self ._to_data_class (event )
155
173
self .lambda_context = context
156
- route , response = self ._find_route (self .current_event .http_method .upper (), self .current_event .path )
157
- if route is None : # No matching route was found
158
- return response .to_dict ()
159
-
160
- if route .cors :
161
- response .add_cors (self ._cors or CORSConfig ())
162
- if route .cache_control :
163
- response .add_cache_control (route .cache_control )
164
- if route .compress and "gzip" in (self .current_event .get_header_value ("accept-encoding" ) or "" ):
165
- response .compress ()
166
-
167
- return response .to_dict ()
174
+ return self ._resolve_response ().build (self .current_event , self ._cors )
168
175
169
176
@staticmethod
170
177
def _compile_regex (rule : str ):
@@ -178,30 +185,36 @@ def _to_data_class(self, event: Dict) -> BaseProxyEvent:
178
185
return APIGatewayProxyEventV2 (event )
179
186
return ALBEvent (event )
180
187
181
- def _find_route (self , method : str , path : str ) -> Tuple [Optional [Route ], Response ]:
188
+ def _resolve_response (self ) -> ResponseBuilder :
189
+ method = self .current_event .http_method .upper ()
190
+ path = self .current_event .path
182
191
for route in self ._routes :
183
192
if method != route .method :
184
193
continue
185
194
match : Optional [re .Match ] = route .rule .match (path )
186
195
if match :
187
196
return self ._call_route (route , match .groupdict ())
188
197
198
+ return self .not_found (method , path )
199
+
200
+ def not_found (self , method : str , path : str ) -> ResponseBuilder :
189
201
headers = {}
190
202
if self ._cors :
191
203
headers .update (self ._cors .to_dict ())
192
204
if method == "OPTIONS" : # Preflight
193
205
headers ["Access-Control-Allow-Methods" ] = "," .join (sorted (self ._cors_methods ))
194
- return None , Response (status_code = 204 , content_type = None , body = None , headers = headers )
195
-
196
- return None , Response (
197
- status_code = 404 ,
198
- content_type = "application/json" ,
199
- body = json .dumps ({"message" : f"No route found for '{ method } .{ path } '" }),
200
- headers = headers ,
206
+ return ResponseBuilder (Response (status_code = 204 , content_type = None , body = None , headers = headers ))
207
+ return ResponseBuilder (
208
+ Response (
209
+ status_code = 404 ,
210
+ content_type = "application/json" ,
211
+ body = json .dumps ({"message" : f"No route found for '{ method } .{ path } '" }),
212
+ headers = headers ,
213
+ )
201
214
)
202
215
203
- def _call_route (self , route : Route , args : Dict [str , str ]) -> Tuple [ Route , Response ] :
204
- return route , self ._to_response (route .func (** args ))
216
+ def _call_route (self , route : Route , args : Dict [str , str ]) -> ResponseBuilder :
217
+ return ResponseBuilder ( self ._to_response (route .func (** args )), route )
205
218
206
219
@staticmethod
207
220
def _to_response (result : Union [Tuple [int , str , Union [bytes , str ]], Dict , Response ]) -> Response :
0 commit comments