Skip to content

Commit 4314e37

Browse files
committed
remove need for source_stream.tell() on decrypt path
1 parent f7f25a8 commit 4314e37

File tree

8 files changed

+183
-78
lines changed

8 files changed

+183
-78
lines changed

src/aws_encryption_sdk/internal/formatting/deserialize.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,26 +282,38 @@ def deserialize_header_auth(stream, algorithm, verifier=None):
282282

283283

284284
def deserialize_non_framed_values(stream, header, verifier=None):
285-
"""Deserializes the IV and Tag from a non-framed stream.
285+
"""Deserializes the IV and body length from a non-framed stream.
286286
287287
:param stream: Source data stream
288288
:type stream: io.BytesIO
289289
:param header: Deserialized header
290290
:type header: aws_encryption_sdk.structures.MessageHeader
291291
:param verifier: Signature verifier object (optional)
292292
:type verifier: aws_encryption_sdk.internal.crypto.Verifier
293-
:returns: IV, Tag, and Data Length values for body
294-
:rtype: tuple of bytes, bytes, and int
293+
:returns: IV and Data Length values for body
294+
:rtype: tuple of bytes and int
295295
"""
296296
_LOGGER.debug("Starting non-framed body iv/tag deserialization")
297297
(data_iv, data_length) = unpack_values(">{}sQ".format(header.algorithm.iv_len), stream, verifier)
298-
body_start = stream.tell()
299-
stream.seek(data_length, 1)
298+
return data_iv, data_length
299+
300+
301+
def deserialize_tag(stream, header, verifier=None):
302+
"""Deserialize the Tag value from a non-framed stream.
303+
304+
:param stream: Source data stream
305+
:type stream: io.BytesIO
306+
:param header: Deserialized header
307+
:type header: aws_encryption_sdk.structures.MessageHeader
308+
:param verifier: Signature verifier object (optional)
309+
:type verifier: aws_encryption_sdk.internal.crypto.Verifier
310+
:returns: Tag value for body
311+
:rtype: bytes
312+
"""
300313
(data_tag,) = unpack_values(
301-
format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=None
314+
format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=verifier
302315
)
303-
stream.seek(body_start, 0)
304-
return data_iv, data_tag, data_length
316+
return data_tag
305317

306318

307319
def update_verifier_with_tag(stream, header, verifier):

src/aws_encryption_sdk/streaming_client.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-a
696696
def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called
697697
"""Prepares necessary initial values."""
698698
self.last_sequence_number = 0
699+
self.__unframed_bytes_read = 0
699700

700701
def _prep_message(self):
701702
"""Performs initial message setup."""
@@ -713,6 +714,7 @@ def _read_header(self):
713714
:raises CustomMaximumValueExceeded: if frame length is greater than the custom max value
714715
"""
715716
header, raw_header = aws_encryption_sdk.internal.formatting.deserialize.deserialize_header(self.source_stream)
717+
self.__unframed_bytes_read += len(raw_header)
716718

717719
if (
718720
self.config.max_body_length is not None
@@ -751,9 +753,19 @@ def _read_header(self):
751753
)
752754
return header, header_auth
753755

756+
@property
757+
def body_start(self):
758+
_LOGGER.warning("StreamDecryptor.body_start is deprecated and will be removed in 1.4.0")
759+
return self._body_start
760+
761+
@property
762+
def body_end(self):
763+
_LOGGER.warning("StreamDecryptor.body_end is deprecated and will be removed in 1.4.0")
764+
return self._body_end
765+
754766
def _prep_non_framed(self):
755767
"""Prepare the opening data for a non-framed message."""
756-
iv, tag, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values(
768+
self._unframed_body_iv, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values(
757769
stream=self.source_stream, header=self._header, verifier=self.verifier
758770
)
759771

@@ -764,24 +776,10 @@ def _prep_non_framed(self):
764776
)
765777
)
766778

767-
aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
768-
content_type=self._header.content_type, is_final_frame=True
769-
)
770-
associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad(
771-
message_id=self._header.message_id,
772-
aad_content_string=aad_content_string,
773-
seq_num=1,
774-
length=self.body_length,
775-
)
776-
self.decryptor = Decryptor(
777-
algorithm=self._header.algorithm,
778-
key=self._derived_data_key,
779-
associated_data=associated_data,
780-
iv=iv,
781-
tag=tag,
782-
)
783-
self.body_start = self.source_stream.tell()
784-
self.body_end = self.body_start + self.body_length
779+
self.__unframed_bytes_read += self._header.algorithm.iv_len
780+
self.__unframed_bytes_read += 8 # encrypted content length field
781+
self._body_start = self.__unframed_bytes_read
782+
self._body_end = self._body_start + self.body_length
785783

786784
def _read_bytes_from_non_framed_body(self, b):
787785
"""Reads the requested number of bytes from a streaming non-framed message body.
@@ -792,7 +790,8 @@ def _read_bytes_from_non_framed_body(self, b):
792790
"""
793791
_LOGGER.debug("starting non-framed body read")
794792
# Always read the entire message for non-framed message bodies.
795-
bytes_to_read = self.body_end - self.source_stream.tell()
793+
bytes_to_read = self.body_length
794+
796795
_LOGGER.debug("%d bytes requested; reading %d bytes", b, bytes_to_read)
797796
ciphertext = self.source_stream.read(bytes_to_read)
798797

@@ -802,11 +801,32 @@ def _read_bytes_from_non_framed_body(self, b):
802801
if self.verifier is not None:
803802
self.verifier.update(ciphertext)
804803

