Skip to content

Commit 7771f51

Browse files
authored
Merge pull request #104 from mattsb42-aws/dev-103
Remove requirement for tell() on decrypt
2 parents 5de76e5 + 8f7f215 commit 7771f51

File tree

8 files changed

+223
-80
lines changed

8 files changed

+223
-80
lines changed

src/aws_encryption_sdk/internal/formatting/deserialize.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,26 +282,38 @@ def deserialize_header_auth(stream, algorithm, verifier=None):
282282

283283

284284
def deserialize_non_framed_values(stream, header, verifier=None):
285-
"""Deserializes the IV and Tag from a non-framed stream.
285+
"""Deserializes the IV and body length from a non-framed stream.
286286
287287
:param stream: Source data stream
288288
:type stream: io.BytesIO
289289
:param header: Deserialized header
290290
:type header: aws_encryption_sdk.structures.MessageHeader
291291
:param verifier: Signature verifier object (optional)
292292
:type verifier: aws_encryption_sdk.internal.crypto.Verifier
293-
:returns: IV, Tag, and Data Length values for body
294-
:rtype: tuple of bytes, bytes, and int
293+
:returns: IV and Data Length values for body
294+
:rtype: tuple of bytes and int
295295
"""
296296
_LOGGER.debug("Starting non-framed body iv/tag deserialization")
297297
(data_iv, data_length) = unpack_values(">{}sQ".format(header.algorithm.iv_len), stream, verifier)
298-
body_start = stream.tell()
299-
stream.seek(data_length, 1)
298+
return data_iv, data_length
299+
300+
301+
def deserialize_tag(stream, header, verifier=None):
302+
"""Deserialize the Tag value from a non-framed stream.
303+
304+
:param stream: Source data stream
305+
:type stream: io.BytesIO
306+
:param header: Deserialized header
307+
:type header: aws_encryption_sdk.structures.MessageHeader
308+
:param verifier: Signature verifier object (optional)
309+
:type verifier: aws_encryption_sdk.internal.crypto.Verifier
310+
:returns: Tag value for body
311+
:rtype: bytes
312+
"""
300313
(data_tag,) = unpack_values(
301-
format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=None
314+
format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=verifier
302315
)
303-
stream.seek(body_start, 0)
304-
return data_iv, data_tag, data_length
316+
return data_tag
305317

306318

307319
def update_verifier_with_tag(stream, header, verifier):

src/aws_encryption_sdk/streaming_client.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-a
696696
def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called
697697
"""Prepares necessary initial values."""
698698
self.last_sequence_number = 0
699+
self.__unframed_bytes_read = 0
699700

700701
def _prep_message(self):
701702
"""Performs initial message setup."""
@@ -713,6 +714,7 @@ def _read_header(self):
713714
:raises CustomMaximumValueExceeded: if frame length is greater than the custom max value
714715
"""
715716
header, raw_header = aws_encryption_sdk.internal.formatting.deserialize.deserialize_header(self.source_stream)
717+
self.__unframed_bytes_read += len(raw_header)
716718

717719
if (
718720
self.config.max_body_length is not None
@@ -751,9 +753,21 @@ def _read_header(self):
751753
)
752754
return header, header_auth
753755

756+
@property
757+
def body_start(self):
758+
"""Log deprecation warning when body_start is accessed."""
759+
_LOGGER.warning("StreamDecryptor.body_start is deprecated and will be removed in 1.4.0")
760+
return self._body_start
761+
762+
@property
763+
def body_end(self):
764+
"""Log deprecation warning when body_end is accessed."""
765+
_LOGGER.warning("StreamDecryptor.body_end is deprecated and will be removed in 1.4.0")
766+
return self._body_end
767+
754768
def _prep_non_framed(self):
755769
"""Prepare the opening data for a non-framed message."""
756-
iv, tag, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values(
770+
self._unframed_body_iv, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( # noqa # pylint: disable=line-too-long
757771
stream=self.source_stream, header=self._header, verifier=self.verifier
758772
)
759773

@@ -764,24 +778,10 @@ def _prep_non_framed(self):
764778
)
765779
)
766780

