Skip to content

Commit 4652237

Browse files
committed
feat(batch): add Kinesis Data streams support
1 parent 5ab2ec7 commit 4652237

File tree

4 files changed

+120
-2
lines changed

4 files changed

+120
-2
lines changed

aws_lambda_powertools/utilities/batch/base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class EventType(Enum):
1818
SQS = "SQS"
19-
KinesisDataStream = "KinesisDataStream"
19+
KinesisDataStreams = "KinesisDataStreams"
2020
DynamoDB = "DynamoDB"
2121

2222

@@ -169,7 +169,11 @@ def __init__(self, event_type: EventType):
169169
# refactor: Bring boto3 etc. for deleting permanent exceptions
170170
self.event_type = event_type
171171
self.batch_response = self.DEFAULT_RESPONSE
172-
172+
self._COLLECTOR_MAPPING = {
173+
EventType.SQS: self._collect_sqs_failures,
174+
EventType.KinesisDataStreams: self._collect_kinesis_failures,
175+
EventType.DynamoDB: self._collect_dynamodb_failures,
176+
}
173177
super().__init__()
174178

175179
def response(self):
@@ -222,4 +226,13 @@ def _get_messages_to_report(self) -> Dict[str, str]:
222226
Format messages to use in batch deletion
223227
"""
224228
# Refactor: get message per event type
229+
return self._COLLECTOR_MAPPING[self.event_type]()
230+
231+
def _collect_sqs_failures(self):
225232
return {msg["receiptHandle"]: msg["messageId"] for msg in self.fail_messages}
233+
234+
def _collect_kinesis_failures(self):
235+
return {msg["eventID"]: msg["kinesis"]["sequenceNumber"] for msg in self.fail_messages}
236+
237+
def _collect_dynamodb_failures(self):
238+
...

tests/functional/test_utilities_batch.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from random import randint
12
from typing import Callable
23
from unittest.mock import patch
34

@@ -8,6 +9,7 @@
89
from aws_lambda_powertools.utilities.batch import PartialSQSProcessor, batch_processor, sqs_batch_processor
910
from aws_lambda_powertools.utilities.batch.base import BatchProcessor, EventType
1011
from aws_lambda_powertools.utilities.batch.exceptions import SQSBatchProcessingError
12+
from tests.functional.utils import decode_kinesis_data, str_to_b64
1113

1214

1315
@pytest.fixture(scope="module")
@@ -28,6 +30,31 @@ def factory(body: str):
2830
return factory
2931

3032

33+
@pytest.fixture(scope="module")
34+
def kinesis_event_factory() -> Callable:
35+
def factory(body: str):
36+
seq = "".join(str(randint(0, 9)) for _ in range(52))
37+
partition_key = str(randint(1, 9))
38+
return {
39+
"kinesis": {
40+
"kinesisSchemaVersion": "1.0",
41+
"partitionKey": partition_key,
42+
"sequenceNumber": seq,
43+
"data": str_to_b64(body),
44+
"approximateArrivalTimestamp": 1545084650.987,
45+
},
46+
"eventSource": "aws:kinesis",
47+
"eventVersion": "1.0",
48+
"eventID": f"shardId-000000000006:{seq}",
49+
"eventName": "aws:kinesis:record",
50+
"invokeIdentityArn": "arn:aws:iam::123456789012:role/lambda-role",
51+
"awsRegion": "us-east-2",
52+
"eventSourceARN": "arn:aws:kinesis:us-east-2:123456789012:stream/lambda-stream",
53+
}
54+
55+
return factory
56+
57+
3158
@pytest.fixture(scope="module")
3259
def record_handler() -> Callable:
3360
def handler(record):
@@ -39,6 +66,17 @@ def handler(record):
3966
return handler
4067

4168

69+
@pytest.fixture(scope="module")
70+
def kinesis_record_handler() -> Callable:
71+
def handler(record):
72+
body = decode_kinesis_data(record)
73+
if "fail" in body:
74+
raise Exception("Failed to process record.")
75+
return body
76+
77+
return handler
78+
79+
4280
@pytest.fixture(scope="module")
4381
def config() -> Config:
4482
return Config(region_name="us-east-1")
@@ -366,3 +404,61 @@ def test_batch_processor_context_with_failure(sqs_event_factory, record_handler)
366404
assert processed_messages[1] == ("success", second_record["body"], second_record)
367405
assert len(batch.fail_messages) == 1
368406
assert batch.response() == {"batchItemFailures": [{first_record["receiptHandle"]: first_record["messageId"]}]}
407+
408+
409+
def test_batch_processor_kinesis_context_success_only(kinesis_event_factory, kinesis_record_handler):
410+
# GIVEN
411+
first_record = kinesis_event_factory("success")
412+
second_record = kinesis_event_factory("success")
413+
records = [first_record, second_record]
414+
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)
415+
416+
# WHEN
417+
with processor(records, kinesis_record_handler) as batch:
418+
processed_messages = batch.process()
419+
420+
# THEN
421+
assert processed_messages == [
422+
("success", decode_kinesis_data(first_record), first_record),
423+
("success", decode_kinesis_data(second_record), second_record),
424+
]
425+
426+
assert batch.response() == {"batchItemFailures": []}
427+
428+
429+
def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kinesis_record_handler):
430+
# GIVEN
431+
first_record = kinesis_event_factory("failure")
432+
second_record = kinesis_event_factory("success")
433+
records = [first_record, second_record]
434+
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)
435+
436+
# WHEN
437+
with processor(records, kinesis_record_handler) as batch:
438+
processed_messages = batch.process()
439+
440+
# THEN
441+
assert processed_messages[1] == ("success", decode_kinesis_data(second_record), second_record)
442+
assert len(batch.fail_messages) == 1
443+
assert batch.response() == {
444+
"batchItemFailures": [{first_record["eventID"]: first_record["kinesis"]["sequenceNumber"]}]
445+
}
446+
447+
448+
def test_batch_processor_kinesis_middleware_with_failure(kinesis_event_factory, kinesis_record_handler):
449+
# GIVEN
450+
first_record = kinesis_event_factory("failure")
451+
second_record = kinesis_event_factory("success")
452+
event = {"Records": [first_record, second_record]}
453+
454+
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)
455+
456+
@batch_processor(record_handler=kinesis_record_handler, processor=processor)
457+
def lambda_handler(event, context):
458+
return processor.response()
459+
460+
# WHEN
461+
result = lambda_handler(event, {})
462+
463+
# THEN
464+
assert len(result["batchItemFailures"]) == 1

tests/functional/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import json
23
from pathlib import Path
34
from typing import Any
@@ -6,3 +7,11 @@
67
def load_event(file_name: str) -> Any:
78
path = Path(str(Path(__file__).parent.parent) + "/events/" + file_name)
89
return json.loads(path.read_text())
10+
11+
12+
def str_to_b64(data: str) -> str:
13+
return base64.b64encode(data.encode()).decode("utf-8")
14+
15+
16+
def decode_kinesis_data(data: dict) -> str:
17+
return base64.b64decode(data["kinesis"]["data"].encode()).decode("utf-8")

tests/utils.py

Whitespace-only changes.

0 commit comments

Comments
 (0)