Skip to content

refactor(batch): add from __future__ import annotations #4993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-

"""
Batch processing utilities
"""
from __future__ import annotations

import asyncio
import copy
import inspect
Expand All @@ -11,22 +12,28 @@
import sys
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union, overload
from typing import TYPE_CHECKING, Any, Callable, Tuple, Union, overload

from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.utilities.batch.exceptions import (
BatchProcessingError,
ExceptionInfo,
)
from aws_lambda_powertools.utilities.batch.types import BatchTypeModels, PartialItemFailureResponse, PartialItemFailures
from aws_lambda_powertools.utilities.batch.types import BatchTypeModels
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import (
DynamoDBRecord,
)
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import (
KinesisStreamRecord,
)
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.batch.types import (
PartialItemFailureResponse,
PartialItemFailures,
)
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)

Expand All @@ -41,7 +48,7 @@ class EventType(Enum):
# and depending on what EventType it's passed it'll correctly map to the right record
# When using Pydantic Models, it'll accept any subclass from SQS, DynamoDB and Kinesis
EventSourceDataClassTypes = Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord]
BatchEventTypes = Union[EventSourceDataClassTypes, "BatchTypeModels"]
BatchEventTypes = Union[EventSourceDataClassTypes, BatchTypeModels]
SuccessResponse = Tuple[str, Any, BatchEventTypes]
FailureResponse = Tuple[str, str, BatchEventTypes]

Expand All @@ -54,9 +61,9 @@ class BasePartialProcessor(ABC):
lambda_context: LambdaContext

def __init__(self):
self.success_messages: List[BatchEventTypes] = []
self.fail_messages: List[BatchEventTypes] = []
self.exceptions: List[ExceptionInfo] = []
self.success_messages: list[BatchEventTypes] = []
self.fail_messages: list[BatchEventTypes] = []
self.exceptions: list[ExceptionInfo] = []

@abstractmethod
def _prepare(self):
Expand All @@ -79,7 +86,7 @@ def _process_record(self, record: dict):
"""
raise NotImplementedError()

def process(self) -> List[Tuple]:
def process(self) -> list[tuple]:
"""
Call instance's handler for each record.
"""
Expand All @@ -92,7 +99,7 @@ async def _async_process_record(self, record: dict):
"""
raise NotImplementedError()

def async_process(self) -> List[Tuple]:
def async_process(self) -> list[tuple]:
"""
Async call instance's handler for each record.

Expand Down Expand Up @@ -135,13 +142,13 @@ def __enter__(self):
def __exit__(self, exception_type, exception_value, traceback):
self._clean()

def __call__(self, records: List[dict], handler: Callable, lambda_context: Optional[LambdaContext] = None):
def __call__(self, records: list[dict], handler: Callable, lambda_context: LambdaContext | None = None):
"""
Set instance attributes before execution

Parameters
----------
records: List[dict]
records: list[dict]
List with objects to be processed.
handler: Callable
Callable to process "records" entries.
Expand Down Expand Up @@ -222,14 +229,14 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
class BasePartialBatchProcessor(BasePartialProcessor): # noqa
DEFAULT_RESPONSE: PartialItemFailureResponse = {"batchItemFailures": []}

