From 504d08be9cef66947e16255326523ac8ddac3baf Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 12:51:00 -0800 Subject: [PATCH 01/13] normalize read() parameter value to correctly always be an int --- src/aws_encryption_sdk/streaming_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 05f3a8a01..5ca17310e 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -202,16 +202,16 @@ def readable(self): # Open streams are currently always readable. return not self.closed - def read(self, b=None): + def read(self, b=-1): """Returns either the requested number of bytes or the entire stream. :param int b: Number of bytes to read :returns: Processed (encrypted or decrypted) bytes from source stream :rtype: bytes """ - # Any negative value for b is interpreted as a full read - if b is not None and b < 0: - b = None + # NoneType value for b is interpretted as a full read for legacy compatibility + if b is None: + b = -1 _LOGGER.debug("Stream read called, requesting %s bytes", b) output = io.BytesIO() From cbea075ebe24c7be27d3060755bceab7ce2186b9 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Thu, 8 Nov 2018 19:42:26 -0800 Subject: [PATCH 02/13] fix merge --- src/aws_encryption_sdk/streaming_client.py | 18 ++++- .../test_f_aws_encryption_sdk_client.py | 72 +++++++++++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 5ca17310e..43c6113e2 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -209,16 +209,20 @@ def read(self, b=-1): :returns: Processed (encrypted or decrypted) bytes from source stream :rtype: bytes """ - # NoneType value for b is interpretted as a full read for legacy compatibility - if b is None: + # Any negative value for b is interpreted as a full read + # None is also accepted for legacy compatibility + if b is None or b < 0: b = -1 _LOGGER.debug("Stream read called, requesting %s bytes", b) output = io.BytesIO() + if not self._message_prepped: self._prep_message() + if self.closed: raise ValueError("I/O operation on closed file") + if b: self._read_bytes(b) output.write(self.output_buffer[:b]) @@ -228,6 +232,7 @@ def read(self, b=-1): self._read_bytes(LINE_LENGTH) output.write(self.output_buffer) self.output_buffer = b"" + self.bytes_read += output.tell() _LOGGER.debug("Returning %s bytes of %s bytes requested", output.tell(), b) return output.getvalue() @@ -511,14 +516,18 @@ def _read_bytes_to_non_framed_body(self, b): _LOGGER.debug("Closing encryptor after receiving only %s bytes of %s bytes requested", plaintext, b) self.source_stream.close() closing = self.encryptor.finalize() + if self.signer is not None: self.signer.update(closing) + closing += aws_encryption_sdk.internal.formatting.serialize.serialize_non_framed_close( tag=self.encryptor.tag, signer=self.signer ) + if self.signer is not None: closing += aws_encryption_sdk.internal.formatting.serialize.serialize_footer(self.signer) return ciphertext + closing + return ciphertext def _read_bytes_to_framed_body(self, b): @@ -535,9 +544,11 @@ def _read_bytes_to_framed_body(self, b): plaintext = self.source_stream.read(b) _LOGGER.debug("%s bytes read from source", len(plaintext)) finalize = False + if len(plaintext) < b: _LOGGER.debug("Final plaintext read from source") finalize = True + output = b"" final_frame_written = False @@ -776,10 +787,13 @@ def _read_bytes_from_non_framed_body(self, b): bytes_to_read = self.body_end - self.source_stream.tell() _LOGGER.debug("%s bytes requested; reading %s bytes", b, bytes_to_read) ciphertext = self.source_stream.read(bytes_to_read) + if len(self.output_buffer) + len(ciphertext) < self.body_length: raise SerializationError("Total message body contents less than specified in body description") + if self.verifier is not None: self.verifier.update(ciphertext) + plaintext = self.decryptor.update(ciphertext) plaintext += self.decryptor.finalize() aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag( diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index c76f8d3f0..838ec7864 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -598,3 +598,75 @@ def test_stream_decryptor_readable(): assert handler.readable() handler.read() assert not handler.readable() + + +def exact_length_plaintext(length): + plaintext = b"" + while len(plaintext) < length: + plaintext += VALUES["plaintext_128"] + return plaintext[:length] + + +class SometimesIncompleteReaderIO(io.BytesIO): + def __init__(self, *args, **kwargs): + self.__read_counter = 0 + super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs) + + def read(self, size=-1): + """Every other read request, return fewer than the requested number of bytes if more than one byte requested.""" + self.__read_counter += 1 + if size > 1 and self.__read_counter % 2 == 0: + size //= 2 + return super(SometimesIncompleteReaderIO, self).read(size) + + +@pytest.mark.parametrize( + "frame_length", + ( + 0, # 0: unframed + 128, # 128: framed with exact final frame size match + 256, # 256: framed with inexact final frame size match + ), +) +def test_incomplete_read_stream_cycle(caplog, frame_length): + caplog.set_level(logging.DEBUG) + chunk_size = 21 # Will never be an exact match for the frame size + key_provider = fake_kms_key_provider() + + plaintext = exact_length_plaintext(384) + ciphertext = b"" + cycle_count = 0 + with aws_encryption_sdk.stream( + mode="encrypt", + source=SometimesIncompleteReaderIO(plaintext), + key_provider=key_provider, + frame_length=frame_length, + ) as encryptor: + while True: + cycle_count += 1 + chunk = encryptor.read(chunk_size) + if not chunk: + break + ciphertext += chunk + if cycle_count > len(VALUES["plaintext_128"]): + raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError( + "Unexpected error encrypting message: infinite loop detected." + ) + + decrypted = b"" + cycle_count = 0 + with aws_encryption_sdk.stream( + mode="decrypt", source=SometimesIncompleteReaderIO(ciphertext), key_provider=key_provider + ) as decryptor: + while True: + cycle_count += 1 + chunk = decryptor.read(chunk_size) + if not chunk: + break + decrypted += chunk + if cycle_count > len(VALUES["plaintext_128"]): + raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError( + "Unexpected error encrypting message: infinite loop detected." + ) + + assert ciphertext != decrypted == VALUES["plaintext_128"] From 4b0fb9fa5a26abb32ea7cce3460087b9371f15cb Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Thu, 8 Nov 2018 20:06:15 -0800 Subject: [PATCH 03/13] handle partial reads from source streams * Add InsistentReaderBytesIO wrapper that will keep reading from the wrapped stream until empty or enough bytes have been collected * update prep_stream_data to make sure that all input data is wrapped in InsistentReaderBytesIO --- .../internal/utils/__init__.py | 10 +++-- .../internal/utils/streams.py | 36 ++++++++++++++++++ .../test_f_aws_encryption_sdk_client.py | 2 +- test/unit/test_util_streams.py | 37 ++++++++++++++++++- test/unit/unit_test_utils.py | 22 +++++++++++ 5 files changed, 101 insertions(+), 6 deletions(-) diff --git a/src/aws_encryption_sdk/internal/utils/__init__.py b/src/aws_encryption_sdk/internal/utils/__init__.py index 065722d0d..1e7400c3a 100644 --- a/src/aws_encryption_sdk/internal/utils/__init__.py +++ b/src/aws_encryption_sdk/internal/utils/__init__.py @@ -23,6 +23,8 @@ from aws_encryption_sdk.internal.str_ops import to_bytes from aws_encryption_sdk.structures import EncryptedDataKey +from .streams import InsistentReaderBytesIO + _LOGGER = logging.getLogger(__name__) @@ -132,12 +134,14 @@ def prep_stream_data(data): :param data: Input data :returns: Prepared stream - :rtype: io.BytesIO + :rtype: InsistentReaderBytesIO """ if isinstance(data, (six.string_types, six.binary_type)): - return io.BytesIO(to_bytes(data)) + stream = io.BytesIO(to_bytes(data)) + else: + stream = data - return data + return InsistentReaderBytesIO(stream) def source_data_key_length_check(source_data_key, algorithm): diff --git a/src/aws_encryption_sdk/internal/utils/streams.py b/src/aws_encryption_sdk/internal/utils/streams.py index b1bf5953c..8c50d4449 100644 --- a/src/aws_encryption_sdk/internal/utils/streams.py +++ b/src/aws_encryption_sdk/internal/utils/streams.py @@ -11,9 +11,12 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Helper stream utility objects for AWS Encryption SDK.""" +import io + from wrapt import ObjectProxy from aws_encryption_sdk.exceptions import ActionNotAllowedError +from aws_encryption_sdk.internal.str_ops import to_bytes class ROStream(ObjectProxy): @@ -56,3 +59,36 @@ def read(self, b=None): data = self.__wrapped__.read(b) self.__tee.write(data) return data + + +class InsistentReaderBytesIO(ObjectProxy): + """Wrapper around a readable stream that insists on reading exactly the requested + number of bytes. It will keep trying to read bytes from the wrapped stream until + either the requested number of bytes are available or the wrapped stream has + nothing more to return. + + :param wrapped: File-like object + """ + + def read(self, b=-1): + """Keep reading from source stream until either the source stream is done + or the requested number of bytes have been obtained. + + :param int b: number of bytes to read + :return: All bytes read from wrapped stream + :rtype: bytes + """ + remaining_bytes = b + data = io.BytesIO() + while True: + chunk = to_bytes(self.__wrapped__.read(remaining_bytes)) + + if not chunk: + break + + data.write(chunk) + remaining_bytes -= len(chunk) + + if remaining_bytes <= 0: + break + return data.getvalue() diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index 838ec7864..b22a8eb8a 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -669,4 +669,4 @@ def test_incomplete_read_stream_cycle(caplog, frame_length): "Unexpected error encrypting message: infinite loop detected." ) - assert ciphertext != decrypted == VALUES["plaintext_128"] + assert ciphertext != decrypted == plaintext diff --git a/test/unit/test_util_streams.py b/test/unit/test_util_streams.py index 15a79ad10..5a1d74ec7 100644 --- a/test/unit/test_util_streams.py +++ b/test/unit/test_util_streams.py @@ -16,13 +16,25 @@ import pytest from aws_encryption_sdk.exceptions import ActionNotAllowedError +from aws_encryption_sdk.internal.str_ops import to_bytes, to_str from aws_encryption_sdk.internal.utils.streams import ROStream, TeeStream +from .unit_test_utils import NothingButRead, SometimesIncompleteReaderIO + pytestmark = [pytest.mark.unit, pytest.mark.local] -def data(): - return io.BytesIO(b"asdijfhoaisjdfoiasjdfoijawef") +def data(length=None, stream_type=io.BytesIO, converter=to_bytes): + source = b"asdijfhoaisjdfoiasjdfoijawef" + chunk_length = 100 + + if length is None: + length = len(source) + + while len(source) < length: + source += source[:chunk_length] + + return stream_type(converter(source[:length])) def test_rostream(): @@ -41,3 +53,24 @@ def test_teestream_full(): raw_read = test_tee.read() assert data().getvalue() == raw_read == new_tee.getvalue() + + +@pytest.mark.parametrize( + "stream_type, converter", + ( + (io.BytesIO, to_bytes), + (SometimesIncompleteReaderIO, to_bytes), + (io.StringIO, to_str), + (NothingButRead, to_bytes), + ), +) +@pytest.mark.parametrize("bytes_to_read", (1, 4, 7, 99, 500)) +@pytest.mark.parametrize("source_length", (1, 11, 100)) +def test_insistent_stream(source_length, bytes_to_read, stream_type, converter): + source = data(length=source_length, stream_type=stream_type, converter=converter) + + test = source.read(bytes_to_read) + + assert (source_length >= bytes_to_read and len(test) == bytes_to_read) or ( + source_length < bytes_to_read and len(test) == source_length + ) diff --git a/test/unit/unit_test_utils.py b/test/unit/unit_test_utils.py index 5c7c1d49d..b3ac25cee 100644 --- a/test/unit/unit_test_utils.py +++ b/test/unit/unit_test_utils.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """Utility functions to handle common test framework functions.""" import copy +import io import itertools @@ -50,3 +51,24 @@ def build_valid_kwargs_list(base, optional_kwargs): _kwargs.update(dict(valid_options)) valid_kwargs.append(_kwargs) return valid_kwargs + + +class SometimesIncompleteReaderIO(io.BytesIO): + def __init__(self, *args, **kwargs): + self.__read_counter = 0 + super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs) + + def read(self, size=-1): + """Every other read request, return fewer than the requested number of bytes if more than one byte requested.""" + self.__read_counter += 1 + if size > 1 and self.__read_counter % 2 == 0: + size //= 2 + return super(SometimesIncompleteReaderIO, self).read(size) + + +class NothingButRead(object): + def __init__(self, data): + self._data = io.BytesIO(data) + + def read(self, size=-1): + return self._data.read(size) From af9567d7b3efeabfe76dbddb3fbb6894c732b8e3 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 17:24:06 -0800 Subject: [PATCH 04/13] add necessary handling to correctly read when b value is corrected to -1 --- src/aws_encryption_sdk/streaming_client.py | 23 ++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 43c6113e2..6f88b55b2 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -223,7 +223,7 @@ def read(self, b=-1): if self.closed: raise ValueError("I/O operation on closed file") - if b: + if b >= 0: self._read_bytes(b) output.write(self.output_buffer[:b]) self.output_buffer = self.output_buffer[b:] @@ -539,13 +539,19 @@ def _read_bytes_to_framed_body(self, b): """ _LOGGER.debug("collecting %s bytes", b) _b = b - b = int(math.ceil(b / float(self.config.frame_length)) * self.config.frame_length) - _LOGGER.debug("%s bytes requested; reading %s bytes after normalizing to frame length", _b, b) + + if b > 0: + _frames_to_read = math.ceil(b / float(self.config.frame_length)) + b = int(_frames_to_read * self.config.frame_length) + _LOGGER.debug("%d bytes requested; reading %d bytes after normalizing to frame length", _b, b) + plaintext = self.source_stream.read(b) - _LOGGER.debug("%s bytes read from source", len(plaintext)) + plaintext_length = len(plaintext) + _LOGGER.debug("%d bytes read from source", plaintext_length) + finalize = False - if len(plaintext) < b: + if b < 0 or plaintext_length < b: _LOGGER.debug("Final plaintext read from source") finalize = True @@ -594,8 +600,8 @@ def _read_bytes(self, b): :param int b: Number of bytes to read :raises NotSupportedError: if content type is not supported """ - _LOGGER.debug("%s bytes requested from stream with content type: %s", b, self.content_type) - if b <= len(self.output_buffer) or self.source_stream.closed: + _LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type) + if 0 <= b <= len(self.output_buffer) or self.source_stream.closed: _LOGGER.debug("No need to read from source stream or source stream closed") return @@ -858,7 +864,8 @@ def _read_bytes(self, b): _LOGGER.debug("Source stream closed") return - if b <= len(self.output_buffer): + buffer_length = len(self.output_buffer) + if 0 <= b <= buffer_length: _LOGGER.debug( "%s bytes requested less than or equal to current output buffer size %s", b, len(self.output_buffer) ) From 8bf76116a05f5825bcd50b995b5d04e2848b517f Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 17:27:14 -0800 Subject: [PATCH 05/13] expand range of read sizes --- test/unit/test_util_streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/test_util_streams.py b/test/unit/test_util_streams.py index 5a1d74ec7..b896a3a3f 100644 --- a/test/unit/test_util_streams.py +++ b/test/unit/test_util_streams.py @@ -64,7 +64,7 @@ def test_teestream_full(): (NothingButRead, to_bytes), ), ) -@pytest.mark.parametrize("bytes_to_read", (1, 4, 7, 99, 500)) +@pytest.mark.parametrize("bytes_to_read", range(1, 102)) @pytest.mark.parametrize("source_length", (1, 11, 100)) def test_insistent_stream(source_length, bytes_to_read, stream_type, converter): source = data(length=source_length, stream_type=stream_type, converter=converter) From e22d171fc67ea94c2f9b524b069178a1c1f5e818 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 17:27:36 -0800 Subject: [PATCH 06/13] always log at debug in pytest --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 366846d99..038fc5924 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,6 +11,7 @@ branch = True show_missing = True [tool:pytest] +log_level = DEBUG markers = local: superset of unit and functional (does not require network access) unit: mark test as a unit test (does not require network access) From 725bfe38fb59d2e3d5571b43c5c23a03374e4be8 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 20:39:23 -0800 Subject: [PATCH 07/13] update prep_stream_data unit tests --- test/unit/test_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index d74499787..c30247522 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -23,6 +23,7 @@ import aws_encryption_sdk.internal.utils from aws_encryption_sdk.exceptions import InvalidDataKeyError, SerializationError, UnknownIdentityError from aws_encryption_sdk.internal.defaults import MAX_FRAME_SIZE, MESSAGE_ID_LENGTH +from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO from aws_encryption_sdk.structures import DataKey, EncryptedDataKey, MasterKeyInfo, RawDataKey from .test_values import VALUES @@ -31,16 +32,19 @@ def test_prep_stream_data_passthrough(): - test = aws_encryption_sdk.internal.utils.prep_stream_data(sentinel.not_a_string_or_bytes) + test = aws_encryption_sdk.internal.utils.prep_stream_data(io.BytesIO(b"some data")) - assert test is sentinel.not_a_string_or_bytes + assert isinstance(test, InsistentReaderBytesIO) @pytest.mark.parametrize("source", (u"some unicode data ловие", b"\x00\x01\x02")) def test_prep_stream_data_wrap(source): test = aws_encryption_sdk.internal.utils.prep_stream_data(source) + # Check the wrapped stream assert isinstance(test, io.BytesIO) + # Check the wrapping stream + assert isinstance(test, InsistentReaderBytesIO) class TestUtils(unittest.TestCase): From 268853a962ef929392b860d2ceced1fbf977a756 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 20:39:36 -0800 Subject: [PATCH 08/13] autoformat --- src/aws_encryption_sdk/streaming_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 6f88b55b2..4e6bd43a9 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -866,9 +866,7 @@ def _read_bytes(self, b): buffer_length = len(self.output_buffer) if 0 <= b <= buffer_length: - _LOGGER.debug( - "%s bytes requested less than or equal to current output buffer size %s", b, len(self.output_buffer) - ) + _LOGGER.debug("%d bytes requested less than or equal to current output buffer size %d", b, buffer_length) return if self._header.content_type == ContentType.FRAMED_DATA: From 626eb4e8147623f2315ccb00c3bb7e419ba0fef9 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 21:26:29 -0800 Subject: [PATCH 09/13] un-mock tests to work correctly with updated prep_data_stream implementation --- ...test_streaming_client_encryption_stream.py | 19 ++++++++++-------- .../test_streaming_client_stream_decryptor.py | 2 +- .../test_streaming_client_stream_encryptor.py | 20 +++++++++---------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/test/unit/test_streaming_client_encryption_stream.py b/test/unit/test_streaming_client_encryption_stream.py index 98b0aad81..22cefffb0 100644 --- a/test/unit/test_streaming_client_encryption_stream.py +++ b/test/unit/test_streaming_client_encryption_stream.py @@ -20,6 +20,7 @@ import aws_encryption_sdk.exceptions from aws_encryption_sdk.internal.defaults import LINE_LENGTH +from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO from aws_encryption_sdk.key_providers.base import MasterKeyProvider from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream @@ -107,17 +108,19 @@ def test_new_with_params(self): line_length=io.DEFAULT_BUFFER_SIZE, source_length=mock_int_sentinel, ) - assert mock_stream.config == MockClientConfig( - source=self.mock_source_stream, - key_provider=self.mock_key_provider, - mock_read_bytes=sentinel.read_bytes, - line_length=io.DEFAULT_BUFFER_SIZE, - source_length=mock_int_sentinel, - ) + + assert mock_stream.config.source == self.mock_source_stream + assert isinstance(mock_stream.config.source, InsistentReaderBytesIO) + assert mock_stream.config.key_provider is self.mock_key_provider + assert mock_stream.config.mock_read_bytes is sentinel.read_bytes + assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE + assert mock_stream.config.source_length is mock_int_sentinel + assert mock_stream.bytes_read == 0 assert mock_stream.output_buffer == b"" assert not mock_stream._message_prepped - assert mock_stream.source_stream is self.mock_source_stream + assert mock_stream.source_stream == self.mock_source_stream + assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO) assert mock_stream._stream_length is mock_int_sentinel assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index 3b824e199..50e981b16 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -261,7 +261,7 @@ def test_prep_non_framed(self): test_decryptor._prep_non_framed() self.mock_deserialize_non_framed_values.assert_called_once_with( - stream=self.mock_input_stream, header=self.mock_header, verifier=sentinel.verifier + stream=test_decryptor.source_stream, header=self.mock_header, verifier=sentinel.verifier ) assert test_decryptor.body_length == len(VALUES["data_128"]) self.mock_get_aad_content_string.assert_called_once_with( diff --git a/test/unit/test_streaming_client_stream_encryptor.py b/test/unit/test_streaming_client_stream_encryptor.py index f3747a98a..96a3f2a43 100644 --- a/test/unit/test_streaming_client_stream_encryptor.py +++ b/test/unit/test_streaming_client_stream_encryptor.py @@ -58,8 +58,6 @@ def setUp(self): self.mock_master_key = MagicMock(__class__=MasterKey) - self.mock_input_stream = MagicMock(__class__=io.IOBase) - self.mock_frame_length = MagicMock(__class__=int) self.mock_algorithm = MagicMock(__class__=Algorithm) @@ -178,7 +176,7 @@ def tearDown(self): def test_init(self): test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider, frame_length=self.mock_frame_length, algorithm=self.mock_algorithm, @@ -190,7 +188,7 @@ def test_init(self): def test_init_non_framed_message_too_large(self): with six.assertRaisesRegex(self, SerializationError, "Source too large for non-framed message"): StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider, frame_length=0, algorithm=self.mock_algorithm, @@ -200,7 +198,7 @@ def test_init_non_framed_message_too_large(self): def test_prep_message_no_master_keys(self): self.mock_key_provider.master_keys_for_encryption.return_value = sentinel.primary_master_key, set() test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider, frame_length=self.mock_frame_length, source_length=5, @@ -215,7 +213,7 @@ def test_prep_message_primary_master_key_not_in_master_keys(self): self.mock_master_keys_set, ) test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider, frame_length=self.mock_frame_length, source_length=5, @@ -227,7 +225,7 @@ def test_prep_message_primary_master_key_not_in_master_keys(self): def test_prep_message_algorithm_change(self): self.mock_encryption_materials.algorithm = Algorithm.AES_256_GCM_IV12_TAG16 test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), materials_manager=self.mock_materials_manager, algorithm=Algorithm.AES_128_GCM_IV12_TAG16, source_length=128, @@ -255,7 +253,7 @@ def test_prep_message_framed_message( ): mock_rostream.return_value = sentinel.plaintext_rostream test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), materials_manager=self.mock_materials_manager, frame_length=self.mock_frame_length, source_length=5, @@ -359,7 +357,7 @@ def test_write_header(self): @patch("aws_encryption_sdk.streaming_client.non_framed_body_iv") def test_prep_non_framed(self, mock_non_framed_iv): self.mock_serialize_non_framed_open.return_value = b"1234567890" - test_encryptor = StreamEncryptor(source=self.mock_input_stream, key_provider=self.mock_key_provider) + test_encryptor = StreamEncryptor(source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider) test_encryptor.signer = sentinel.signer test_encryptor._encryption_materials = self.mock_encryption_materials test_encryptor._header = MagicMock() @@ -411,7 +409,7 @@ def test_read_bytes_to_non_framed_body_too_large(self): test_encryptor._read_bytes_to_non_framed_body(5) def test_read_bytes_to_non_framed_body_close(self): - test_encryptor = StreamEncryptor(source=self.mock_input_stream, key_provider=self.mock_key_provider) + test_encryptor = StreamEncryptor(source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider) test_encryptor.signer = MagicMock() test_encryptor._encryption_materials = self.mock_encryption_materials test_encryptor.encryptor = MagicMock() @@ -420,7 +418,9 @@ def test_read_bytes_to_non_framed_body_close(self): test_encryptor.encryptor.tag = sentinel.tag self.mock_serialize_non_framed_close.return_value = b"789" self.mock_serialize_footer.return_value = b"0-=" + test = test_encryptor._read_bytes_to_non_framed_body(len(self.plaintext) + 1) + test_encryptor.signer.update.assert_has_calls(calls=(call(b"123"), call(b"456")), any_order=False) assert test_encryptor.source_stream.closed test_encryptor.encryptor.finalize.assert_called_once_with() From 3fdacd77e53913f1b23746c68f7da07f8113e52f Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 21:26:49 -0800 Subject: [PATCH 10/13] add #24 change record to changelog --- CHANGELOG.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0942cba48..0324fc6ac 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,10 +8,12 @@ Changelog Minor ----- -* Add support to remove clients from :ref:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint. +* Add support to remove clients from :class:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint. `#86 `_ * Add support for SHA384 and SHA512 for use with RSA OAEP wrapping algorithms. `#56 `_ +* Fix ``streaming_client`` classes to properly interpret short reads in source streams. + `#24 `_ 1.3.7 -- 2018-09-20 =================== From de4d7f0dfc3bbc929255737ad4cc29dbe19623db Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 21:41:37 -0800 Subject: [PATCH 11/13] remove unnecessary caplog use --- test/functional/test_f_aws_encryption_sdk_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index b22a8eb8a..58dcb0958 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -628,8 +628,7 @@ def read(self, size=-1): 256, # 256: framed with inexact final frame size match ), ) -def test_incomplete_read_stream_cycle(caplog, frame_length): - caplog.set_level(logging.DEBUG) +def test_incomplete_read_stream_cycle(frame_length): chunk_size = 21 # Will never be an exact match for the frame size key_provider = fake_kms_key_provider() From f47f942bcee09a4fe46c00f87350583d72705b14 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 22:17:25 -0800 Subject: [PATCH 12/13] actually test InsistentReaderBytesIO --- test/unit/test_util_streams.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unit/test_util_streams.py b/test/unit/test_util_streams.py index b896a3a3f..0bea65a23 100644 --- a/test/unit/test_util_streams.py +++ b/test/unit/test_util_streams.py @@ -17,7 +17,7 @@ from aws_encryption_sdk.exceptions import ActionNotAllowedError from aws_encryption_sdk.internal.str_ops import to_bytes, to_str -from aws_encryption_sdk.internal.utils.streams import ROStream, TeeStream +from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO, ROStream, TeeStream from .unit_test_utils import NothingButRead, SometimesIncompleteReaderIO @@ -67,7 +67,7 @@ def test_teestream_full(): @pytest.mark.parametrize("bytes_to_read", range(1, 102)) @pytest.mark.parametrize("source_length", (1, 11, 100)) def test_insistent_stream(source_length, bytes_to_read, stream_type, converter): - source = data(length=source_length, stream_type=stream_type, converter=converter) + source = InsistentReaderBytesIO(data(length=source_length, stream_type=stream_type, converter=converter)) test = source.read(bytes_to_read) From 74355b065971c9862acbf2d57bdf173712da4e09 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 12 Nov 2018 22:18:43 -0800 Subject: [PATCH 13/13] make InsistentReaderBytesIO properly handle a stream that closes while being read from --- src/aws_encryption_sdk/internal/utils/streams.py | 7 ++++++- test/unit/test_util_streams.py | 15 ++++++++++++++- test/unit/unit_test_utils.py | 13 ++++++++++--- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/aws_encryption_sdk/internal/utils/streams.py b/src/aws_encryption_sdk/internal/utils/streams.py index 8c50d4449..ef9244cad 100644 --- a/src/aws_encryption_sdk/internal/utils/streams.py +++ b/src/aws_encryption_sdk/internal/utils/streams.py @@ -81,7 +81,12 @@ def read(self, b=-1): remaining_bytes = b data = io.BytesIO() while True: - chunk = to_bytes(self.__wrapped__.read(remaining_bytes)) + try: + chunk = to_bytes(self.__wrapped__.read(remaining_bytes)) + except ValueError: + if self.__wrapped__.closed: + break + raise if not chunk: break diff --git a/test/unit/test_util_streams.py b/test/unit/test_util_streams.py index 0bea65a23..ab7b05152 100644 --- a/test/unit/test_util_streams.py +++ b/test/unit/test_util_streams.py @@ -19,7 +19,7 @@ from aws_encryption_sdk.internal.str_ops import to_bytes, to_str from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO, ROStream, TeeStream -from .unit_test_utils import NothingButRead, SometimesIncompleteReaderIO +from .unit_test_utils import ExactlyTwoReads, NothingButRead, SometimesIncompleteReaderIO pytestmark = [pytest.mark.unit, pytest.mark.local] @@ -74,3 +74,16 @@ def test_insistent_stream(source_length, bytes_to_read, stream_type, converter): assert (source_length >= bytes_to_read and len(test) == bytes_to_read) or ( source_length < bytes_to_read and len(test) == source_length ) + + +def test_insistent_stream_close_partway_through(): + raw = data(length=100) + source = ExactlyTwoReads(raw.getvalue()) + + wrapped = InsistentReaderBytesIO(source) + + test = b"" + test += wrapped.read(10) # actually reads 10 bytes + test += wrapped.read(10) # reads 5 bytes, stream is closed before third read can complete, truncating the result + + assert test == raw.getvalue()[:15] diff --git a/test/unit/unit_test_utils.py b/test/unit/unit_test_utils.py index b3ac25cee..7873456c1 100644 --- a/test/unit/unit_test_utils.py +++ b/test/unit/unit_test_utils.py @@ -55,13 +55,13 @@ def build_valid_kwargs_list(base, optional_kwargs): class SometimesIncompleteReaderIO(io.BytesIO): def __init__(self, *args, **kwargs): - self.__read_counter = 0 + self._read_counter = 0 super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs) def read(self, size=-1): """Every other read request, return fewer than the requested number of bytes if more than one byte requested.""" - self.__read_counter += 1 - if size > 1 and self.__read_counter % 2 == 0: + self._read_counter += 1 + if size > 1 and self._read_counter % 2 == 0: size //= 2 return super(SometimesIncompleteReaderIO, self).read(size) @@ -72,3 +72,10 @@ def __init__(self, data): def read(self, size=-1): return self._data.read(size) + + +class ExactlyTwoReads(SometimesIncompleteReaderIO): + def read(self, size=-1): + if self._read_counter >= 2: + self.close() + return super(ExactlyTwoReads, self).read(size)