From 5b6c04b615b298518a820f824de90347afe289d6 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Wed, 14 Nov 2018 12:40:18 -0800 Subject: [PATCH 1/2] expand minimal source stream API testing --- .../test_f_aws_encryption_sdk_client.py | 56 +++++++++++++++---- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index a0bd32675..aa63377b4 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -24,6 +24,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding from mock import MagicMock +from wrapt import ObjectProxy import aws_encryption_sdk from aws_encryption_sdk import KMSMasterKeyProvider @@ -749,31 +750,57 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size): class NothingButRead(object): def __init__(self, data): - self._data = io.BytesIO(data) + self._data = data def read(self, size=-1): return self._data.read(size) -@pytest.mark.xfail +class NoTell(ObjectProxy): + def tell(self): + raise NotImplementedError("NoTell does not tell().") + + +class NoClose(ObjectProxy): + closed = NotImplemented + + def close(self): + raise NotImplementedError("NoClose does not close().") + + +@pytest.mark.parametrize( + "wrapping_class", + ( + pytest.param(NoTell, marks=pytest.mark.xfail), + pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), + pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), + ), +) @pytest.mark.parametrize("frame_length", (0, 1024)) -def test_cycle_nothing_but_read(frame_length): +def test_cycle_minimal_source_stream_api(frame_length, wrapping_class): raw_plaintext = exact_length_plaintext(100) - plaintext = NothingButRead(raw_plaintext) + plaintext = wrapping_class(io.BytesIO(raw_plaintext)) key_provider = fake_kms_key_provider() raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( source=plaintext, key_provider=key_provider, frame_length=frame_length ) - ciphertext = NothingButRead(raw_ciphertext) + ciphertext = wrapping_class(io.BytesIO(raw_ciphertext)) decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert raw_plaintext == decrypted -@pytest.mark.xfail +@pytest.mark.parametrize( + "wrapping_class", + ( + pytest.param(NoTell, marks=pytest.mark.xfail), + pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), + pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), + ), +) @pytest.mark.parametrize("frame_length", (0, 1024)) -def test_encrypt_nothing_but_read(frame_length): +def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class): raw_plaintext = exact_length_plaintext(100) - plaintext = NothingButRead(raw_plaintext) + plaintext = wrapping_class(io.BytesIO(raw_plaintext)) key_provider = fake_kms_key_provider() ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( source=plaintext, key_provider=key_provider, frame_length=frame_length @@ -782,15 +809,22 @@ def test_encrypt_nothing_but_read(frame_length): assert raw_plaintext == decrypted -@pytest.mark.xfail +@pytest.mark.parametrize( + "wrapping_class", + ( + NoTell, + pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), + pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), + ), +) @pytest.mark.parametrize("frame_length", (0, 1024)) -def test_decrypt_nothing_but_read(frame_length): +def test_decrypt_minimal_source_stream_api(frame_length, wrapping_class): plaintext = exact_length_plaintext(100) key_provider = fake_kms_key_provider() raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( source=plaintext, key_provider=key_provider, frame_length=frame_length ) - ciphertext = NothingButRead(raw_ciphertext) + ciphertext = wrapping_class(io.BytesIO(raw_ciphertext)) decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert plaintext == decrypted From e3768c51a21b59ee5194e268e96a0ddc6465b0e6 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Wed, 14 Nov 2018 13:42:40 -0800 Subject: [PATCH 2/2] remove requirement for source_stream.tell() on encrypt --- src/aws_encryption_sdk/streaming_client.py | 20 ++++++++++++++++--- .../test_f_aws_encryption_sdk_client.py | 4 ++-- .../test_streaming_client_stream_encryptor.py | 5 +++++ 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 139c36b60..516a21913 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -399,6 +399,8 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not- ): raise SerializationError("Source too large for non-framed message") + self.__unframed_plaintext_cache = io.BytesIO() + def ciphertext_length(self): """Returns the length of the resulting ciphertext message in bytes. @@ -486,6 +488,17 @@ def _write_header(self): def _prep_non_framed(self): """Prepare the opening data for a non-framed message.""" + try: + plaintext_length = self.stream_length + self.__unframed_plaintext_cache = self.source_stream + except NotSupportedError: + # We need to know the plaintext length before we can start processing the data. + # If we cannot seek on the source then we need to read the entire source into memory. + self.__unframed_plaintext_cache = io.BytesIO() + self.__unframed_plaintext_cache.write(self.source_stream.read()) + plaintext_length = self.__unframed_plaintext_cache.tell() + self.__unframed_plaintext_cache.seek(0) + aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( content_type=self.content_type, is_final_frame=True ) @@ -493,7 +506,7 @@ def _prep_non_framed(self): message_id=self._header.message_id, aad_content_string=aad_content_string, seq_num=1, - length=self.stream_length, + length=plaintext_length, ) self.encryptor = Encryptor( algorithm=self._encryption_materials.algorithm, @@ -504,7 +517,7 @@ def _prep_non_framed(self): self.output_buffer += serialize_non_framed_open( algorithm=self._encryption_materials.algorithm, iv=self.encryptor.iv, - plaintext_length=self.stream_length, + plaintext_length=plaintext_length, signer=self.signer, ) @@ -516,7 +529,7 @@ def _read_bytes_to_non_framed_body(self, b): :rtype: bytes """ _LOGGER.debug("Reading %d bytes", b) - plaintext = self.source_stream.read(b) + plaintext = self.__unframed_plaintext_cache.read(b) plaintext_length = len(plaintext) if self.tell() + len(plaintext) > MAX_NON_FRAMED_SIZE: raise SerializationError("Source too large for non-framed message") @@ -529,6 +542,7 @@ def _read_bytes_to_non_framed_body(self, b): if len(plaintext) < b: _LOGGER.debug("Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b) self.source_stream.close() + self.__unframed_plaintext_cache.close() closing = self.encryptor.finalize() if self.signer is not None: diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index aa63377b4..74e78b382 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -771,7 +771,7 @@ def close(self): @pytest.mark.parametrize( "wrapping_class", ( - pytest.param(NoTell, marks=pytest.mark.xfail), + NoTell, pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), ), @@ -792,7 +792,7 @@ def test_cycle_minimal_source_stream_api(frame_length, wrapping_class): @pytest.mark.parametrize( "wrapping_class", ( - pytest.param(NoTell, marks=pytest.mark.xfail), + NoTell, pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), ), diff --git a/test/unit/test_streaming_client_stream_encryptor.py b/test/unit/test_streaming_client_stream_encryptor.py index 52e5f1215..8435d2bb3 100644 --- a/test/unit/test_streaming_client_stream_encryptor.py +++ b/test/unit/test_streaming_client_stream_encryptor.py @@ -382,7 +382,10 @@ def test_read_bytes_to_non_framed_body(self): test_encryptor.encryptor = MagicMock() test_encryptor._encryption_materials = self.mock_encryption_materials test_encryptor.encryptor.update.return_value = sentinel.ciphertext + test_encryptor._StreamEncryptor__unframed_plaintext_cache = pt_stream + test = test_encryptor._read_bytes_to_non_framed_body(5) + test_encryptor.encryptor.update.assert_called_once_with(self.plaintext[:5]) test_encryptor.signer.update.assert_called_once_with(sentinel.ciphertext) assert not test_encryptor.source_stream.closed @@ -392,6 +395,8 @@ def test_read_bytes_to_non_framed_body_too_large(self): pt_stream = io.BytesIO(self.plaintext) test_encryptor = StreamEncryptor(source=pt_stream, key_provider=self.mock_key_provider) test_encryptor.bytes_read = aws_encryption_sdk.internal.defaults.MAX_NON_FRAMED_SIZE + test_encryptor._StreamEncryptor__unframed_plaintext_cache = pt_stream + with six.assertRaisesRegex(self, SerializationError, "Source too large for non-framed message"): test_encryptor._read_bytes_to_non_framed_body(5)