Skip to content

Commit 1822456

Browse files
committed
feat(batch): use event source data classes by default
1 parent 139f52b commit 1822456

File tree

4 files changed

+62
-41
lines changed

4 files changed

+62
-41
lines changed

aws_lambda_powertools/utilities/batch/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
Batch processing utility
55
"""
66

7-
from .base import BasePartialProcessor, batch_processor
8-
from .sqs import PartialSQSProcessor, sqs_batch_processor
7+
from aws_lambda_powertools.utilities.batch.base import BasePartialProcessor, BatchProcessor, EventType, batch_processor
8+
from aws_lambda_powertools.utilities.batch.sqs import PartialSQSProcessor, sqs_batch_processor
99

10-
__all__ = ("BasePartialProcessor", "PartialSQSProcessor", "batch_processor", "sqs_batch_processor")
10+
__all__ = (
11+
"BatchProcessor",
12+
"BasePartialProcessor",
13+
"EventType",
14+
"PartialSQSProcessor",
15+
"batch_processor",
16+
"sqs_batch_processor",
17+
)

aws_lambda_powertools/utilities/batch/base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from typing import Any, Callable, Dict, List, Optional, Tuple
1111

1212
from aws_lambda_powertools.middleware_factory import lambda_handler_decorator
13+
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord
14+
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord
15+
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
1316

1417
logger = logging.getLogger(__name__)
1518

@@ -174,6 +177,12 @@ def __init__(self, event_type: EventType):
174177
EventType.KinesisDataStreams: self._collect_kinesis_failures,
175178
EventType.DynamoDBStreams: self._collect_dynamodb_failures,
176179
}
180+
self._DATA_CLASS_MAPPING = {
181+
EventType.SQS: SQSRecord,
182+
EventType.KinesisDataStreams: KinesisStreamRecord,
183+
EventType.DynamoDBStreams: DynamoDBRecord,
184+
}
185+
177186
super().__init__()
178187

179188
def response(self):
@@ -198,7 +207,8 @@ def _process_record(self, record) -> Tuple:
198207
An object to be processed.
199208
"""
200209
try:
201-
result = self.handler(record=record)
210+
data = self._DATA_CLASS_MAPPING[self.event_type](record)
211+
result = self.handler(record=data)
202212
return self.success_handler(record=record, result=result)
203213
except Exception:
204214
return self.failure_handler(record=record, exception=sys.exc_info())
@@ -225,7 +235,6 @@ def _get_messages_to_report(self) -> Dict[str, str]:
225235
"""
226236
Format messages to use in batch deletion
227237
"""
228-
# Refactor: get message per event type
229238
return self._COLLECTOR_MAPPING[self.event_type]()
230239

231240
def _collect_sqs_failures(self):

tests/functional/test_utilities_batch.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from aws_lambda_powertools.utilities.batch import PartialSQSProcessor, batch_processor, sqs_batch_processor
1010
from aws_lambda_powertools.utilities.batch.base import BatchProcessor, EventType
1111
from aws_lambda_powertools.utilities.batch.exceptions import SQSBatchProcessingError
12-
from tests.functional.utils import decode_kinesis_data, str_to_b64
12+
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord
13+
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord
14+
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
15+
from tests.functional.utils import b64_to_str, str_to_b64
1316

1417

1518
@pytest.fixture(scope="module")
@@ -90,8 +93,8 @@ def handler(record):
9093

9194
@pytest.fixture(scope="module")
9295
def kinesis_record_handler() -> Callable:
93-
def handler(record):
94-
body = decode_kinesis_data(record)
96+
def handler(record: KinesisStreamRecord):
97+
body = b64_to_str(record.kinesis.data)
9598
if "fail" in body:
9699
raise Exception("Failed to process record.")
97100
return body
@@ -101,8 +104,8 @@ def handler(record):
101104

