Skip to content

Commit c3e25d6

Browse files
fix(batch): handle early validation errors for pydantic models (poison pill) aws-powertools#2091 (aws-powertools#2099)
Co-authored-by: heitorlessa <lessa@amazon.co.uk>
1 parent c58f474 commit c3e25d6

File tree

8 files changed

+312
-88
lines changed

8 files changed

+312
-88
lines changed

aws_lambda_powertools/utilities/batch/base.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
KinesisStreamRecord,
3838
)
3939
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
40+
from aws_lambda_powertools.utilities.parser import ValidationError
4041
from aws_lambda_powertools.utilities.typing import LambdaContext
4142

4243
logger = logging.getLogger(__name__)
@@ -316,21 +317,36 @@ def _get_messages_to_report(self) -> List[Dict[str, str]]:
316317
def _collect_sqs_failures(self):
317318
failures = []
318319
for msg in self.fail_messages:
319-
msg_id = msg.messageId if self.model else msg.message_id
320+
# If a message failed due to model validation (e.g., poison pill)
321+
# we convert to an event source data class...but self.model is still true
322+
# therefore, we do an additional check on whether the failed message is still a model
323+
# see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091
324+
if self.model and getattr(msg, "parse_obj", None):
325+
msg_id = msg.messageId
326+
else:
327+
msg_id = msg.message_id
320328
failures.append({"itemIdentifier": msg_id})
321329
return failures
322330

323331
def _collect_kinesis_failures(self):
324332
failures = []
325333
for msg in self.fail_messages:
326-
msg_id = msg.kinesis.sequenceNumber if self.model else msg.kinesis.sequence_number
334+
# # see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091
335+
if self.model and getattr(msg, "parse_obj", None):
336+
msg_id = msg.kinesis.sequenceNumber
337+
else:
338+
msg_id = msg.kinesis.sequence_number
327339
failures.append({"itemIdentifier": msg_id})
328340
return failures
329341

330342
def _collect_dynamodb_failures(self):
331343
failures = []
332344
for msg in self.fail_messages:
333-
msg_id = msg.dynamodb.SequenceNumber if self.model else msg.dynamodb.sequence_number
345+
# see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091
346+
if self.model and getattr(msg, "parse_obj", None):
347+
msg_id = msg.dynamodb.SequenceNumber
348+
else:
349+
msg_id = msg.dynamodb.sequence_number
334350
failures.append({"itemIdentifier": msg_id})
335351
return failures
336352

