|
| 1 | +import requests |
| 2 | + |
| 3 | +from aws_lambda_powertools import Logger |
| 4 | +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response |
| 5 | +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware |
| 6 | + |
| 7 | +app = APIGatewayRestResolver() |
| 8 | +logger = Logger() |
| 9 | + |
| 10 | + |
| 11 | +class CorrelationIdMiddleware(BaseMiddlewareHandler): |
| 12 | + def __init__(self, header: str): # (1)! |
| 13 | + """Extract and inject correlation ID in response |
| 14 | +
|
| 15 | + Parameters |
| 16 | + ---------- |
| 17 | + header : str |
| 18 | + HTTP Header to extract correlation ID |
| 19 | + """ |
| 20 | + super().__init__() |
| 21 | + self.header = header |
| 22 | + |
| 23 | + def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: # (2)! |
| 24 | + request_id = app.current_event.request_context.request_id |
| 25 | + correlation_id = app.current_event.get_header_value( |
| 26 | + name=self.header, |
| 27 | + default_value=request_id, |
| 28 | + ) |
| 29 | + |
| 30 | + response = next_middleware(app) # (3)! |
| 31 | + response.headers[self.header] = correlation_id |
| 32 | + |
| 33 | + return response |
| 34 | + |
| 35 | + |
| 36 | +@app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")]) # (4)! |
| 37 | +def get_todos(): |
| 38 | + todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos") |
| 39 | + todos.raise_for_status() |
| 40 | + |
| 41 | + # for brevity, we'll limit to the first 10 only |
| 42 | + return {"todos": todos.json()[:10]} |
| 43 | + |
| 44 | + |
| 45 | +@logger.inject_lambda_context |
| 46 | +def lambda_handler(event, context): |
| 47 | + return app.resolve(event, context) |
0 commit comments