102105
@pytest.fixture(scope="module")
103106
def dynamodb_record_handler() -> Callable:
104-
def handler(record):
105-
body = record["dynamodb"]["NewImage"]["message"]["S"]
107+
def handler(record: DynamoDBRecord):
108+
body = record.dynamodb.new_image.get("message").get_value
106109
if "fail" in body:
107110
raise Exception("Failed to process record.")
108111
return body
@@ -366,9 +369,9 @@ def test_partial_sqs_processor_context_only_failure(sqs_event_factory, record_ha
366369

367370
def test_batch_processor_middleware_success_only(sqs_event_factory, record_handler):
368371
# GIVEN
369-
first_record = sqs_event_factory("success")
370-
second_record = sqs_event_factory("success")
371-
event = {"Records": [first_record, second_record]}
372+
first_record = SQSRecord(sqs_event_factory("success"))
373+
second_record = SQSRecord(sqs_event_factory("success"))
374+
event = {"Records": [first_record.raw_event, second_record.raw_event]}
372375

373376
processor = BatchProcessor(event_type=EventType.SQS)
374377

@@ -385,9 +388,9 @@ def lambda_handler(event, context):
385388

386389
def test_batch_processor_middleware_with_failure(sqs_event_factory, record_handler):
387390
# GIVEN
388-
first_record = sqs_event_factory("fail")
389-
second_record = sqs_event_factory("success")
390-
event = {"Records": [first_record, second_record]}
391+
first_record = SQSRecord(sqs_event_factory("fail"))
392+
second_record = SQSRecord(sqs_event_factory("success"))
393+
event = {"Records": [first_record.raw_event, second_record.raw_event]}
391394

392395
processor = BatchProcessor(event_type=EventType.SQS)
393396

@@ -404,9 +407,9 @@ def lambda_handler(event, context):
404407

405408
def test_batch_processor_context_success_only(sqs_event_factory, record_handler):
406409
# GIVEN
407-
first_record = sqs_event_factory("success")
408-
second_record = sqs_event_factory("success")
409-
records = [first_record, second_record]
410+
first_record = SQSRecord(sqs_event_factory("success"))
411+
second_record = SQSRecord(sqs_event_factory("success"))
412+
records = [first_record.raw_event, second_record.raw_event]
410413
processor = BatchProcessor(event_type=EventType.SQS)
411414

412415
# WHEN
@@ -415,35 +418,36 @@ def test_batch_processor_context_success_only(sqs_event_factory, record_handler)
415418

416419
# THEN
417420
assert processed_messages == [
418-
("success", first_record["body"], first_record),
419-
("success", second_record["body"], second_record),
421+
("success", first_record.body, first_record.raw_event),
422+
("success", second_record.body, second_record.raw_event),
420423
]
421424

422425
assert batch.response() == {"batchItemFailures": []}
423426

424427

425428
def test_batch_processor_context_with_failure(sqs_event_factory, record_handler):
426429
# GIVEN
427-
first_record = sqs_event_factory("failure")
428-
second_record = sqs_event_factory("success")
429-
records = [first_record, second_record]
430+
first_record = SQSRecord(sqs_event_factory("failure"))
431+
second_record = SQSRecord(sqs_event_factory("success"))
432+
records = [first_record.raw_event, second_record.raw_event]
430433
processor = BatchProcessor(event_type=EventType.SQS)
431434

432435
# WHEN
433436
with processor(records, record_handler) as batch:
434437
processed_messages = batch.process()
435438

436439
# THEN
437-
assert processed_messages[1] == ("success", second_record["body"], second_record)
440+
assert processed_messages[1] == ("success", second_record.body, second_record.raw_event)
438441
assert len(batch.fail_messages) == 1
439-
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["messageId"]}]}
442+
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.message_id}]}
440443

441444

442445
def test_batch_processor_kinesis_context_success_only(kinesis_event_factory, kinesis_record_handler):
443446
# GIVEN
444-
first_record = kinesis_event_factory("success")
445-
second_record = kinesis_event_factory("success")
446-
records = [first_record, second_record]
447+
first_record = KinesisStreamRecord(kinesis_event_factory("success"))
448+
second_record = KinesisStreamRecord(kinesis_event_factory("success"))
449+
450+
records = [first_record.raw_event, second_record.raw_event]
447451
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)
448452

449453
# WHEN
@@ -452,35 +456,36 @@ def test_batch_processor_kinesis_context_success_only(kinesis_event_factory, kin
452456

453457
# THEN
454458
assert processed_messages == [
455-
("success", decode_kinesis_data(first_record), first_record),
456-
("success", decode_kinesis_data(second_record), second_record),
459+
("success", b64_to_str(first_record.kinesis.data), first_record.raw_event),
460+
("success", b64_to_str(second_record.kinesis.data), second_record.raw_event),
457461
]
458462

459463
assert batch.response() == {"batchItemFailures": []}
460464

461465

462466
def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kinesis_record_handler):
463467
# GIVEN
464-
first_record = kinesis_event_factory("failure")
465-
second_record = kinesis_event_factory("success")
466-
records = [first_record, second_record]
468+
first_record = KinesisStreamRecord(kinesis_event_factory("failure"))
469+
second_record = KinesisStreamRecord(kinesis_event_factory("success"))
470+
471+
records = [first_record.raw_event, second_record.raw_event]
467472
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)
468473

469474
# WHEN
470475
with processor(records, kinesis_record_handler) as batch:
471476
processed_messages = batch.process()
472477

473478
# THEN
474-
assert processed_messages[1] == ("success", decode_kinesis_data(second_record), second_record)
479+
assert processed_messages[1] == ("success", b64_to_str(second_record.kinesis.data), second_record.raw_event)
475480
assert len(batch.fail_messages) == 1
476-
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["kinesis"]["sequenceNumber"]}]}
481+
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.kinesis.sequence_number}]}
477482

478483

479484
def test_batch_processor_kinesis_middleware_with_failure(kinesis_event_factory, kinesis_record_handler):
480485
# GIVEN
481-
first_record = kinesis_event_factory("failure")
482-
second_record = kinesis_event_factory("success")
483-
event = {"Records": [first_record, second_record]}
486+
first_record = KinesisStreamRecord(kinesis_event_factory("failure"))
487+
second_record = KinesisStreamRecord(kinesis_event_factory("success"))
488+
event = {"Records": [first_record.raw_event, second_record.raw_event]}
484489

485490
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)
486491

tests/functional/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ def str_to_b64(data: str) -> str:
1313
return base64.b64encode(data.encode()).decode("utf-8")
1414

1515

16-
def decode_kinesis_data(data: dict) -> str:
17-
return base64.b64decode(data["kinesis"]["data"].encode()).decode("utf-8")
16+
def b64_to_str(data: str) -> str:
17+
return base64.b64decode(data.encode()).decode("utf-8")

0 commit comments

Comments
 (0)