@@ -347,6 +363,17 @@ def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["B
347363
return model.parse_obj(record)
348364
return self._DATA_CLASS_MAPPING[event_type](record)
349365

366+
def _register_model_validation_error_record(self, record: dict):
367+
"""Convert and register failure due to poison pills where model failed validation early"""
368+
# Parser will fail validation if record is a poison pill (malformed input)
369+
# this means we can't collect the message id if we try transforming again
370+
# so we convert into to the equivalent batch type model (e.g., SQS, Kinesis, DynamoDB Stream)
371+
# and downstream we can correctly collect the correct message id identifier and make the failed record available
372+
# see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091
373+
logger.debug("Record cannot be converted to customer's model; converting without model")
374+
failed_record: "EventSourceDataClassTypes" = self._to_batch_type(record=record, event_type=self.event_type)
375+
return self.failure_handler(record=failed_record, exception=sys.exc_info())
376+
350377

351378
class BatchProcessor(BasePartialBatchProcessor): # Keep old name for compatibility
352379
"""Process native partial responses from SQS, Kinesis Data Streams, and DynamoDB.
@@ -471,14 +498,17 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
471498
record: dict
472499
A batch record to be processed.
473500
"""
474-
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
501+
data: Optional["BatchTypeModels"] = None
475502
try:
503+
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
476504
if self._handler_accepts_lambda_context:
477505
result = self.handler(record=data, lambda_context=self.lambda_context)
478506
else:
479507
result = self.handler(record=data)
480508

481509
return self.success_handler(record=record, result=result)
510+
except ValidationError:
511+
return self._register_model_validation_error_record(record)
482512
except Exception:
483513
return self.failure_handler(record=data, exception=sys.exc_info())
484514

@@ -651,14 +681,17 @@ async def _async_process_record(self, record: dict) -> Union[SuccessResponse, Fa
651681
record: dict
652682
A batch record to be processed.
653683
"""
654-
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
684+
data: Optional["BatchTypeModels"] = None
655685
try:
686+
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
656687
if self._handler_accepts_lambda_context:
657688
result = await self.handler(record=data, lambda_context=self.lambda_context)
658689
else:
659690
result = await self.handler(record=data)
660691

661692
return self.success_handler(record=record, result=result)
693+
except ValidationError:
694+
return self._register_model_validation_error_record(record)
662695
except Exception:
663696
return self.failure_handler(record=data, exception=sys.exc_info())
664697

aws_lambda_powertools/utilities/parser/models/dynamodb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
class DynamoDBStreamChangedRecordModel(BaseModel):
1010
ApproximateCreationDateTime: Optional[date]
1111
Keys: Dict[str, Dict[str, Any]]
12-
NewImage: Optional[Union[Dict[str, Any], Type[BaseModel]]]
13-
OldImage: Optional[Union[Dict[str, Any], Type[BaseModel]]]
12+
NewImage: Optional[Union[Dict[str, Any], Type[BaseModel], BaseModel]]
13+
OldImage: Optional[Union[Dict[str, Any], Type[BaseModel], BaseModel]]
1414
SequenceNumber: str
1515
SizeBytes: int
1616
StreamViewType: Literal["NEW_AND_OLD_IMAGES", "KEYS_ONLY", "NEW_IMAGE", "OLD_IMAGE"]

aws_lambda_powertools/utilities/parser/models/kinesis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class KinesisDataStreamRecordPayload(BaseModel):
1515
kinesisSchemaVersion: str
1616
partitionKey: str
1717
sequenceNumber: str
18-
data: Union[bytes, Type[BaseModel]] # base64 encoded str is parsed into bytes
18+
data: Union[bytes, Type[BaseModel], BaseModel] # base64 encoded str is parsed into bytes
1919
approximateArrivalTimestamp: float
2020

2121
@validator("data", pre=True, allow_reuse=True)

aws_lambda_powertools/utilities/parser/models/sqs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class SqsMsgAttributeModel(BaseModel):
5252
class SqsRecordModel(BaseModel):
5353
messageId: str
5454
receiptHandle: str
55-
body: Union[str, Type[BaseModel]]
55+
body: Union[str, Type[BaseModel], BaseModel]
5656
attributes: SqsAttributesModel
5757
messageAttributes: Dict[str, SqsMsgAttributeModel]
5858
md5OfBody: str

aws_lambda_powertools/utilities/parser/types.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
import sys
44
from typing import Any, Dict, Type, TypeVar, Union
55

6-
from pydantic import BaseModel
6+
from pydantic import BaseModel, Json
77

88
# We only need typing_extensions for python versions <3.8
99
if sys.version_info >= (3, 8):
10-
from typing import Literal # noqa: F401
10+
from typing import Literal
1111
else:
12-
from typing_extensions import Literal # noqa: F401
12+
from typing_extensions import Literal
1313

1414
Model = TypeVar("Model", bound=BaseModel)
1515
EnvelopeModel = TypeVar("EnvelopeModel")
1616
EventParserReturnType = TypeVar("EventParserReturnType")
1717
AnyInheritedModel = Union[Type[BaseModel], BaseModel]
1818
RawDictOrModel = Union[Dict[str, Any], AnyInheritedModel]
19+
20+
__all__ = ["Json", "Literal"]

tests/functional/batch/__init__.py

Whitespace-only changes.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import json
2+
from typing import Dict, Optional
3+
4+
from aws_lambda_powertools.utilities.parser import BaseModel, validator
5+
from aws_lambda_powertools.utilities.parser.models import (
6+
DynamoDBStreamChangedRecordModel,
7+
DynamoDBStreamRecordModel,
8+
KinesisDataStreamRecord,
9+
KinesisDataStreamRecordPayload,
10+
SqsRecordModel,
11+
)
12+
from aws_lambda_powertools.utilities.parser.types import Json, Literal
13+
14+
15+
class Order(BaseModel):
16+
item: dict
17+
18+
19+
class OrderSqs(SqsRecordModel):
20+
body: Json[Order]
21+
22+
23+
class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload):
24+
data: Json[Order]
25+
26+
27+
class OrderKinesisRecord(KinesisDataStreamRecord):
28+
kinesis: OrderKinesisPayloadRecord
29+
30+
31+
class OrderDynamoDB(BaseModel):
32+
Message: Order
33+
34+
# auto transform json string
35+
# so Pydantic can auto-initialize nested Order model
36+
@validator("Message", pre=True)
37+
def transform_message_to_dict(cls, value: Dict[Literal["S"], str]):
38+
return json.loads(value["S"])
39+
40+
41+
class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel):
42+
NewImage: Optional[OrderDynamoDB]
43+
OldImage: Optional[OrderDynamoDB]
44+
45+
46+
class OrderDynamoDBRecord(DynamoDBStreamRecordModel):
47+
dynamodb: OrderDynamoDBChangeRecord

0 commit comments

Comments
 (0)