Skip to content

Commit e7e36a9

Browse files
committed
feat(sqs): improve validation for queue_url
1 parent 9c7324b commit e7e36a9

File tree

3 files changed

+288
-4
lines changed

3 files changed

+288
-4
lines changed

aws_lambda_powertools/utilities/batch/sqs.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,26 @@
1515

1616
class PartialSQSProcessor(BasePartialProcessor):
1717
def __init__(self):
18-
self._client = boto3.client("sqs")
18+
self.client = boto3.client("sqs")
1919
self.success_messages: List = []
2020
self.fail_messages: List = []
2121

2222
super().__init__()
2323

2424
def get_queue_url(self):
25+
"""
26+
Format QueueUrl from first records entry
27+
"""
28+
if not getattr(self, "records", None):
29+
return
30+
2531
*_, account_id, queue_name = self.records[0]["eventSourceARN"].split(":")
26-
return f"{self._client._endpoint.host}/{account_id}/{queue_name}"
32+
return f"{self.client._endpoint.host}/{account_id}/{queue_name}"
2733

2834
def get_entries_to_clean(self):
35+
"""
36+
Format messages to use in batch deletion
37+
"""
2938
return [{"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]} for msg in self.success_messages]
3039

3140
def _process_record(self, record):
@@ -37,21 +46,22 @@ def _process_record(self, record):
3746

3847
def _prepare(self):
3948
"""
49+
Remove results from previous executions.
4050
"""
4151
self.success_messages.clear()
4252
self.fail_messages.clear()
4353

4454
def _clean(self):
4555
"""
56+
Delete messages from Queue in case of partial failure.
4657
"""
47-
# skip only failures or only successes
4858
if not (self.fail_messages and self.success_messages):
4959
return
5060

5161
queue_url = self.get_queue_url()
5262
entries_to_remove = self.get_entries_to_clean()
5363

54-
return self._client.delete_message_batch(QueueUrl=queue_url, Entries=entries_to_remove)
64+
return self.client.delete_message_batch(QueueUrl=queue_url, Entries=entries_to_remove)
5565

5666

5767
@lambda_handler_decorator
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from typing import Callable
2+
import pytest
3+
4+
from botocore.stub import Stubber
5+
6+
from aws_lambda_powertools.utilities.batch import PartialSQSProcessor, partial_sqs_processor
7+
8+
9+
@pytest.fixture
10+
def sqs_event_factory() -> Callable:
11+
def factory(body: str):
12+
return {
13+
"messageId": "059f36b4-87a3-44ab-83d2-661975830a7d",
14+
"receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a",
15+
"body": body,
16+
"attributes": {},
17+
"messageAttributes": {},
18+
"md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3",
19+
"eventSource": "aws:sqs",
20+
"eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue",
21+
"awsRegion": "us-east-1",
22+
}
23+
24+
return factory
25+
26+
27+
@pytest.fixture
28+
def record_handler() -> Callable:
29+
def handler(record):
30+
body = record["body"]
31+
if "fail" in body:
32+
raise Exception("Failed to process record.")
33+
return body
34+
35+
return handler
36+
37+
38+
def test_partial_sqs_processor_context_with_failure(sqs_event_factory, record_handler):
39+
""" Test processor with one record failing """
40+
processor = PartialSQSProcessor()
41+
42+
fail_record = sqs_event_factory("fail")
43+
success_record = sqs_event_factory("success")
44+
45+
records = [fail_record, success_record]
46+
47+
response = {"Successful": [{"Id": fail_record["messageId"]},], "Failed": []}
48+
49+
with Stubber(processor.client) as stubber:
50+
stubber.add_response("delete_message_batch", response)
51+
52+
with processor(records, record_handler) as ctx:
53+
result = ctx.process()
54+
55+
stubber.assert_no_pending_responses()
56+
57+
assert result == [
58+
("fail", ("Failed to process record.",), fail_record),
59+
("success", success_record["body"], success_record),
60+
]
61+
62+
63+
def test_partial_sqs_processor_context_only_success(sqs_event_factory, record_handler):
64+
""" Test processor without failure """
65+
processor = PartialSQSProcessor()
66+
67+
first_record = sqs_event_factory("success")
68+
second_record = sqs_event_factory("success")
69+
70+
records = [first_record, second_record]
71+
72+
with processor(records, record_handler) as ctx:
73+
result = ctx.process()
74+
75+
assert result == [
76+
("success", first_record["body"], first_record),
77+
("success", second_record["body"], second_record),
78+
]
79+
80+
81+
def test_partial_sqs_processor_context_multiple_calls(sqs_event_factory, record_handler):
82+
""" Test processor without failure """
83+
processor = PartialSQSProcessor()
84+
85+
first_record = sqs_event_factory("success")
86+
second_record = sqs_event_factory("success")
87+
88+
records = [first_record, second_record]
89+
90+
with processor(records, record_handler) as ctx:
91+
ctx.process()
92+
93+
with processor([first_record], record_handler) as ctx:
94+
ctx.process()
95+
96+
assert processor.success_messages == [first_record]
97+
98+
99+
def test_partial_sqs_processor_middleware_with_default(sqs_event_factory, record_handler):
100+
""" Test middleware with default partial processor """
101+
processor = PartialSQSProcessor()
102+
103+
@partial_sqs_processor(record_handler=record_handler, processor=processor)
104+
def lambda_handler(event, context):
105+
return True
106+
107+
fail_record = sqs_event_factory("fail")
108+
109+
event = {"Records": [sqs_event_factory("fail"), sqs_event_factory("success")]}
110+
response = {"Successful": [{"Id": fail_record["messageId"]},], "Failed": []}
111+
112+
with Stubber(processor.client) as stubber:
113+
stubber.add_response("delete_message_batch", response)
114+
115+
result = lambda_handler(event, {})
116+
117+
stubber.assert_no_pending_responses()
118+
119+
assert result is True
120+
121+
122+
def test_partial_sqs_processor_middleware_with_custom(capsys, sqs_event_factory, record_handler):
123+
""" Test middle with custom partial processor """
124+
class CustomProcessor(PartialSQSProcessor):
125+
def failure_handler(self, record, exception):
126+
print("Oh no ! It's a failure.")
127+
return super().failure_handler(record, exception)
128+
129+
processor = CustomProcessor()
130+
131+
@partial_sqs_processor(record_handler=record_handler, processor=processor)
132+
def lambda_handler(event, context):
133+
return True
134+
135+
fail_record = sqs_event_factory("fail")
136+
137+
event = {"Records": [sqs_event_factory("fail"), sqs_event_factory("success")]}
138+
response = {"Successful": [{"Id": fail_record["messageId"]},], "Failed": []}
139+
140+
with Stubber(processor.client) as stubber:
141+
stubber.add_response("delete_message_batch", response)
142+
143+
result = lambda_handler(event, {})
144+
145+
stubber.assert_no_pending_responses()
146+
147+
assert result is True
148+
assert capsys.readouterr().out == "Oh no ! It's a failure.\n"

tests/unit/test_utilities_batch.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import pytest
2+
3+
from aws_lambda_powertools.utilities.batch import PartialSQSProcessor
4+
5+
6+
@pytest.fixture(scope="module")
7+
def sqs_event():
8+
return {
9+
"messageId": "059f36b4-87a3-44ab-83d2-661975830a7d",
10+
"receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a",
11+
"body": "",
12+
"attributes": {},
13+
"messageAttributes": {},
14+
"md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3",
15+
"eventSource": "aws:sqs",
16+
"eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue",
17+
"awsRegion": "us-east-1",
18+
}
19+
20+
21+
def test_partial_sqs_get_queue_url_with_records(mocker, sqs_event):
22+
expected_url = "https://queue.amazonaws.com/123456789012/my-queue"
23+
24+
records_mock = mocker.patch.object(PartialSQSProcessor, "records", create=True, new_callable=mocker.PropertyMock)
25+
records_mock.return_value = [sqs_event]
26+
27+
result = PartialSQSProcessor().get_queue_url()
28+
assert result == expected_url
29+
30+
31+
def test_partial_sqs_get_queue_url_without_records():
32+
assert PartialSQSProcessor().get_queue_url() == None
33+
34+
35+
def test_partial_sqs_get_entries_to_clean_with_success(mocker, sqs_event):
36+
expected_entries = [{"Id": sqs_event["messageId"], "ReceiptHandle": sqs_event["receiptHandle"]}]
37+
38+
success_messages_mock = mocker.patch.object(
39+
PartialSQSProcessor, "success_messages", new_callable=mocker.PropertyMock
40+
)
41+
success_messages_mock.return_value = [sqs_event]
42+
43+
result = PartialSQSProcessor().get_entries_to_clean()
44+
45+
assert result == expected_entries
46+
47+
48+
def test_partial_sqs_get_entries_to_clean_without_success(mocker):
49+
expected_entries = []
50+
51+
success_messages_mock = mocker.patch.object(
52+
PartialSQSProcessor, "success_messages", new_callable=mocker.PropertyMock
53+
)
54+
success_messages_mock.return_value = []
55+
56+
result = PartialSQSProcessor().get_entries_to_clean()
57+
58+
assert result == expected_entries
59+
60+
61+
def test_partial_sqs_process_record_success(mocker):
62+
expected_value = mocker.sentinel.expected_value
63+
64+
success_result = mocker.sentinel.success_result
65+
record = mocker.sentinel.record
66+
67+
handler_mock = mocker.patch.object(PartialSQSProcessor, "handler", create=True, return_value=success_result)
68+
success_handler_mock = mocker.patch.object(PartialSQSProcessor, "success_handler", return_value=expected_value)
69+
70+
result = PartialSQSProcessor()._process_record(record)
71+
72+
handler_mock.assert_called_once_with(record)
73+
success_handler_mock.assert_called_once_with(record, success_result)
74+
75+
assert result == expected_value
76+
77+
78+
def test_partial_sqs_process_record_failure(mocker):
79+
expected_value = mocker.sentinel.expected_value
80+
81+
failure_result = Exception()
82+
record = mocker.sentinel.record
83+
84+
handler_mock = mocker.patch.object(PartialSQSProcessor, "handler", create=True, side_effect=failure_result)
85+
failure_handler_mock = mocker.patch.object(PartialSQSProcessor, "failure_handler", return_value=expected_value)
86+
87+
result = PartialSQSProcessor()._process_record(record)
88+
89+
handler_mock.assert_called_once_with(record)
90+
failure_handler_mock.assert_called_once_with(record, failure_result)
91+
92+
assert result == expected_value
93+
94+
95+
def test_partial_sqs_prepare(mocker):
96+
processor = PartialSQSProcessor()
97+
98+
success_messages_mock = mocker.patch.object(processor, "success_messages", spec=list)
99+
failed_messages_mock = mocker.patch.object(processor, "fail_messages", spec=list)
100+
101+
processor._prepare()
102+
103+
success_messages_mock.clear.assert_called_once()
104+
failed_messages_mock.clear.assert_called_once()
105+
106+
107+
def test_partial_sqs_clean(monkeypatch, mocker):
108+
processor = PartialSQSProcessor()
109+
records = [mocker.sentinel.record]
110+
111+
monkeypatch.setattr(processor, "fail_messages", records)
112+
monkeypatch.setattr(processor, "success_messages", records)
113+
114+
queue_url_mock = mocker.patch.object(PartialSQSProcessor, "get_queue_url")
115+
entries_to_clean_mock = mocker.patch.object(PartialSQSProcessor, "get_entries_to_clean")
116+
117+
queue_url_mock.return_value = mocker.sentinel.queue_url
118+
entries_to_clean_mock.return_value = mocker.sentinel.entries_to_clean
119+
120+
client_mock = mocker.patch.object(processor, "client", autospec=True)
121+
122+
processor._clean()
123+
124+
client_mock.delete_message_batch.assert_called_once_with(
125+
QueueUrl=mocker.sentinel.queue_url, Entries=mocker.sentinel.entries_to_clean
126+
)

0 commit comments

Comments
 (0)