|
8 | 8 | from abc import ABC, abstractmethod
|
9 | 9 | from enum import Enum
|
10 | 10 | from types import TracebackType
|
11 |
| -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union |
| 11 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, overload |
12 | 12 |
|
13 | 13 | from aws_lambda_powertools.middleware_factory import lambda_handler_decorator
|
14 | 14 | from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord
|
15 | 15 | from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord
|
16 | 16 | from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
|
17 | 17 |
|
18 | 18 | logger = logging.getLogger(__name__)
|
| 19 | +has_pydantic = "pydantic" in sys.modules |
| 20 | + |
19 | 21 | SuccessCallback = Tuple[str, Any, dict]
|
20 | 22 | FailureCallback = Tuple[str, str, dict]
|
21 |
| - |
22 | 23 | _ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
|
23 | 24 | _OptExcInfo = Union[_ExcInfo, Tuple[None, None, None]]
|
24 | 25 |
|
| 26 | +if has_pydantic: |
| 27 | + from aws_lambda_powertools.utilities.parser.models import DynamoDBStreamRecordModel |
| 28 | + from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel |
| 29 | + from aws_lambda_powertools.utilities.parser.models import SqsRecordModel |
| 30 | + |
| 31 | + BatchTypeModels = Union[SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecordModel] |
| 32 | + |
25 | 33 |
|
26 | 34 | class EventType(Enum):
|
27 | 35 | SQS = "SQS"
|
@@ -167,15 +175,18 @@ def batch_processor(
|
167 | 175 | class BatchProcessor(BasePartialProcessor):
|
168 | 176 | DEFAULT_RESPONSE: Dict[str, List[Optional[dict]]] = {"batchItemFailures": []}
|
169 | 177 |
|
170 |
| - def __init__(self, event_type: EventType): |
| 178 | + def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None): |
171 | 179 | """Process batch and partially report failed items
|
172 | 180 |
|
173 | 181 | Parameters
|
174 | 182 | ----------
|
175 | 183 | event_type: EventType
|
176 | 184 | Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event
|
| 185 | + model: Optional["BatchTypeModels"] |
| 186 | + Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord |
177 | 187 | """
|
178 | 188 | self.event_type = event_type
|
| 189 | + self.model = model |
179 | 190 | self.batch_response = self.DEFAULT_RESPONSE
|
180 | 191 | self._COLLECTOR_MAPPING = {
|
181 | 192 | EventType.SQS: self._collect_sqs_failures,
|
@@ -212,7 +223,7 @@ def _process_record(self, record: dict) -> Union[SuccessCallback, FailureCallbac
|
212 | 223 | A batch record to be processed.
|
213 | 224 | """
|
214 | 225 | try:
|
215 |
| - data = self._to_batch_type(record, event_type=self.event_type) |
| 226 | + data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) |
216 | 227 | result = self.handler(record=data)
|
217 | 228 | return self.success_handler(record=record, result=result)
|
218 | 229 | except Exception:
|
@@ -251,7 +262,18 @@ def _collect_kinesis_failures(self):
|
251 | 262 | def _collect_dynamodb_failures(self):
|
252 | 263 | return {"itemIdentifier": msg["dynamodb"]["SequenceNumber"] for msg in self.fail_messages}
|
253 | 264 |
|
| 265 | + @overload |
| 266 | + def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels": |
| 267 | + ... |
| 268 | + |
| 269 | + @overload |
254 | 270 | def _to_batch_type(
|
255 | 271 | self, record: dict, event_type: EventType
|
256 | 272 | ) -> Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord]:
|
257 |
| - return self._DATA_CLASS_MAPPING[event_type](record) # type: ignore # since DictWrapper inference is incorrect |
| 273 | + ... |
| 274 | + |
| 275 | + def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None): |
| 276 | + if model: |
| 277 | + return model.parse_obj(record) |
| 278 | + else: |
| 279 | + return self._DATA_CLASS_MAPPING[event_type](record) |
0 commit comments