Skip to content

Fix handling of partial reads #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ Changelog
Minor
-----

* Add support to remove clients from :ref:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint.
* Add support to remove clients from :class:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint.
`#86 <https://github.com/aws/aws-encryption-sdk-python/pull/86>`_
* Add support for SHA384 and SHA512 for use with RSA OAEP wrapping algorithms.
`#56 <https://github.com/aws/aws-encryption-sdk-python/issues/56>`_
* Fix ``streaming_client`` classes to properly interpret short reads in source streams.
`#24 <https://github.com/aws/aws-encryption-sdk-python/issues/24>`_

1.3.7 -- 2018-09-20
===================
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ branch = True
show_missing = True

[tool:pytest]
log_level = DEBUG
markers =
local: superset of unit and functional (does not require network access)
unit: mark test as a unit test (does not require network access)
Expand Down
10 changes: 7 additions & 3 deletions src/aws_encryption_sdk/internal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from aws_encryption_sdk.internal.str_ops import to_bytes
from aws_encryption_sdk.structures import EncryptedDataKey

from .streams import InsistentReaderBytesIO

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -132,12 +134,14 @@ def prep_stream_data(data):

:param data: Input data
:returns: Prepared stream
:rtype: io.BytesIO
:rtype: InsistentReaderBytesIO
"""
if isinstance(data, (six.string_types, six.binary_type)):
return io.BytesIO(to_bytes(data))
stream = io.BytesIO(to_bytes(data))
else:
stream = data

return data
return InsistentReaderBytesIO(stream)


def source_data_key_length_check(source_data_key, algorithm):
Expand Down
41 changes: 41 additions & 0 deletions src/aws_encryption_sdk/internal/utils/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Helper stream utility objects for AWS Encryption SDK."""
import io

from wrapt import ObjectProxy

from aws_encryption_sdk.exceptions import ActionNotAllowedError
from aws_encryption_sdk.internal.str_ops import to_bytes


class ROStream(ObjectProxy):
Expand Down Expand Up @@ -56,3 +59,41 @@ def read(self, b=None):
data = self.__wrapped__.read(b)
self.__tee.write(data)
return data


class InsistentReaderBytesIO(ObjectProxy):
"""Wrapper around a readable stream that insists on reading exactly the requested
number of bytes. It will keep trying to read bytes from the wrapped stream until
either the requested number of bytes are available or the wrapped stream has
nothing more to return.

:param wrapped: File-like object
"""

def read(self, b=-1):
"""Keep reading from source stream until either the source stream is done
or the requested number of bytes have been obtained.

:param int b: number of bytes to read
:return: All bytes read from wrapped stream
:rtype: bytes
"""
remaining_bytes = b
data = io.BytesIO()
while True:
try:
chunk = to_bytes(self.__wrapped__.read(remaining_bytes))
except ValueError:
if self.__wrapped__.closed:
break
raise

if not chunk:
break

data.write(chunk)
remaining_bytes -= len(chunk)

if remaining_bytes <= 0:
break
return data.getvalue()
47 changes: 33 additions & 14 deletions src/aws_encryption_sdk/streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,24 +202,28 @@ def readable(self):
# Open streams are currently always readable.
return not self.closed

def read(self, b=None):
def read(self, b=-1):
"""Returns either the requested number of bytes or the entire stream.

:param int b: Number of bytes to read
:returns: Processed (encrypted or decrypted) bytes from source stream
:rtype: bytes
"""
# Any negative value for b is interpreted as a full read
if b is not None and b < 0:
b = None
# None is also accepted for legacy compatibility
if b is None or b < 0:
b = -1

_LOGGER.debug("Stream read called, requesting %s bytes", b)
output = io.BytesIO()

if not self._message_prepped:
self._prep_message()

if self.closed:
raise ValueError("I/O operation on closed file")
if b:

