Skip to content

Commit 4c95d39

Browse files
committed
chore: improve mypy support on success/failure
1 parent 251541c commit 4c95d39

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ repos:
1111
- id: trailing-whitespace
1212
- id: end-of-file-fixer
1313
- id: check-toml
14-
- repo: https://github.com/pre-commit/pygrep-hooks
15-
rev: v1.5.1
16-
hooks:
17-
- id: python-use-type-annotations
1814
- repo: local
1915
hooks:
2016
- id: black

aws_lambda_powertools/utilities/batch/base.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,23 @@
1616
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
1717

1818
logger = logging.getLogger(__name__)
19-
has_pydantic = "pydantic" in sys.modules
2019

21-
SuccessCallback = Tuple[str, Any, dict]
22-
FailureCallback = Tuple[str, str, dict]
20+
21+
class EventType(Enum):
22+
SQS = "SQS"
23+
KinesisDataStreams = "KinesisDataStreams"
24+
DynamoDBStreams = "DynamoDBStreams"
25+
26+
27+
#
28+
# type specifics
29+
#
30+
has_pydantic = "pydantic" in sys.modules
2331
_ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
2432
_OptExcInfo = Union[_ExcInfo, Tuple[None, None, None]]
2533

34+
# For IntelliSense and Mypy to work, we need to account for possible SQS, Kinesis and DynamoDB subclasses
35+
# We need them as subclasses as we must access their message ID or sequence number metadata via dot notation
2636
if has_pydantic:
2737
from aws_lambda_powertools.utilities.parser.models import DynamoDBStreamRecordModel
2838
from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel
@@ -32,11 +42,13 @@
3242
Union[Type[SqsRecordModel], Type[DynamoDBStreamRecordModel], Type[KinesisDataStreamRecordModel]]
3343
]
3444

35-
36-
class EventType(Enum):
37-
SQS = "SQS"
38-
KinesisDataStreams = "KinesisDataStreams"
39-
DynamoDBStreams = "DynamoDBStreams"
45+
# When using processor with default arguments, records will carry EventSourceDataClassTypes
46+
# and depending on what EventType it's passed it'll correctly map to the right record
47+
# When using Pydantic Models, it'll accept any
48+
EventSourceDataClassTypes = Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord]
49+
BatchEventTypes = Union[EventSourceDataClassTypes, "BatchTypeModels"]
50+
SuccessCallback = Tuple[str, Any, BatchEventTypes]
51+
FailureCallback = Tuple[str, str, BatchEventTypes]
4052

4153

4254
class BasePartialProcessor(ABC):
@@ -45,8 +57,8 @@ class BasePartialProcessor(ABC):
4557
"""
4658

4759
def __init__(self):
48-
self.success_messages: List = []
49-
self.fail_messages: List = []
60+
self.success_messages: List[BatchEventTypes] = []
61+
self.fail_messages: List[BatchEventTypes] = []
5062
self.exceptions: List = []
5163

5264
@abstractmethod
@@ -98,7 +110,7 @@ def __call__(self, records: List[dict], handler: Callable):
98110
self.handler = handler
99111
return self
100112

101-
def success_handler(self, record: dict, result: Any) -> SuccessCallback:
113+
def success_handler(self, record, result: Any) -> SuccessCallback:
102114
"""
103115
Success callback
104116
@@ -111,7 +123,7 @@ def success_handler(self, record: dict, result: Any) -> SuccessCallback:
111123
self.success_messages.append(record)
112124
return entry
113125

114-
def failure_handler(self, record: dict, exception: _OptExcInfo) -> FailureCallback:
126+
def failure_handler(self, record, exception: _OptExcInfo) -> FailureCallback:
115127
"""
116128
Failure callback
117129
@@ -256,22 +268,20 @@ def _get_messages_to_report(self) -> Dict[str, str]:
256268
return self._COLLECTOR_MAPPING[self.event_type]()
257269

258270
def _collect_sqs_failures(self):
259-
return {"itemIdentifier": msg["messageId"] for msg in self.fail_messages}
271+
return {"itemIdentifier": msg.messageId for msg in self.fail_messages}
260272

261273
def _collect_kinesis_failures(self):
262-
return {"itemIdentifier": msg["kinesis"]["sequenceNumber"] for msg in self.fail_messages}
274+
return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages}
263275

264276
def _collect_dynamodb_failures(self):
265-
return {"itemIdentifier": msg["dynamodb"]["SequenceNumber"] for msg in self.fail_messages}
277+
return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages}
266278

267279
@overload
268280
def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels":
269281
...
270282

271283
@overload
272-
def _to_batch_type(
273-
self, record: dict, event_type: EventType
274-
) -> Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord]:
284+
def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes:
275285
...
276286

277287
def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None):

0 commit comments

Comments
 (0)