def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None):
def __init__(self, event_type: EventType, model: BatchTypeModels | None = None):
"""Process batch and partially report failed items

Parameters
----------
event_type: EventType
Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event
model: Optional["BatchTypeModels"]
model: BatchTypeModels | None
Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord

Exceptions
Expand Down Expand Up @@ -294,7 +301,7 @@ def _has_messages_to_report(self) -> bool:
def _entire_batch_failed(self) -> bool:
return len(self.exceptions) == len(self.records)

def _get_messages_to_report(self) -> List[PartialItemFailures]:
def _get_messages_to_report(self) -> list[PartialItemFailures]:
"""
Format messages to use in batch deletion
"""
Expand Down Expand Up @@ -343,13 +350,13 @@ def _to_batch_type(
self,
record: dict,
event_type: EventType,
model: "BatchTypeModels",
) -> "BatchTypeModels": ... # pragma: no cover
model: BatchTypeModels,
) -> BatchTypeModels: ... # pragma: no cover

@overload
def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes: ... # pragma: no cover

def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None):
def _to_batch_type(self, record: dict, event_type: EventType, model: BatchTypeModels | None = None):
if model is not None:
# If a model is provided, we assume Pydantic is installed and we need to disable v2 warnings
return model.model_validate(record)
Expand All @@ -363,7 +370,7 @@ def _register_model_validation_error_record(self, record: dict):
# and downstream we can correctly collect the correct message id identifier and make the failed record available
# see https://github.com/aws-powertools/powertools-lambda-python/issues/2091
logger.debug("Record cannot be converted to customer's model; converting without model")
failed_record: "EventSourceDataClassTypes" = self._to_batch_type(record=record, event_type=self.event_type)
failed_record: EventSourceDataClassTypes = self._to_batch_type(record=record, event_type=self.event_type)
return self.failure_handler(record=failed_record, exception=sys.exc_info())


Expand Down Expand Up @@ -453,7 +460,7 @@ def record_handler(record: DynamoDBRecord):
logger.info(record.dynamodb.new_image)
payload: dict = json.loads(record.dynamodb.new_image.get("item"))
# alternatively:
# changes: Dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
# changes: dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
# payload = change.get("Message") -> "<payload>"
...

Expand Down Expand Up @@ -481,7 +488,7 @@ def lambda_handler(event, context: LambdaContext):
async def _async_process_record(self, record: dict):
raise NotImplementedError()

def _process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]:
def _process_record(self, record: dict) -> SuccessResponse | FailureResponse:
"""
Process a record with instance's handler

Expand All @@ -490,7 +497,7 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
record: dict
A batch record to be processed.
"""
data: Optional["BatchTypeModels"] = None
data: BatchTypeModels | None = None
try:
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
if self._handler_accepts_lambda_context:
Expand Down Expand Up @@ -602,7 +609,7 @@ async def record_handler(record: DynamoDBRecord):
logger.info(record.dynamodb.new_image)
payload: dict = json.loads(record.dynamodb.new_image.get("item"))
# alternatively:
# changes: Dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
# changes: dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
# payload = change.get("Message") -> "<payload>"
...

Expand Down Expand Up @@ -630,7 +637,7 @@ def lambda_handler(event, context: LambdaContext):
def _process_record(self, record: dict):
raise NotImplementedError()

async def _async_process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]:
async def _async_process_record(self, record: dict) -> SuccessResponse | FailureResponse:
"""
Process a record with instance's handler