if b >= 0:
self._read_bytes(b)
output.write(self.output_buffer[:b])
self.output_buffer = self.output_buffer[b:]
Expand All @@ -228,6 +232,7 @@ def read(self, b=None):
self._read_bytes(LINE_LENGTH)
output.write(self.output_buffer)
self.output_buffer = b""

self.bytes_read += output.tell()
_LOGGER.debug("Returning %s bytes of %s bytes requested", output.tell(), b)
return output.getvalue()
Expand Down Expand Up @@ -511,14 +516,18 @@ def _read_bytes_to_non_framed_body(self, b):
_LOGGER.debug("Closing encryptor after receiving only %s bytes of %s bytes requested", plaintext, b)
self.source_stream.close()
closing = self.encryptor.finalize()

if self.signer is not None:
self.signer.update(closing)

closing += aws_encryption_sdk.internal.formatting.serialize.serialize_non_framed_close(
tag=self.encryptor.tag, signer=self.signer
)

if self.signer is not None:
closing += aws_encryption_sdk.internal.formatting.serialize.serialize_footer(self.signer)
return ciphertext + closing

return ciphertext

def _read_bytes_to_framed_body(self, b):
Expand All @@ -530,14 +539,22 @@ def _read_bytes_to_framed_body(self, b):
"""
_LOGGER.debug("collecting %s bytes", b)
_b = b
b = int(math.ceil(b / float(self.config.frame_length)) * self.config.frame_length)
_LOGGER.debug("%s bytes requested; reading %s bytes after normalizing to frame length", _b, b)

if b > 0:
_frames_to_read = math.ceil(b / float(self.config.frame_length))
b = int(_frames_to_read * self.config.frame_length)
_LOGGER.debug("%d bytes requested; reading %d bytes after normalizing to frame length", _b, b)

plaintext = self.source_stream.read(b)
_LOGGER.debug("%s bytes read from source", len(plaintext))
plaintext_length = len(plaintext)
_LOGGER.debug("%d bytes read from source", plaintext_length)

finalize = False
if len(plaintext) < b:

if b < 0 or plaintext_length < b:
_LOGGER.debug("Final plaintext read from source")
finalize = True

output = b""
final_frame_written = False

Expand Down Expand Up @@ -583,8 +600,8 @@ def _read_bytes(self, b):
:param int b: Number of bytes to read
:raises NotSupportedError: if content type is not supported
"""
_LOGGER.debug("%s bytes requested from stream with content type: %s", b, self.content_type)
if b <= len(self.output_buffer) or self.source_stream.closed:
_LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type)
if 0 <= b <= len(self.output_buffer) or self.source_stream.closed:
_LOGGER.debug("No need to read from source stream or source stream closed")
return

Expand Down Expand Up @@ -776,10 +793,13 @@ def _read_bytes_from_non_framed_body(self, b):
bytes_to_read = self.body_end - self.source_stream.tell()
_LOGGER.debug("%s bytes requested; reading %s bytes", b, bytes_to_read)
ciphertext = self.source_stream.read(bytes_to_read)

if len(self.output_buffer) + len(ciphertext) < self.body_length:
raise SerializationError("Total message body contents less than specified in body description")

if self.verifier is not None:
self.verifier.update(ciphertext)

plaintext = self.decryptor.update(ciphertext)
plaintext += self.decryptor.finalize()
aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag(
Expand Down Expand Up @@ -844,10 +864,9 @@ def _read_bytes(self, b):
_LOGGER.debug("Source stream closed")
return

if b <= len(self.output_buffer):
_LOGGER.debug(
"%s bytes requested less than or equal to current output buffer size %s", b, len(self.output_buffer)
)
buffer_length = len(self.output_buffer)
if 0 <= b <= buffer_length:
_LOGGER.debug("%d bytes requested less than or equal to current output buffer size %d", b, buffer_length)
return

if self._header.content_type == ContentType.FRAMED_DATA:
Expand Down
71 changes: 71 additions & 0 deletions test/functional/test_f_aws_encryption_sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,74 @@ def test_stream_decryptor_readable():
assert handler.readable()
handler.read()
assert not handler.readable()


def exact_length_plaintext(length):
plaintext = b""
while len(plaintext) < length:
plaintext += VALUES["plaintext_128"]
return plaintext[:length]


class SometimesIncompleteReaderIO(io.BytesIO):
def __init__(self, *args, **kwargs):
self.__read_counter = 0
super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs)