767-
aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
768-
content_type=self._header.content_type, is_final_frame=True
769-
)
770-
associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad(
771-
message_id=self._header.message_id,
772-
aad_content_string=aad_content_string,
773-
seq_num=1,
774-
length=self.body_length,
775-
)
776-
self.decryptor = Decryptor(
777-
algorithm=self._header.algorithm,
778-
key=self._derived_data_key,
779-
associated_data=associated_data,
780-
iv=iv,
781-
tag=tag,
782-
)
783-
self.body_start = self.source_stream.tell()
784-
self.body_end = self.body_start + self.body_length
781+
self.__unframed_bytes_read += self._header.algorithm.iv_len
782+
self.__unframed_bytes_read += 8 # encrypted content length field
783+
self._body_start = self.__unframed_bytes_read
784+
self._body_end = self._body_start + self.body_length
785785

786786
def _read_bytes_from_non_framed_body(self, b):
787787
"""Reads the requested number of bytes from a streaming non-framed message body.
@@ -792,7 +792,8 @@ def _read_bytes_from_non_framed_body(self, b):
792792
"""
793793
_LOGGER.debug("starting non-framed body read")
794794
# Always read the entire message for non-framed message bodies.
795-
bytes_to_read = self.body_end - self.source_stream.tell()
795+
bytes_to_read = self.body_length
796+
796797
_LOGGER.debug("%d bytes requested; reading %d bytes", b, bytes_to_read)
797798
ciphertext = self.source_stream.read(bytes_to_read)
798799

@@ -802,11 +803,30 @@ def _read_bytes_from_non_framed_body(self, b):
802803
if self.verifier is not None:
803804
self.verifier.update(ciphertext)
804805

805-
plaintext = self.decryptor.update(ciphertext)
806-
plaintext += self.decryptor.finalize()
807-
aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag(
806+
tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag(
808807
stream=self.source_stream, header=self._header, verifier=self.verifier
809808
)
809+
810+
aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
811+
content_type=self._header.content_type, is_final_frame=True
812+
)
813+
associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad(
814+
message_id=self._header.message_id,
815+
aad_content_string=aad_content_string,
816+
seq_num=1,
817+
length=self.body_length,
818+
)
819+
self.decryptor = Decryptor(
820+
algorithm=self._header.algorithm,
821+
key=self._derived_data_key,
822+
associated_data=associated_data,
823+
iv=self._unframed_body_iv,
824+
tag=tag,
825+
)
826+
827+
plaintext = self.decryptor.update(ciphertext)
828+
plaintext += self.decryptor.finalize()
829+
810830
self.footer = aws_encryption_sdk.internal.formatting.deserialize.deserialize_footer(
811831
stream=self.source_stream, verifier=self.verifier
812832
)

test/functional/test_f_aws_encryption_sdk_client.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,69 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size):
745745