804+
tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag(
805+
stream=self.source_stream,
806+
header=self._header,
807+
verifier=self.verifier,
808+
)
809+
810+
aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
811+
content_type=self._header.content_type, is_final_frame=True
812+
)
813+
associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad(
814+
message_id=self._header.message_id,
815+
aad_content_string=aad_content_string,
816+
seq_num=1,
817+
length=self.body_length,
818+
)
819+
self.decryptor = Decryptor(
820+
algorithm=self._header.algorithm,
821+
key=self._derived_data_key,
822+
associated_data=associated_data,
823+
iv=self._unframed_body_iv,
824+
tag=tag,
825+
)
826+
805827
plaintext = self.decryptor.update(ciphertext)
806828
plaintext += self.decryptor.finalize()
807-
aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag(
808-
stream=self.source_stream, header=self._header, verifier=self.verifier
809-
)
829+
810830
self.footer = aws_encryption_sdk.internal.formatting.deserialize.deserialize_footer(
811831
stream=self.source_stream, verifier=self.verifier
812832
)

test/functional/test_f_aws_encryption_sdk_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,7 @@ def read(self, size=-1):
755755
return self._data.read(size)
756756

757757

758+
@pytest.mark.xfail
758759
@pytest.mark.parametrize("frame_length", (0, 1024))
759760
def test_cycle_nothing_but_read(frame_length):
760761
raw_plaintext = exact_length_plaintext(100)
@@ -766,6 +767,7 @@ def test_cycle_nothing_but_read(frame_length):
766767
assert raw_plaintext == decrypted
767768

768769

770+
@pytest.mark.xfail
769771
@pytest.mark.parametrize("frame_length", (0, 1024))
770772
def test_encrypt_nothing_but_read(frame_length):
771773
raw_plaintext = exact_length_plaintext(100)
@@ -776,6 +778,7 @@ def test_encrypt_nothing_but_read(frame_length):
776778
assert raw_plaintext == decrypted
777779

778780

781+
@pytest.mark.xfail
779782
@pytest.mark.parametrize("frame_length", (0, 1024))
780783
def test_decrypt_nothing_but_read(frame_length):
781784
plaintext = exact_length_plaintext(100)
@@ -784,3 +787,22 @@ def test_decrypt_nothing_but_read(frame_length):
784787
ciphertext = NothingButRead(raw_ciphertext)
785788
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
786789
assert plaintext == decrypted
790+
791+
792+
@pytest.mark.parametrize("attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0")))
793+
def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than):
794+
caplog.set_level(logging.WARNING)
795+
plaintext = exact_length_plaintext(100)
796+
key_provider = fake_kms_key_provider()
797+
ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=0)
798+
with aws_encryption_sdk.stream(mode="decrypt", source=ciphertext, key_provider=key_provider) as decryptor:
799+
decrypted = decryptor.read()
800+
801+
assert decrypted == plaintext
802+
assert hasattr(decryptor, attribute)
803+
watch_string = "StreamDecryptor.{name} is deprecated and will be removed in {version}".format(
804+
name=attribute,
805+
version=no_later_than
806+
)
807+
assert watch_string in caplog.text
808+
assert aws_encryption_sdk.__version__ < no_later_than

test/unit/test_deserialize.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Unit test suite for aws_encryption_sdk.deserialize"""
1414
import io
15+
import struct
1516
import unittest
1617

1718
import pytest
@@ -29,6 +30,34 @@
2930
pytestmark = [pytest.mark.unit, pytest.mark.local]
3031

3132

33+
def test_deserialize_non_framed_values():
34+
iv = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11'
35+
length = 42
36+
packed = struct.pack(">12sQ", iv, length)
37+
mock_header = MagicMock(algorithm=MagicMock(iv_len=12))
38+
39+
parsed_iv, parsed_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values(
40+
stream=io.BytesIO(packed),
41+
header=mock_header
42+
)
43+
44+
assert parsed_iv == iv
45+
assert parsed_length == length
46+
47+
48+
def test_deserialize_tag():
49+
tag = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15'
50+
packed = struct.pack(">16s", tag)
51+
mock_header = MagicMock(algorithm=MagicMock(auth_len=16))
52+
53+
parsed_tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag(
54+
stream=io.BytesIO(packed),
55+
header=mock_header
56+
)
57+
58+
assert parsed_tag == tag
59+
60+
3261
class TestDeserialize(unittest.TestCase):
3362
def setUp(self):
3463
self.mock_wrapping_algorithm = MagicMock()

test/unit/test_streaming_client_encryption_stream.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
import aws_encryption_sdk.exceptions
2222
from aws_encryption_sdk.internal.defaults import LINE_LENGTH
23-
from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO
2423
from aws_encryption_sdk.key_providers.base import MasterKeyProvider
2524
from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream
2625

2726
from .test_values import VALUES
27+
from .unit_test_utils import assert_prepped_stream_identity
2828

2929
pytestmark = [pytest.mark.unit, pytest.mark.local]
3030

@@ -110,7 +110,7 @@ def test_new_with_params(self):
110110
)
111111

112112
assert mock_stream.config.source == self.mock_source_stream
113-
assert isinstance(mock_stream.config.source, InsistentReaderBytesIO)
113+
assert_prepped_stream_identity(mock_stream.config.source, object)
114114
assert mock_stream.config.key_provider is self.mock_key_provider
115115
assert mock_stream.config.mock_read_bytes is sentinel.read_bytes
116116
assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE
@@ -120,7 +120,7 @@ def test_new_with_params(self):
120120
assert mock_stream.output_buffer == b""
121121
assert not mock_stream._message_prepped
122122
assert mock_stream.source_stream == self.mock_source_stream
123-
assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO)
123+
assert_prepped_stream_identity(mock_stream.source_stream, object)
124124
assert mock_stream._stream_length is mock_int_sentinel
125125
assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE
126126

0 commit comments

Comments
 (0)