Skip to content

Commit f6a5b2a

Browse files
author
Michael Brewer
authored
feat(logger): add get_correlation_id method (#516)
1 parent 89e8151 commit f6a5b2a

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

aws_lambda_powertools/logging/logger.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,16 +387,26 @@ def structure_logs(self, append: bool = False, **keys):
387387
formatter = self.logger_formatter or LambdaPowertoolsFormatter(**log_keys) # type: ignore
388388
self.registered_handler.setFormatter(formatter)
389389

390-
def set_correlation_id(self, value: str):
390+
def set_correlation_id(self, value: Optional[str]):
391391
"""Sets the correlation_id in the logging json
392392
393393
Parameters
394394
----------
395-
value : str
396-
Value for the correlation id
395+
value : str, optional
396+
Value for the correlation id. None will remove the correlation_id
397397
"""
398398
self.append_keys(correlation_id=value)
399399

400+
def get_correlation_id(self) -> Optional[str]:
401+
"""Gets the correlation_id in the logging json
402+
403+
Returns
404+
-------
405+
str, optional
406+
Value for the correlation id
407+
"""
408+
return self.registered_formatter.log_format.get("correlation_id")
409+
400410
@staticmethod
401411
def _get_log_level(level: Union[str, int, None]) -> Union[str, int]:
402412
"""Returns preferred log level set by the customer in upper case"""

tests/functional/test_logger.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,18 @@ def handler(event, _):
460460
assert request_id == log["correlation_id"]
461461

462462

463+
def test_logger_get_correlation_id(lambda_context, stdout, service_name):
464+
# GIVEN a logger with a correlation_id set
465+
logger = Logger(service=service_name, stream=stdout)
466+
logger.set_correlation_id("foo")
467+
468+
# WHEN calling get_correlation_id
469+
correlation_id = logger.get_correlation_id()
470+
471+
# THEN it should return the correlation_id
472+
assert "foo" == correlation_id
473+
474+
463475
def test_logger_set_correlation_id_path(lambda_context, stdout, service_name):
464476
# GIVEN
465477
logger = Logger(service=service_name, stream=stdout)

0 commit comments

Comments
 (0)