746746
_look_in_logs(caplog, plaintext)
747747
_error_check(capsys)
748+
749+
750+
class NothingButRead(object):
751+
def __init__(self, data):
752+
self._data = io.BytesIO(data)
753+
754+
def read(self, size=-1):
755+
return self._data.read(size)
756+
757+
758+
@pytest.mark.xfail
759+
@pytest.mark.parametrize("frame_length", (0, 1024))
760+
def test_cycle_nothing_but_read(frame_length):
761+
raw_plaintext = exact_length_plaintext(100)
762+
plaintext = NothingButRead(raw_plaintext)
763+
key_provider = fake_kms_key_provider()
764+
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
765+
source=plaintext, key_provider=key_provider, frame_length=frame_length
766+
)
767+
ciphertext = NothingButRead(raw_ciphertext)
768+
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
769+
assert raw_plaintext == decrypted
770+
771+
772+
@pytest.mark.xfail
773+
@pytest.mark.parametrize("frame_length", (0, 1024))
774+
def test_encrypt_nothing_but_read(frame_length):
775+
raw_plaintext = exact_length_plaintext(100)
776+
plaintext = NothingButRead(raw_plaintext)
777+
key_provider = fake_kms_key_provider()
778+
ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
779+
source=plaintext, key_provider=key_provider, frame_length=frame_length
780+
)
781+
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
782+
assert raw_plaintext == decrypted
783+
784+
785+
@pytest.mark.xfail
786+
@pytest.mark.parametrize("frame_length", (0, 1024))
787+
def test_decrypt_nothing_but_read(frame_length):
788+
plaintext = exact_length_plaintext(100)
789+
key_provider = fake_kms_key_provider()
790+
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
791+
source=plaintext, key_provider=key_provider, frame_length=frame_length
792+
)
793+
ciphertext = NothingButRead(raw_ciphertext)
794+
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
795+
assert plaintext == decrypted
796+
797+
798+
@pytest.mark.parametrize("attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0")))
799+
def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than):
800+
caplog.set_level(logging.WARNING)
801+
plaintext = exact_length_plaintext(100)
802+
key_provider = fake_kms_key_provider()
803+
ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=0)
804+
with aws_encryption_sdk.stream(mode="decrypt", source=ciphertext, key_provider=key_provider) as decryptor:
805+
decrypted = decryptor.read()
806+
807+
assert decrypted == plaintext
808+
assert hasattr(decryptor, attribute)
809+
watch_string = "StreamDecryptor.{name} is deprecated and will be removed in {version}".format(
810+
name=attribute, version=no_later_than
811+
)
812+
assert watch_string in caplog.text
813+
assert aws_encryption_sdk.__version__ < no_later_than

test/unit/test_deserialize.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Unit test suite for aws_encryption_sdk.deserialize"""
1414
import io
15+
import struct
1516
import unittest
1617

1718
import pytest
@@ -29,6 +30,32 @@
2930
pytestmark = [pytest.mark.unit, pytest.mark.local]
3031

3132

33+
def test_deserialize_non_framed_values():
34+
iv = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11"
35+
length = 42
36+
packed = struct.pack(">12sQ", iv, length)
37+
mock_header = MagicMock(algorithm=MagicMock(iv_len=12))
38+
39+
parsed_iv, parsed_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values(
40+
stream=io.BytesIO(packed), header=mock_header
41+
)
42+
43+
assert parsed_iv == iv
44+
assert parsed_length == length
45+
46+
47+
def test_deserialize_tag():
48+
tag = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
49+
packed = struct.pack(">16s", tag)
50+
mock_header = MagicMock(algorithm=MagicMock(auth_len=16))
51+
52+
parsed_tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag(
53+
stream=io.BytesIO(packed), header=mock_header
54+
)
55+
56+
assert parsed_tag == tag
57+
58+
3259
class TestDeserialize(unittest.TestCase):
3360
def setUp(self):
3461
self.mock_wrapping_algorithm = MagicMock()

test/unit/test_streaming_client_encryption_stream.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
import aws_encryption_sdk.exceptions
2222
from aws_encryption_sdk.internal.defaults import LINE_LENGTH
23-
from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO
2423
from aws_encryption_sdk.key_providers.base import MasterKeyProvider
2524
from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream
2625

2726
from .test_values import VALUES
27+
from .unit_test_utils import assert_prepped_stream_identity
2828

2929
pytestmark = [pytest.mark.unit, pytest.mark.local]
3030

@@ -110,7 +110,7 @@ def test_new_with_params(self):
110110
)
111111

112112
assert mock_stream.config.source == self.mock_source_stream
113-
assert isinstance(mock_stream.config.source, InsistentReaderBytesIO)
113+
assert_prepped_stream_identity(mock_stream.config.source, object)
114114
assert mock_stream.config.key_provider is self.mock_key_provider
115115
assert mock_stream.config.mock_read_bytes is sentinel.read_bytes
116116
assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE
@@ -120,7 +120,7 @@ def test_new_with_params(self):
120120
assert mock_stream.output_buffer == b""
121121
assert not mock_stream._message_prepped
122122
assert mock_stream.source_stream == self.mock_source_stream
123-
assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO)
123+
assert_prepped_stream_identity(mock_stream.source_stream, object)
124124
assert mock_stream._stream_length is mock_int_sentinel
125125
assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE
126126

0 commit comments

Comments
 (0)