diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0942cba48..0324fc6ac 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,10 +8,12 @@ Changelog Minor ----- -* Add support to remove clients from :ref:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint. +* Add support to remove clients from :class:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint. `#86 `_ * Add support for SHA384 and SHA512 for use with RSA OAEP wrapping algorithms. `#56 `_ +* Fix ``streaming_client`` classes to properly interpret short reads in source streams. + `#24 `_ 1.3.7 -- 2018-09-20 =================== diff --git a/setup.cfg b/setup.cfg index 366846d99..038fc5924 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,6 +11,7 @@ branch = True show_missing = True [tool:pytest] +log_level = DEBUG markers = local: superset of unit and functional (does not require network access) unit: mark test as a unit test (does not require network access) diff --git a/src/aws_encryption_sdk/internal/utils/__init__.py b/src/aws_encryption_sdk/internal/utils/__init__.py index 065722d0d..1e7400c3a 100644 --- a/src/aws_encryption_sdk/internal/utils/__init__.py +++ b/src/aws_encryption_sdk/internal/utils/__init__.py @@ -23,6 +23,8 @@ from aws_encryption_sdk.internal.str_ops import to_bytes from aws_encryption_sdk.structures import EncryptedDataKey +from .streams import InsistentReaderBytesIO + _LOGGER = logging.getLogger(__name__) @@ -132,12 +134,14 @@ def prep_stream_data(data): :param data: Input data :returns: Prepared stream - :rtype: io.BytesIO + :rtype: InsistentReaderBytesIO """ if isinstance(data, (six.string_types, six.binary_type)): - return io.BytesIO(to_bytes(data)) + stream = io.BytesIO(to_bytes(data)) + else: + stream = data - return data + return InsistentReaderBytesIO(stream) def source_data_key_length_check(source_data_key, algorithm): diff --git a/src/aws_encryption_sdk/internal/utils/streams.py b/src/aws_encryption_sdk/internal/utils/streams.py index b1bf5953c..ef9244cad 100644 --- a/src/aws_encryption_sdk/internal/utils/streams.py +++ b/src/aws_encryption_sdk/internal/utils/streams.py @@ -11,9 +11,12 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Helper stream utility objects for AWS Encryption SDK.""" +import io + from wrapt import ObjectProxy from aws_encryption_sdk.exceptions import ActionNotAllowedError +from aws_encryption_sdk.internal.str_ops import to_bytes class ROStream(ObjectProxy): @@ -56,3 +59,41 @@ def read(self, b=None): data = self.__wrapped__.read(b) self.__tee.write(data) return data + + +class InsistentReaderBytesIO(ObjectProxy): + """Wrapper around a readable stream that insists on reading exactly the requested + number of bytes. It will keep trying to read bytes from the wrapped stream until + either the requested number of bytes are available or the wrapped stream has + nothing more to return. + + :param wrapped: File-like object + """ + + def read(self, b=-1): + """Keep reading from source stream until either the source stream is done + or the requested number of bytes have been obtained. + + :param int b: number of bytes to read + :return: All bytes read from wrapped stream + :rtype: bytes + """ + remaining_bytes = b + data = io.BytesIO() + while True: + try: + chunk = to_bytes(self.__wrapped__.read(remaining_bytes)) + except ValueError: + if self.__wrapped__.closed: + break + raise + + if not chunk: + break + + data.write(chunk) + remaining_bytes -= len(chunk) + + if remaining_bytes <= 0: + break + return data.getvalue() diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 05f3a8a01..4e6bd43a9 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -202,7 +202,7 @@ def readable(self): # Open streams are currently always readable. return not self.closed - def read(self, b=None): + def read(self, b=-1): """Returns either the requested number of bytes or the entire stream. :param int b: Number of bytes to read @@ -210,16 +210,20 @@ def read(self, b=None): :rtype: bytes """ # Any negative value for b is interpreted as a full read - if b is not None and b < 0: - b = None + # None is also accepted for legacy compatibility + if b is None or b < 0: + b = -1 _LOGGER.debug("Stream read called, requesting %s bytes", b) output = io.BytesIO() + if not self._message_prepped: self._prep_message() + if self.closed: raise ValueError("I/O operation on closed file") - if b: + + if b >= 0: self._read_bytes(b) output.write(self.output_buffer[:b]) self.output_buffer = self.output_buffer[b:] @@ -228,6 +232,7 @@ def read(self, b=None): self._read_bytes(LINE_LENGTH) output.write(self.output_buffer) self.output_buffer = b"" + self.bytes_read += output.tell() _LOGGER.debug("Returning %s bytes of %s bytes requested", output.tell(), b) return output.getvalue() @@ -511,14 +516,18 @@ def _read_bytes_to_non_framed_body(self, b): _LOGGER.debug("Closing encryptor after receiving only %s bytes of %s bytes requested", plaintext, b) self.source_stream.close() closing = self.encryptor.finalize() + if self.signer is not None: self.signer.update(closing) + closing += aws_encryption_sdk.internal.formatting.serialize.serialize_non_framed_close( tag=self.encryptor.tag, signer=self.signer ) + if self.signer is not None: closing += aws_encryption_sdk.internal.formatting.serialize.serialize_footer(self.signer) return ciphertext + closing + return ciphertext def _read_bytes_to_framed_body(self, b): @@ -530,14 +539,22 @@ def _read_bytes_to_framed_body(self, b): """ _LOGGER.debug("collecting %s bytes", b) _b = b - b = int(math.ceil(b / float(self.config.frame_length)) * self.config.frame_length) - _LOGGER.debug("%s bytes requested; reading %s bytes after normalizing to frame length", _b, b) + + if b > 0: + _frames_to_read = math.ceil(b / float(self.config.frame_length)) + b = int(_frames_to_read * self.config.frame_length) + _LOGGER.debug("%d bytes requested; reading %d bytes after normalizing to frame length", _b, b) + plaintext = self.source_stream.read(b) - _LOGGER.debug("%s bytes read from source", len(plaintext)) + plaintext_length = len(plaintext) + _LOGGER.debug("%d bytes read from source", plaintext_length) + finalize = False - if len(plaintext) < b: + + if b < 0 or plaintext_length < b: _LOGGER.debug("Final plaintext read from source") finalize = True + output = b"" final_frame_written = False @@ -583,8 +600,8 @@ def _read_bytes(self, b): :param int b: Number of bytes to read :raises NotSupportedError: if content type is not supported """ - _LOGGER.debug("%s bytes requested from stream with content type: %s", b, self.content_type) - if b <= len(self.output_buffer) or self.source_stream.closed: + _LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type) + if 0 <= b <= len(self.output_buffer) or self.source_stream.closed: _LOGGER.debug("No need to read from source stream or source stream closed") return @@ -776,10 +793,13 @@ def _read_bytes_from_non_framed_body(self, b): bytes_to_read = self.body_end - self.source_stream.tell() _LOGGER.debug("%s bytes requested; reading %s bytes", b, bytes_to_read) ciphertext = self.source_stream.read(bytes_to_read) + if len(self.output_buffer) + len(ciphertext) < self.body_length: raise SerializationError("Total message body contents less than specified in body description") + if self.verifier is not None: self.verifier.update(ciphertext) + plaintext = self.decryptor.update(ciphertext) plaintext += self.decryptor.finalize() aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag( @@ -844,10 +864,9 @@ def _read_bytes(self, b): _LOGGER.debug("Source stream closed") return - if b <= len(self.output_buffer): - _LOGGER.debug( - "%s bytes requested less than or equal to current output buffer size %s", b, len(self.output_buffer) - ) + buffer_length = len(self.output_buffer) + if 0 <= b <= buffer_length: + _LOGGER.debug("%d bytes requested less than or equal to current output buffer size %d", b, buffer_length) return if self._header.content_type == ContentType.FRAMED_DATA: diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index c76f8d3f0..58dcb0958 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -598,3 +598,74 @@ def test_stream_decryptor_readable(): assert handler.readable() handler.read() assert not handler.readable() + + +def exact_length_plaintext(length): + plaintext = b"" + while len(plaintext) < length: + plaintext += VALUES["plaintext_128"] + return plaintext[:length] + + +class SometimesIncompleteReaderIO(io.BytesIO): + def __init__(self, *args, **kwargs): + self.__read_counter = 0 + super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs) + + def read(self, size=-1): + """Every other read request, return fewer than the requested number of bytes if more than one byte requested.""" + self.__read_counter += 1 + if size > 1 and self.__read_counter % 2 == 0: + size //= 2 + return super(SometimesIncompleteReaderIO, self).read(size) + + +@pytest.mark.parametrize( + "frame_length", + ( + 0, # 0: unframed + 128, # 128: framed with exact final frame size match + 256, # 256: framed with inexact final frame size match + ), +) +def test_incomplete_read_stream_cycle(frame_length): + chunk_size = 21 # Will never be an exact match for the frame size + key_provider = fake_kms_key_provider() + + plaintext = exact_length_plaintext(384) + ciphertext = b"" + cycle_count = 0 + with aws_encryption_sdk.stream( + mode="encrypt", + source=SometimesIncompleteReaderIO(plaintext), + key_provider=key_provider, + frame_length=frame_length, + ) as encryptor: + while True: + cycle_count += 1 + chunk = encryptor.read(chunk_size) + if not chunk: + break + ciphertext += chunk + if cycle_count > len(VALUES["plaintext_128"]): + raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError( + "Unexpected error encrypting message: infinite loop detected." + ) + + decrypted = b"" + cycle_count = 0 + with aws_encryption_sdk.stream( + mode="decrypt", source=SometimesIncompleteReaderIO(ciphertext), key_provider=key_provider + ) as decryptor: + while True: + cycle_count += 1 + chunk = decryptor.read(chunk_size) + if not chunk: + break + decrypted += chunk + if cycle_count > len(VALUES["plaintext_128"]): + raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError( + "Unexpected error encrypting message: infinite loop detected." + ) + + assert ciphertext != decrypted == plaintext diff --git a/test/unit/test_streaming_client_encryption_stream.py b/test/unit/test_streaming_client_encryption_stream.py index 98b0aad81..22cefffb0 100644 --- a/test/unit/test_streaming_client_encryption_stream.py +++ b/test/unit/test_streaming_client_encryption_stream.py @@ -20,6 +20,7 @@ import aws_encryption_sdk.exceptions from aws_encryption_sdk.internal.defaults import LINE_LENGTH +from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO from aws_encryption_sdk.key_providers.base import MasterKeyProvider from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream @@ -107,17 +108,19 @@ def test_new_with_params(self): line_length=io.DEFAULT_BUFFER_SIZE, source_length=mock_int_sentinel, ) - assert mock_stream.config == MockClientConfig( - source=self.mock_source_stream, - key_provider=self.mock_key_provider, - mock_read_bytes=sentinel.read_bytes, - line_length=io.DEFAULT_BUFFER_SIZE, - source_length=mock_int_sentinel, - ) + + assert mock_stream.config.source == self.mock_source_stream + assert isinstance(mock_stream.config.source, InsistentReaderBytesIO) + assert mock_stream.config.key_provider is self.mock_key_provider + assert mock_stream.config.mock_read_bytes is sentinel.read_bytes + assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE + assert mock_stream.config.source_length is mock_int_sentinel + assert mock_stream.bytes_read == 0 assert mock_stream.output_buffer == b"" assert not mock_stream._message_prepped - assert mock_stream.source_stream is self.mock_source_stream + assert mock_stream.source_stream == self.mock_source_stream + assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO) assert mock_stream._stream_length is mock_int_sentinel assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index 3b824e199..50e981b16 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -261,7 +261,7 @@ def test_prep_non_framed(self): test_decryptor._prep_non_framed() self.mock_deserialize_non_framed_values.assert_called_once_with( - stream=self.mock_input_stream, header=self.mock_header, verifier=sentinel.verifier + stream=test_decryptor.source_stream, header=self.mock_header, verifier=sentinel.verifier ) assert test_decryptor.body_length == len(VALUES["data_128"]) self.mock_get_aad_content_string.assert_called_once_with( diff --git a/test/unit/test_streaming_client_stream_encryptor.py b/test/unit/test_streaming_client_stream_encryptor.py index f3747a98a..96a3f2a43 100644 --- a/test/unit/test_streaming_client_stream_encryptor.py +++ b/test/unit/test_streaming_client_stream_encryptor.py @@ -58,8 +58,6 @@ def setUp(self): self.mock_master_key = MagicMock(__class__=MasterKey) - self.mock_input_stream = MagicMock(__class__=io.IOBase) - self.mock_frame_length = MagicMock(__class__=int) self.mock_algorithm = MagicMock(__class__=Algorithm) @@ -178,7 +176,7 @@ def tearDown(self): def test_init(self): test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider, frame_length=self.mock_frame_length, algorithm=self.mock_algorithm, @@ -190,7 +188,7 @@ def test_init(self): def test_init_non_framed_message_too_large(self): with six.assertRaisesRegex(self, SerializationError, "Source too large for non-framed message"): StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider, frame_length=0, algorithm=self.mock_algorithm, @@ -200,7 +198,7 @@ def test_init_non_framed_message_too_large(self): def test_prep_message_no_master_keys(self): self.mock_key_provider.master_keys_for_encryption.return_value = sentinel.primary_master_key, set() test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider, frame_length=self.mock_frame_length, source_length=5, @@ -215,7 +213,7 @@ def test_prep_message_primary_master_key_not_in_master_keys(self): self.mock_master_keys_set, ) test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider, frame_length=self.mock_frame_length, source_length=5, @@ -227,7 +225,7 @@ def test_prep_message_primary_master_key_not_in_master_keys(self): def test_prep_message_algorithm_change(self): self.mock_encryption_materials.algorithm = Algorithm.AES_256_GCM_IV12_TAG16 test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), materials_manager=self.mock_materials_manager, algorithm=Algorithm.AES_128_GCM_IV12_TAG16, source_length=128, @@ -255,7 +253,7 @@ def test_prep_message_framed_message( ): mock_rostream.return_value = sentinel.plaintext_rostream test_encryptor = StreamEncryptor( - source=self.mock_input_stream, + source=io.BytesIO(self.plaintext), materials_manager=self.mock_materials_manager, frame_length=self.mock_frame_length, source_length=5, @@ -359,7 +357,7 @@ def test_write_header(self): @patch("aws_encryption_sdk.streaming_client.non_framed_body_iv") def test_prep_non_framed(self, mock_non_framed_iv): self.mock_serialize_non_framed_open.return_value = b"1234567890" - test_encryptor = StreamEncryptor(source=self.mock_input_stream, key_provider=self.mock_key_provider) + test_encryptor = StreamEncryptor(source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider) test_encryptor.signer = sentinel.signer test_encryptor._encryption_materials = self.mock_encryption_materials test_encryptor._header = MagicMock() @@ -411,7 +409,7 @@ def test_read_bytes_to_non_framed_body_too_large(self): test_encryptor._read_bytes_to_non_framed_body(5) def test_read_bytes_to_non_framed_body_close(self): - test_encryptor = StreamEncryptor(source=self.mock_input_stream, key_provider=self.mock_key_provider) + test_encryptor = StreamEncryptor(source=io.BytesIO(self.plaintext), key_provider=self.mock_key_provider) test_encryptor.signer = MagicMock() test_encryptor._encryption_materials = self.mock_encryption_materials test_encryptor.encryptor = MagicMock() @@ -420,7 +418,9 @@ def test_read_bytes_to_non_framed_body_close(self): test_encryptor.encryptor.tag = sentinel.tag self.mock_serialize_non_framed_close.return_value = b"789" self.mock_serialize_footer.return_value = b"0-=" + test = test_encryptor._read_bytes_to_non_framed_body(len(self.plaintext) + 1) + test_encryptor.signer.update.assert_has_calls(calls=(call(b"123"), call(b"456")), any_order=False) assert test_encryptor.source_stream.closed test_encryptor.encryptor.finalize.assert_called_once_with() diff --git a/test/unit/test_util_streams.py b/test/unit/test_util_streams.py index 15a79ad10..ab7b05152 100644 --- a/test/unit/test_util_streams.py +++ b/test/unit/test_util_streams.py @@ -16,13 +16,25 @@ import pytest from aws_encryption_sdk.exceptions import ActionNotAllowedError -from aws_encryption_sdk.internal.utils.streams import ROStream, TeeStream +from aws_encryption_sdk.internal.str_ops import to_bytes, to_str +from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO, ROStream, TeeStream + +from .unit_test_utils import ExactlyTwoReads, NothingButRead, SometimesIncompleteReaderIO pytestmark = [pytest.mark.unit, pytest.mark.local] -def data(): - return io.BytesIO(b"asdijfhoaisjdfoiasjdfoijawef") +def data(length=None, stream_type=io.BytesIO, converter=to_bytes): + source = b"asdijfhoaisjdfoiasjdfoijawef" + chunk_length = 100 + + if length is None: + length = len(source) + + while len(source) < length: + source += source[:chunk_length] + + return stream_type(converter(source[:length])) def test_rostream(): @@ -41,3 +53,37 @@ def test_teestream_full(): raw_read = test_tee.read() assert data().getvalue() == raw_read == new_tee.getvalue() + + +@pytest.mark.parametrize( + "stream_type, converter", + ( + (io.BytesIO, to_bytes), + (SometimesIncompleteReaderIO, to_bytes), + (io.StringIO, to_str), + (NothingButRead, to_bytes), + ), +) +@pytest.mark.parametrize("bytes_to_read", range(1, 102)) +@pytest.mark.parametrize("source_length", (1, 11, 100)) +def test_insistent_stream(source_length, bytes_to_read, stream_type, converter): + source = InsistentReaderBytesIO(data(length=source_length, stream_type=stream_type, converter=converter)) + + test = source.read(bytes_to_read) + + assert (source_length >= bytes_to_read and len(test) == bytes_to_read) or ( + source_length < bytes_to_read and len(test) == source_length + ) + + +def test_insistent_stream_close_partway_through(): + raw = data(length=100) + source = ExactlyTwoReads(raw.getvalue()) + + wrapped = InsistentReaderBytesIO(source) + + test = b"" + test += wrapped.read(10) # actually reads 10 bytes + test += wrapped.read(10) # reads 5 bytes, stream is closed before third read can complete, truncating the result + + assert test == raw.getvalue()[:15] diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index d74499787..c30247522 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -23,6 +23,7 @@ import aws_encryption_sdk.internal.utils from aws_encryption_sdk.exceptions import InvalidDataKeyError, SerializationError, UnknownIdentityError from aws_encryption_sdk.internal.defaults import MAX_FRAME_SIZE, MESSAGE_ID_LENGTH +from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO from aws_encryption_sdk.structures import DataKey, EncryptedDataKey, MasterKeyInfo, RawDataKey from .test_values import VALUES @@ -31,16 +32,19 @@ def test_prep_stream_data_passthrough(): - test = aws_encryption_sdk.internal.utils.prep_stream_data(sentinel.not_a_string_or_bytes) + test = aws_encryption_sdk.internal.utils.prep_stream_data(io.BytesIO(b"some data")) - assert test is sentinel.not_a_string_or_bytes + assert isinstance(test, InsistentReaderBytesIO) @pytest.mark.parametrize("source", (u"some unicode data ловие", b"\x00\x01\x02")) def test_prep_stream_data_wrap(source): test = aws_encryption_sdk.internal.utils.prep_stream_data(source) + # Check the wrapped stream assert isinstance(test, io.BytesIO) + # Check the wrapping stream + assert isinstance(test, InsistentReaderBytesIO) class TestUtils(unittest.TestCase): diff --git a/test/unit/unit_test_utils.py b/test/unit/unit_test_utils.py index 5c7c1d49d..7873456c1 100644 --- a/test/unit/unit_test_utils.py +++ b/test/unit/unit_test_utils.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """Utility functions to handle common test framework functions.""" import copy +import io import itertools @@ -50,3 +51,31 @@ def build_valid_kwargs_list(base, optional_kwargs): _kwargs.update(dict(valid_options)) valid_kwargs.append(_kwargs) return valid_kwargs + + +class SometimesIncompleteReaderIO(io.BytesIO): + def __init__(self, *args, **kwargs): + self._read_counter = 0 + super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs) + + def read(self, size=-1): + """Every other read request, return fewer than the requested number of bytes if more than one byte requested.""" + self._read_counter += 1 + if size > 1 and self._read_counter % 2 == 0: + size //= 2 + return super(SometimesIncompleteReaderIO, self).read(size) + + +class NothingButRead(object): + def __init__(self, data): + self._data = io.BytesIO(data) + + def read(self, size=-1): + return self._data.read(size) + + +class ExactlyTwoReads(SometimesIncompleteReaderIO): + def read(self, size=-1): + if self._read_counter >= 2: + self.close() + return super(ExactlyTwoReads, self).read(size)