Skip to content

Commit 3217097

Browse files
committed
feat: mypy support
1 parent c757316 commit 3217097

File tree

1 file changed

+21
-10
lines changed
  • aws_lambda_powertools/utilities/batch

1 file changed

+21
-10
lines changed

aws_lambda_powertools/utilities/batch/base.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,20 @@
77
import sys
88
from abc import ABC, abstractmethod
99
from enum import Enum
10-
from typing import Any, Callable, Dict, List, Optional, Tuple
10+
from types import TracebackType
11+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
1112

1213
from aws_lambda_powertools.middleware_factory import lambda_handler_decorator
1314
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord
1415
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord
1516
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
1617

1718
logger = logging.getLogger(__name__)
19+
SuccessCallback = Tuple[str, Any, dict]
20+
FailureCallback = Tuple[str, str, dict]
21+
22+
_ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
23+
_OptExcInfo = Union[_ExcInfo, Tuple[None, None, None]]
1824

1925

2026
class EventType(Enum):
@@ -48,7 +54,7 @@ def _clean(self):
4854
raise NotImplementedError()
4955

5056
@abstractmethod
51-
def _process_record(self, record: Any):
57+
def _process_record(self, record: dict):
5258
"""
5359
Process record with handler.
5460
"""
@@ -67,13 +73,13 @@ def __enter__(self):
6773
def __exit__(self, exception_type, exception_value, traceback):
6874
self._clean()
6975

70-
def __call__(self, records: List[Any], handler: Callable):
76+
def __call__(self, records: List[dict], handler: Callable):
7177
"""
7278
Set instance attributes before execution
7379
7480
Parameters
7581
----------
76-
records: List[Any]
82+
records: List[dict]
7783
List with objects to be processed.
7884
handler: Callable
7985
Callable to process "records" entries.
@@ -82,7 +88,7 @@ def __call__(self, records: List[Any], handler: Callable):
8288
self.handler = handler
8389
return self
8490

85-
def success_handler(self, record: Any, result: Any):
91+
def success_handler(self, record: dict, result: Any) -> SuccessCallback:
8692
"""
8793
Success callback
8894
@@ -95,7 +101,7 @@ def success_handler(self, record: Any, result: Any):
95101
self.success_messages.append(record)
96102
return entry
97103

98-
def failure_handler(self, record: Any, exception: Tuple):
104+
def failure_handler(self, record: dict, exception: _OptExcInfo) -> FailureCallback:
99105
"""
100106
Failure callback
101107
@@ -196,17 +202,17 @@ def _prepare(self):
196202
self.fail_messages.clear()
197203
self.batch_response = self.DEFAULT_RESPONSE
198204

199-
def _process_record(self, record) -> Tuple:
205+
def _process_record(self, record: dict) -> Union[SuccessCallback, FailureCallback]:
200206
"""
201207
Process a record with instance's handler
202208
203209
Parameters
204210
----------
205-
record: Any
206-
An object to be processed.
211+
record: dict
212+
A batch record to be processed.
207213
"""
208214
try:
209-
data = self._DATA_CLASS_MAPPING[self.event_type](record)
215+
data = self._to_batch_type(record, event_type=self.event_type)
210216
result = self.handler(record=data)
211217
return self.success_handler(record=record, result=result)
212218
except Exception:
@@ -244,3 +250,8 @@ def _collect_kinesis_failures(self):
244250

245251
def _collect_dynamodb_failures(self):
246252
return {"itemIdentifier": msg["dynamodb"]["SequenceNumber"] for msg in self.fail_messages}
253+
254+
def _to_batch_type(
255+
self, record: dict, event_type: EventType
256+
) -> Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord]:
257+
return self._DATA_CLASS_MAPPING[event_type](record) # type: ignore # since DictWrapper inference is incorrect

0 commit comments

Comments
 (0)