diff --git a/src/aws_encryption_sdk/internal/formatting/deserialize.py b/src/aws_encryption_sdk/internal/formatting/deserialize.py index 024ccca28..86fa4c06d 100644 --- a/src/aws_encryption_sdk/internal/formatting/deserialize.py +++ b/src/aws_encryption_sdk/internal/formatting/deserialize.py @@ -282,7 +282,7 @@ def deserialize_header_auth(stream, algorithm, verifier=None): def deserialize_non_framed_values(stream, header, verifier=None): - """Deserializes the IV and Tag from a non-framed stream. + """Deserializes the IV and body length from a non-framed stream. :param stream: Source data stream :type stream: io.BytesIO @@ -290,18 +290,30 @@ def deserialize_non_framed_values(stream, header, verifier=None): :type header: aws_encryption_sdk.structures.MessageHeader :param verifier: Signature verifier object (optional) :type verifier: aws_encryption_sdk.internal.crypto.Verifier - :returns: IV, Tag, and Data Length values for body - :rtype: tuple of bytes, bytes, and int + :returns: IV and Data Length values for body + :rtype: tuple of bytes and int """ _LOGGER.debug("Starting non-framed body iv/tag deserialization") (data_iv, data_length) = unpack_values(">{}sQ".format(header.algorithm.iv_len), stream, verifier) - body_start = stream.tell() - stream.seek(data_length, 1) + return data_iv, data_length + + +def deserialize_tag(stream, header, verifier=None): + """Deserialize the Tag value from a non-framed stream. + + :param stream: Source data stream + :type stream: io.BytesIO + :param header: Deserialized header + :type header: aws_encryption_sdk.structures.MessageHeader + :param verifier: Signature verifier object (optional) + :type verifier: aws_encryption_sdk.internal.crypto.Verifier + :returns: Tag value for body + :rtype: bytes + """ (data_tag,) = unpack_values( - format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=None + format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=verifier ) - stream.seek(body_start, 0) - return data_iv, data_tag, data_length + return data_tag def update_verifier_with_tag(stream, header, verifier): diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 539bdf86d..faadc6515 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -696,6 +696,7 @@ class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-a def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called """Prepares necessary initial values.""" self.last_sequence_number = 0 + self.__unframed_bytes_read = 0 def _prep_message(self): """Performs initial message setup.""" @@ -713,6 +714,7 @@ def _read_header(self): :raises CustomMaximumValueExceeded: if frame length is greater than the custom max value """ header, raw_header = aws_encryption_sdk.internal.formatting.deserialize.deserialize_header(self.source_stream) + self.__unframed_bytes_read += len(raw_header) if ( self.config.max_body_length is not None @@ -751,9 +753,21 @@ def _read_header(self): ) return header, header_auth + @property + def body_start(self): + """Log deprecation warning when body_start is accessed.""" + _LOGGER.warning("StreamDecryptor.body_start is deprecated and will be removed in 1.4.0") + return self._body_start + + @property + def body_end(self): + """Log deprecation warning when body_end is accessed.""" + _LOGGER.warning("StreamDecryptor.body_end is deprecated and will be removed in 1.4.0") + return self._body_end + def _prep_non_framed(self): """Prepare the opening data for a non-framed message.""" - iv, tag, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( + self._unframed_body_iv, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( # noqa # pylint: disable=line-too-long stream=self.source_stream, header=self._header, verifier=self.verifier ) @@ -764,24 +778,10 @@ def _prep_non_framed(self): ) ) - aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( - content_type=self._header.content_type, is_final_frame=True - ) - associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad( - message_id=self._header.message_id, - aad_content_string=aad_content_string, - seq_num=1, - length=self.body_length, - ) - self.decryptor = Decryptor( - algorithm=self._header.algorithm, - key=self._derived_data_key, - associated_data=associated_data, - iv=iv, - tag=tag, - ) - self.body_start = self.source_stream.tell() - self.body_end = self.body_start + self.body_length + self.__unframed_bytes_read += self._header.algorithm.iv_len + self.__unframed_bytes_read += 8 # encrypted content length field + self._body_start = self.__unframed_bytes_read + self._body_end = self._body_start + self.body_length def _read_bytes_from_non_framed_body(self, b): """Reads the requested number of bytes from a streaming non-framed message body. @@ -792,7 +792,8 @@ def _read_bytes_from_non_framed_body(self, b): """ _LOGGER.debug("starting non-framed body read") # Always read the entire message for non-framed message bodies. - bytes_to_read = self.body_end - self.source_stream.tell() + bytes_to_read = self.body_length + _LOGGER.debug("%d bytes requested; reading %d bytes", b, bytes_to_read) ciphertext = self.source_stream.read(bytes_to_read) @@ -802,11 +803,30 @@ def _read_bytes_from_non_framed_body(self, b): if self.verifier is not None: self.verifier.update(ciphertext) - plaintext = self.decryptor.update(ciphertext) - plaintext += self.decryptor.finalize() - aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag( + tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag( stream=self.source_stream, header=self._header, verifier=self.verifier ) + + aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( + content_type=self._header.content_type, is_final_frame=True + ) + associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad( + message_id=self._header.message_id, + aad_content_string=aad_content_string, + seq_num=1, + length=self.body_length, + ) + self.decryptor = Decryptor( + algorithm=self._header.algorithm, + key=self._derived_data_key, + associated_data=associated_data, + iv=self._unframed_body_iv, + tag=tag, + ) + + plaintext = self.decryptor.update(ciphertext) + plaintext += self.decryptor.finalize() + self.footer = aws_encryption_sdk.internal.formatting.deserialize.deserialize_footer( stream=self.source_stream, verifier=self.verifier ) diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index 08cec1c5e..a0bd32675 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -745,3 +745,69 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size): _look_in_logs(caplog, plaintext) _error_check(capsys) + + +class NothingButRead(object): + def __init__(self, data): + self._data = io.BytesIO(data) + + def read(self, size=-1): + return self._data.read(size) + + +@pytest.mark.xfail +@pytest.mark.parametrize("frame_length", (0, 1024)) +def test_cycle_nothing_but_read(frame_length): + raw_plaintext = exact_length_plaintext(100) + plaintext = NothingButRead(raw_plaintext) + key_provider = fake_kms_key_provider() + raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( + source=plaintext, key_provider=key_provider, frame_length=frame_length + ) + ciphertext = NothingButRead(raw_ciphertext) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + assert raw_plaintext == decrypted + + +@pytest.mark.xfail +@pytest.mark.parametrize("frame_length", (0, 1024)) +def test_encrypt_nothing_but_read(frame_length): + raw_plaintext = exact_length_plaintext(100) + plaintext = NothingButRead(raw_plaintext) + key_provider = fake_kms_key_provider() + ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( + source=plaintext, key_provider=key_provider, frame_length=frame_length + ) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + assert raw_plaintext == decrypted + + +@pytest.mark.xfail +@pytest.mark.parametrize("frame_length", (0, 1024)) +def test_decrypt_nothing_but_read(frame_length): + plaintext = exact_length_plaintext(100) + key_provider = fake_kms_key_provider() + raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( + source=plaintext, key_provider=key_provider, frame_length=frame_length + ) + ciphertext = NothingButRead(raw_ciphertext) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + assert plaintext == decrypted + + +@pytest.mark.parametrize("attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0"))) +def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than): + caplog.set_level(logging.WARNING) + plaintext = exact_length_plaintext(100) + key_provider = fake_kms_key_provider() + ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=0) + with aws_encryption_sdk.stream(mode="decrypt", source=ciphertext, key_provider=key_provider) as decryptor: + decrypted = decryptor.read() + + assert decrypted == plaintext + assert hasattr(decryptor, attribute) + watch_string = "StreamDecryptor.{name} is deprecated and will be removed in {version}".format( + name=attribute, version=no_later_than + ) + assert watch_string in caplog.text + assert aws_encryption_sdk.__version__ < no_later_than diff --git a/test/unit/test_deserialize.py b/test/unit/test_deserialize.py index ac3be1bf3..b591159e0 100644 --- a/test/unit/test_deserialize.py +++ b/test/unit/test_deserialize.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """Unit test suite for aws_encryption_sdk.deserialize""" import io +import struct import unittest import pytest @@ -29,6 +30,32 @@ pytestmark = [pytest.mark.unit, pytest.mark.local] +def test_deserialize_non_framed_values(): + iv = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11" + length = 42 + packed = struct.pack(">12sQ", iv, length) + mock_header = MagicMock(algorithm=MagicMock(iv_len=12)) + + parsed_iv, parsed_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( + stream=io.BytesIO(packed), header=mock_header + ) + + assert parsed_iv == iv + assert parsed_length == length + + +def test_deserialize_tag(): + tag = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15" + packed = struct.pack(">16s", tag) + mock_header = MagicMock(algorithm=MagicMock(auth_len=16)) + + parsed_tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag( + stream=io.BytesIO(packed), header=mock_header + ) + + assert parsed_tag == tag + + class TestDeserialize(unittest.TestCase): def setUp(self): self.mock_wrapping_algorithm = MagicMock() diff --git a/test/unit/test_streaming_client_encryption_stream.py b/test/unit/test_streaming_client_encryption_stream.py index 22cefffb0..e3a06347a 100644 --- a/test/unit/test_streaming_client_encryption_stream.py +++ b/test/unit/test_streaming_client_encryption_stream.py @@ -20,11 +20,11 @@ import aws_encryption_sdk.exceptions from aws_encryption_sdk.internal.defaults import LINE_LENGTH -from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO from aws_encryption_sdk.key_providers.base import MasterKeyProvider from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream from .test_values import VALUES +from .unit_test_utils import assert_prepped_stream_identity pytestmark = [pytest.mark.unit, pytest.mark.local] @@ -110,7 +110,7 @@ def test_new_with_params(self): ) assert mock_stream.config.source == self.mock_source_stream - assert isinstance(mock_stream.config.source, InsistentReaderBytesIO) + assert_prepped_stream_identity(mock_stream.config.source, object) assert mock_stream.config.key_provider is self.mock_key_provider assert mock_stream.config.mock_read_bytes is sentinel.read_bytes assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE @@ -120,7 +120,7 @@ def test_new_with_params(self): assert mock_stream.output_buffer == b"" assert not mock_stream._message_prepped assert mock_stream.source_stream == self.mock_source_stream - assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO) + assert_prepped_stream_identity(mock_stream.source_stream, object) assert mock_stream._stream_length is mock_int_sentinel assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index 50e981b16..0c9f53ee7 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -37,10 +37,12 @@ def setUp(self): data_key=VALUES["data_key_obj"], verification_key=sentinel.verification_key ) self.mock_header = MagicMock() - self.mock_header.algorithm = MagicMock(__class__=Algorithm) + self.mock_header.algorithm = MagicMock(__class__=Algorithm, iv_len=12) self.mock_header.encrypted_data_keys = sentinel.encrypted_data_keys self.mock_header.encryption_context = sentinel.encryption_context + self.mock_raw_header = b"some bytes" + self.mock_input_stream = MagicMock() self.mock_input_stream.__class__ = io.IOBase self.mock_input_stream.tell.side_effect = (0, 500) @@ -50,7 +52,7 @@ def setUp(self): "aws_encryption_sdk.streaming_client.aws_encryption_sdk.internal.formatting.deserialize.deserialize_header" ) self.mock_deserialize_header = self.mock_deserialize_header_patcher.start() - self.mock_deserialize_header.return_value = self.mock_header, sentinel.raw_header + self.mock_deserialize_header.return_value = self.mock_header, self.mock_raw_header # Set up deserialize_header_auth patch self.mock_deserialize_header_auth_patcher = patch( "aws_encryption_sdk.streaming_client" @@ -69,7 +71,13 @@ def setUp(self): ".aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values" ) self.mock_deserialize_non_framed_values = self.mock_deserialize_non_framed_values_patcher.start() - self.mock_deserialize_non_framed_values.return_value = (sentinel.iv, sentinel.tag, len(VALUES["data_128"])) + self.mock_deserialize_non_framed_values.return_value = (sentinel.iv, len(VALUES["data_128"])) + # Set up deserialize_tag_value patch + self.mock_deserialize_tag_patcher = patch( + "aws_encryption_sdk.streaming_client" ".aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag" + ) + self.mock_deserialize_tag = self.mock_deserialize_tag_patcher.start() + self.mock_deserialize_tag.return_value = sentinel.tag # Set up get_aad_content_string patch self.mock_get_aad_content_string_patcher = patch( "aws_encryption_sdk.streaming_client.aws_encryption_sdk.internal.utils.get_aad_content_string" @@ -113,6 +121,7 @@ def tearDown(self): self.mock_deserialize_header_auth_patcher.stop() self.mock_validate_header_patcher.stop() self.mock_deserialize_non_framed_values_patcher.stop() + self.mock_deserialize_tag_patcher.stop() self.mock_get_aad_content_string_patcher.stop() self.mock_assemble_content_aad_patcher.stop() self.mock_decryptor_patcher.stop() @@ -151,11 +160,9 @@ def test_prep_message_non_framed_message(self, mock_read_header, mock_prep_non_f @patch("aws_encryption_sdk.streaming_client.Verifier") @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") - @patch("aws_encryption_sdk.streaming_client.StreamDecryptor.__init__") - def test_read_header(self, mock_init, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier): + def test_read_header(self, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier): mock_verifier_instance = MagicMock() mock_verifier.from_key_bytes.return_value = mock_verifier_instance - mock_init.return_value = None ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(materials_manager=self.mock_materials_manager, source=ct_stream) test_decryptor.source_stream = ct_stream @@ -175,7 +182,7 @@ def test_read_header(self, mock_init, mock_derive_datakey, mock_decrypt_material self.mock_materials_manager.decrypt_materials.assert_called_once_with( request=mock_decrypt_materials_request.return_value ) - mock_verifier_instance.update.assert_called_once_with(sentinel.raw_header) + mock_verifier_instance.update.assert_called_once_with(self.mock_raw_header) self.mock_deserialize_header_auth.assert_called_once_with( stream=ct_stream, algorithm=self.mock_header.algorithm, verifier=mock_verifier_instance ) @@ -188,18 +195,16 @@ def test_read_header(self, mock_init, mock_derive_datakey, mock_decrypt_material self.mock_validate_header.assert_called_once_with( header=self.mock_header, header_auth=sentinel.header_auth, - raw_header=sentinel.raw_header, + raw_header=self.mock_raw_header, data_key=mock_derive_datakey.return_value, ) assert test_header is self.mock_header assert test_header_auth is sentinel.header_auth @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") - @patch("aws_encryption_sdk.streaming_client.StreamDecryptor.__init__") - def test_read_header_frame_too_large(self, mock_init, mock_derive_datakey): + def test_read_header_frame_too_large(self, mock_derive_datakey): self.mock_header.content_type = ContentType.FRAMED_DATA self.mock_header.frame_length = 1024 - mock_init.return_value = None ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream, max_body_length=10) test_decryptor.key_provider = self.mock_key_provider @@ -215,14 +220,10 @@ def test_read_header_frame_too_large(self, mock_init, mock_derive_datakey): @patch("aws_encryption_sdk.streaming_client.Verifier") @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") - @patch("aws_encryption_sdk.streaming_client.StreamDecryptor.__init__") - def test_read_header_no_verifier( - self, mock_init, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier - ): + def test_read_header_no_verifier(self, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier): self.mock_materials_manager.decrypt_materials.return_value = MagicMock( data_key=VALUES["data_key_obj"], verification_key=None ) - mock_init.return_value = None test_decryptor = StreamDecryptor(materials_manager=self.mock_materials_manager, source=self.mock_input_stream) test_decryptor.key_provider = self.mock_key_provider test_decryptor.source_stream = self.mock_input_stream @@ -264,6 +265,26 @@ def test_prep_non_framed(self): stream=test_decryptor.source_stream, header=self.mock_header, verifier=sentinel.verifier ) assert test_decryptor.body_length == len(VALUES["data_128"]) + assert test_decryptor.body_start == self.mock_header.algorithm.iv_len + 8 + assert test_decryptor.body_end == self.mock_header.algorithm.iv_len + 8 + len(VALUES["data_128"]) + + def test_read_bytes_from_non_framed(self): + ct_stream = io.BytesIO(VALUES["data_128"]) + test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) + test_decryptor.body_length = len(VALUES["data_128"]) + test_decryptor.decryptor = self.mock_decryptor_instance + test_decryptor._header = self.mock_header + test_decryptor.verifier = MagicMock() + test_decryptor._derived_data_key = sentinel.derived_data_key + test_decryptor._unframed_body_iv = sentinel.unframed_body_iv + self.mock_decryptor_instance.update.return_value = b"1234" + self.mock_decryptor_instance.finalize.return_value = b"5678" + + test = test_decryptor._read_bytes_from_non_framed_body(5) + + self.mock_deserialize_tag.assert_called_once_with( + stream=test_decryptor.source_stream, header=test_decryptor._header, verifier=test_decryptor.verifier + ) self.mock_get_aad_content_string.assert_called_once_with( content_type=self.mock_header.content_type, is_final_frame=True ) @@ -277,24 +298,10 @@ def test_prep_non_framed(self): algorithm=self.mock_header.algorithm, key=sentinel.derived_data_key, associated_data=sentinel.associated_data, - iv=sentinel.iv, + iv=sentinel.unframed_body_iv, tag=sentinel.tag, ) assert test_decryptor.decryptor is self.mock_decryptor_instance - assert test_decryptor.body_start == 0 - assert test_decryptor.body_end == len(VALUES["data_128"]) - - def test_read_bytes_from_non_framed(self): - ct_stream = io.BytesIO(VALUES["data_128"]) - test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.body_start = 0 - test_decryptor.body_length = test_decryptor.body_end = len(VALUES["data_128"]) - test_decryptor.decryptor = self.mock_decryptor_instance - test_decryptor._header = self.mock_header - test_decryptor.verifier = MagicMock() - self.mock_decryptor_instance.update.return_value = b"1234" - self.mock_decryptor_instance.finalize.return_value = b"5678" - test = test_decryptor._read_bytes_from_non_framed_body(5) test_decryptor.verifier.update.assert_called_once_with(VALUES["data_128"]) self.mock_decryptor_instance.update.assert_called_once_with(VALUES["data_128"]) assert test_decryptor.source_stream.closed @@ -303,8 +310,7 @@ def test_read_bytes_from_non_framed(self): def test_read_bytes_from_non_framed_message_body_too_small(self): ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.body_start = 0 - test_decryptor.body_length = test_decryptor.body_end = len(VALUES["data_128"] * 2) + test_decryptor.body_length = len(VALUES["data_128"] * 2) test_decryptor._header = self.mock_header with six.assertRaisesRegex( self, SerializationError, "Total message body contents less than specified in body description" @@ -314,10 +320,11 @@ def test_read_bytes_from_non_framed_message_body_too_small(self): def test_read_bytes_from_non_framed_no_verifier(self): ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.body_start = 0 - test_decryptor.body_length = test_decryptor.body_end = len(VALUES["data_128"]) + test_decryptor.body_length = len(VALUES["data_128"]) test_decryptor.decryptor = self.mock_decryptor_instance test_decryptor._header = self.mock_header + test_decryptor._derived_data_key = sentinel.derived_data_key + test_decryptor._unframed_body_iv = sentinel.unframed_body_iv test_decryptor.verifier = None self.mock_decryptor_instance.update.return_value = b"1234" test_decryptor._read_bytes_from_non_framed_body(5) @@ -325,19 +332,19 @@ def test_read_bytes_from_non_framed_no_verifier(self): def test_read_bytes_from_non_framed_finalize(self): ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.body_start = 0 - test_decryptor.body_length = test_decryptor.body_end = len(VALUES["data_128"]) + test_decryptor.body_length = len(VALUES["data_128"]) test_decryptor.decryptor = self.mock_decryptor_instance test_decryptor.verifier = MagicMock() test_decryptor._header = self.mock_header + test_decryptor._derived_data_key = sentinel.derived_data_key + test_decryptor._unframed_body_iv = sentinel.unframed_body_iv self.mock_decryptor_instance.update.return_value = b"1234" self.mock_decryptor_instance.finalize.return_value = b"5678" + test = test_decryptor._read_bytes_from_non_framed_body(len(VALUES["data_128"]) + 1) + test_decryptor.verifier.update.assert_called_once_with(VALUES["data_128"]) self.mock_decryptor_instance.update.assert_called_once_with(VALUES["data_128"]) - self.mock_update_verifier_with_tag.assert_called_once_with( - stream=test_decryptor.source_stream, header=test_decryptor._header, verifier=test_decryptor.verifier - ) self.mock_deserialize_footer.assert_called_once_with( stream=test_decryptor.source_stream, verifier=test_decryptor.verifier ) diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index c30247522..a94519b58 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -23,10 +23,10 @@ import aws_encryption_sdk.internal.utils from aws_encryption_sdk.exceptions import InvalidDataKeyError, SerializationError, UnknownIdentityError from aws_encryption_sdk.internal.defaults import MAX_FRAME_SIZE, MESSAGE_ID_LENGTH -from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO from aws_encryption_sdk.structures import DataKey, EncryptedDataKey, MasterKeyInfo, RawDataKey from .test_values import VALUES +from .unit_test_utils import assert_prepped_stream_identity pytestmark = [pytest.mark.unit, pytest.mark.local] @@ -34,17 +34,14 @@ def test_prep_stream_data_passthrough(): test = aws_encryption_sdk.internal.utils.prep_stream_data(io.BytesIO(b"some data")) - assert isinstance(test, InsistentReaderBytesIO) + assert_prepped_stream_identity(test, io.BytesIO) @pytest.mark.parametrize("source", (u"some unicode data ловие", b"\x00\x01\x02")) def test_prep_stream_data_wrap(source): test = aws_encryption_sdk.internal.utils.prep_stream_data(source) - # Check the wrapped stream - assert isinstance(test, io.BytesIO) - # Check the wrapping stream - assert isinstance(test, InsistentReaderBytesIO) + assert_prepped_stream_identity(test, io.BytesIO) class TestUtils(unittest.TestCase): diff --git a/test/unit/unit_test_utils.py b/test/unit/unit_test_utils.py index 7873456c1..6b0a84bdc 100644 --- a/test/unit/unit_test_utils.py +++ b/test/unit/unit_test_utils.py @@ -15,6 +15,8 @@ import io import itertools +from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO + def all_valid_kwargs(valid_kwargs): valid = [] @@ -79,3 +81,15 @@ def read(self, size=-1): if self._read_counter >= 2: self.close() return super(ExactlyTwoReads, self).read(size) + + +class FailingTeller(object): + def tell(self): + raise IOError("Tell not allowed!") + + +def assert_prepped_stream_identity(prepped_stream, wrapped_type): + # Check the wrapped stream + assert isinstance(prepped_stream, wrapped_type) + # Check the wrapping streams + assert isinstance(prepped_stream, InsistentReaderBytesIO)