def read(self, size=-1):
"""Every other read request, return fewer than the requested number of bytes if more than one byte requested."""
self.__read_counter += 1
if size > 1 and self.__read_counter % 2 == 0:
size //= 2
return super(SometimesIncompleteReaderIO, self).read(size)


@pytest.mark.parametrize(
"frame_length",
(
0, # 0: unframed
128, # 128: framed with exact final frame size match
256, # 256: framed with inexact final frame size match
),
)
def test_incomplete_read_stream_cycle(frame_length):
chunk_size = 21 # Will never be an exact match for the frame size
key_provider = fake_kms_key_provider()

plaintext = exact_length_plaintext(384)
ciphertext = b""
cycle_count = 0
with aws_encryption_sdk.stream(
mode="encrypt",
source=SometimesIncompleteReaderIO(plaintext),
key_provider=key_provider,
frame_length=frame_length,
) as encryptor:
while True:
cycle_count += 1
chunk = encryptor.read(chunk_size)
if not chunk:
break
ciphertext += chunk
if cycle_count > len(VALUES["plaintext_128"]):
raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError(
"Unexpected error encrypting message: infinite loop detected."
)

decrypted = b""
cycle_count = 0
with aws_encryption_sdk.stream(
mode="decrypt", source=SometimesIncompleteReaderIO(ciphertext), key_provider=key_provider
) as decryptor:
while True:
cycle_count += 1
chunk = decryptor.read(chunk_size)
if not chunk:
break
decrypted += chunk
if cycle_count > len(VALUES["plaintext_128"]):
raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError(
"Unexpected error encrypting message: infinite loop detected."
)

assert ciphertext != decrypted == plaintext
19 changes: 11 additions & 8 deletions test/unit/test_streaming_client_encryption_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import aws_encryption_sdk.exceptions
from aws_encryption_sdk.internal.defaults import LINE_LENGTH
from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO
from aws_encryption_sdk.key_providers.base import MasterKeyProvider
from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream

Expand Down Expand Up @@ -107,17 +108,19 @@ def test_new_with_params(self):
line_length=io.DEFAULT_BUFFER_SIZE,
source_length=mock_int_sentinel,
)
assert mock_stream.config == MockClientConfig(
source=self.mock_source_stream,
key_provider=self.mock_key_provider,
mock_read_bytes=sentinel.read_bytes,
line_length=io.DEFAULT_BUFFER_SIZE,
source_length=mock_int_sentinel,
)

assert mock_stream.config.source == self.mock_source_stream
assert isinstance(mock_stream.config.source, InsistentReaderBytesIO)
assert mock_stream.config.key_provider is self.mock_key_provider
assert mock_stream.config.mock_read_bytes is sentinel.read_bytes
assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE
assert mock_stream.config.source_length is mock_int_sentinel

assert mock_stream.bytes_read == 0
assert mock_stream.output_buffer == b""
assert not mock_stream._message_prepped
assert mock_stream.source_stream is self.mock_source_stream
assert mock_stream.source_stream == self.mock_source_stream
assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO)
assert mock_stream._stream_length is mock_int_sentinel
assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE

Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_streaming_client_stream_decryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_prep_non_framed(self):
test_decryptor._prep_non_framed()

self.mock_deserialize_non_framed_values.assert_called_once_with(
stream=self.mock_input_stream, header=self.mock_header, verifier=sentinel.verifier
stream=test_decryptor.source_stream, header=self.mock_header, verifier=sentinel.verifier
)
assert test_decryptor.body_length == len(VALUES["data_128"])
self.mock_get_aad_content_string.assert_called_once_with(
Expand Down
Loading