Expand All @@ -639,7 +646,7 @@ async def _async_process_record(self, record: dict) -> Union[SuccessResponse, Fa
record: dict
A batch record to be processed.
"""
data: Optional["BatchTypeModels"] = None
data: BatchTypeModels | None = None
try:
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
if self._handler_accepts_lambda_context:
Expand Down
28 changes: 15 additions & 13 deletions aws_lambda_powertools/utilities/batch/decorators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import Any, Awaitable, Callable, Dict, List
from typing import TYPE_CHECKING, Any, Awaitable, Callable

from typing_extensions import deprecated

Expand All @@ -12,10 +12,12 @@
BatchProcessor,
EventType,
)
from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools.warnings import PowertoolsDeprecationWarning

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse
from aws_lambda_powertools.utilities.typing import LambdaContext


@lambda_handler_decorator
@deprecated(
Expand All @@ -24,7 +26,7 @@
)
def async_batch_processor(
handler: Callable,
event: Dict,
event: dict,
context: LambdaContext,
record_handler: Callable[..., Awaitable[Any]],
processor: AsyncBatchProcessor,
Expand All @@ -40,7 +42,7 @@ def async_batch_processor(
----------
handler: Callable
Lambda's handler
event: Dict
event: dict
Lambda's Event
context: LambdaContext
Lambda's Context
Expand Down Expand Up @@ -92,7 +94,7 @@ def async_batch_processor(
)
def batch_processor(
handler: Callable,
event: Dict,
event: dict,
context: LambdaContext,
record_handler: Callable,
processor: BatchProcessor,
Expand All @@ -108,7 +110,7 @@ def batch_processor(
----------
handler: Callable
Lambda's handler
event: Dict
event: dict
Lambda's Event
context: LambdaContext
Lambda's Context
Expand Down Expand Up @@ -154,7 +156,7 @@ def batch_processor(


def process_partial_response(
event: Dict,
event: dict,
record_handler: Callable,
processor: BasePartialBatchProcessor,
context: LambdaContext | None = None,
Expand All @@ -164,7 +166,7 @@ def process_partial_response(

Parameters
----------
event: Dict
event: dict
Lambda's original event
record_handler: Callable
Callable to process each record from the batch
Expand Down Expand Up @@ -202,7 +204,7 @@ def handler(event, context):
* Async batch processors. Use `async_process_partial_response` instead.
"""
try:
records: List[Dict] = event.get("Records", [])
records: list[dict] = event.get("Records", [])
except AttributeError:
event_types = ", ".join(list(EventType.__members__))
docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line
Expand All @@ -218,7 +220,7 @@ def handler(event, context):


def async_process_partial_response(
event: Dict,
event: dict,
record_handler: Callable,
processor: AsyncBatchProcessor,
context: LambdaContext | None = None,
Expand All @@ -228,7 +230,7 @@ def async_process_partial_response(

Parameters
----------
event: Dict
event: dict
Lambda's original event
record_handler: Callable
Callable to process each record from the batch
Expand Down Expand Up @@ -266,7 +268,7 @@ def handler(event, context):
* Sync batch processors. Use `process_partial_response` instead.
"""
try:
records: List[Dict] = event.get("Records", [])
records: list[dict] = event.get("Records", [])
except AttributeError:
event_types = ", ".join(list(EventType.__members__))
docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line
Expand Down
6 changes: 3 additions & 3 deletions aws_lambda_powertools/utilities/batch/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import traceback
from types import TracebackType
from typing import List, Optional, Tuple, Type
from typing import Optional, Tuple, Type

ExceptionInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]


class BaseBatchProcessingError(Exception):
def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None):
def __init__(self, msg="", child_exceptions: list[ExceptionInfo] | None = None):
super().__init__(msg)
self.msg = msg
self.child_exceptions = child_exceptions or []
Expand All @@ -30,7 +30,7 @@ def format_exceptions(self, parent_exception_str):
class BatchProcessingError(BaseBatchProcessingError):
"""When all batch records failed to be processed"""

def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None):
def __init__(self, msg="", child_exceptions: list[ExceptionInfo] | None = None):
super().__init__(msg, child_exceptions)

def __str__(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import logging
from typing import Optional, Set
from typing import TYPE_CHECKING

from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, ExceptionInfo, FailureResponse
from aws_lambda_powertools.utilities.batch.exceptions import (
SQSFifoCircuitBreakerError,
SQSFifoMessageGroupCircuitBreakerError,
)
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,13 +66,13 @@ def lambda_handler(event, context: LambdaContext):
None,
)

def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_error: bool = False):
def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: bool = False):
"""
Initialize the SqsFifoProcessor.

Parameters
----------
model: Optional["BatchSqsTypeModel"]
model: BatchSqsTypeModel | None
An optional model for batch processing.
skip_group_on_error: bool
Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures
Expand All @@ -77,7 +81,7 @@ def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_er
"""
self._skip_group_on_error: bool = skip_group_on_error
self._current_group_id = None
self._failed_group_ids: Set[str] = set()
self._failed_group_ids: set[str] = set()
super().__init__(EventType.SQS, model)

def _process_record(self, record):
Expand Down
Loading
Loading