diff --git a/aws_lambda_powertools/utilities/parser/models/kinesis.py b/aws_lambda_powertools/utilities/parser/models/kinesis.py index ffc89bcbdaa..6fb9a7076b5 100644 --- a/aws_lambda_powertools/utilities/parser/models/kinesis.py +++ b/aws_lambda_powertools/utilities/parser/models/kinesis.py @@ -1,8 +1,13 @@ -from typing import List, Type, Union +import json +import zlib +from typing import Dict, List, Type, Union from pydantic import BaseModel, validator from aws_lambda_powertools.shared.functions import base64_decode +from aws_lambda_powertools.utilities.parser.models.cloudwatch import ( + CloudWatchLogsDecode, +) from aws_lambda_powertools.utilities.parser.types import Literal @@ -28,6 +33,21 @@ class KinesisDataStreamRecord(BaseModel): eventSourceARN: str kinesis: KinesisDataStreamRecordPayload + def decompress_zlib_record_data_as_json(self) -> Dict: + """Decompress Kinesis Record bytes data zlib compressed to JSON""" + if not isinstance(self.kinesis.data, bytes): + raise ValueError("We can only decompress bytes data, not custom models.") + + return json.loads(zlib.decompress(self.kinesis.data, zlib.MAX_WBITS | 32)) + class KinesisDataStreamModel(BaseModel): Records: List[KinesisDataStreamRecord] + + +def extract_cloudwatch_logs_from_event(event: KinesisDataStreamModel) -> List[CloudWatchLogsDecode]: + return [CloudWatchLogsDecode(**record.decompress_zlib_record_data_as_json()) for record in event.Records] + + +def extract_cloudwatch_logs_from_record(record: KinesisDataStreamRecord) -> CloudWatchLogsDecode: + return CloudWatchLogsDecode(**record.decompress_zlib_record_data_as_json()) diff --git a/tests/functional/parser/test_kinesis.py b/tests/functional/parser/test_kinesis.py index 13f1e55b479..6b23bd214a6 100644 --- a/tests/functional/parser/test_kinesis.py +++ b/tests/functional/parser/test_kinesis.py @@ -3,6 +3,7 @@ import pytest from aws_lambda_powertools.utilities.parser import ( + BaseModel, ValidationError, envelopes, event_parser, @@ -11,6 +12,13 @@ KinesisDataStreamModel, KinesisDataStreamRecordPayload, ) +from aws_lambda_powertools.utilities.parser.models.cloudwatch import ( + CloudWatchLogsDecode, +) +from aws_lambda_powertools.utilities.parser.models.kinesis import ( + extract_cloudwatch_logs_from_event, + extract_cloudwatch_logs_from_record, +) from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyKinesisBusiness from tests.functional.utils import load_event @@ -111,3 +119,35 @@ def test_validate_event_does_not_conform_with_model(): event_dict: Any = {"hello": "s"} with pytest.raises(ValidationError): handle_kinesis(event_dict, LambdaContext()) + + +def test_kinesis_stream_event_cloudwatch_logs_data_extraction(): + # GIVEN a KinesisDataStreamModel is instantiated with CloudWatch Logs compressed data + event_dict = load_event("kinesisStreamCloudWatchLogsEvent.json") + stream_data = KinesisDataStreamModel(**event_dict) + single_record = stream_data.Records[0] + + # WHEN we try to extract CloudWatch Logs from KinesisDataStreamRecordPayload model + extracted_logs = extract_cloudwatch_logs_from_event(stream_data) + individual_logs = [extract_cloudwatch_logs_from_record(record) for record in stream_data.Records] + single_log = extract_cloudwatch_logs_from_record(single_record) + + # THEN we should have extracted any potential logs as CloudWatchLogsDecode models + assert len(extracted_logs) == len(individual_logs) + assert isinstance(single_log, CloudWatchLogsDecode) + + +def test_kinesis_stream_event_cloudwatch_logs_data_extraction_fails_with_custom_model(): + # GIVEN a custom model replaces Kinesis Record Data bytes + class DummyModel(BaseModel): + ... + + event_dict = load_event("kinesisStreamCloudWatchLogsEvent.json") + stream_data = KinesisDataStreamModel(**event_dict) + + # WHEN decompress_zlib_record_data_as_json is used + # THEN ValueError should be raised + with pytest.raises(ValueError, match="We can only decompress bytes data"): + for record in stream_data.Records: + record.kinesis.data = DummyModel() + record.decompress_zlib_record_data_as_json()