Skip to content

Commit 4b0fb9f

Browse files
committed
handle partial reads from source streams
* Add InsistentReaderBytesIO wrapper that will keep reading from the wrapped stream until empty or enough bytes have been collected * update prep_stream_data to make sure that all input data is wrapped in InsistentReaderBytesIO
1 parent cbea075 commit 4b0fb9f

File tree

5 files changed

+101
-6
lines changed

5 files changed

+101
-6
lines changed

src/aws_encryption_sdk/internal/utils/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from aws_encryption_sdk.internal.str_ops import to_bytes
2424
from aws_encryption_sdk.structures import EncryptedDataKey
2525

26+
from .streams import InsistentReaderBytesIO
27+
2628
_LOGGER = logging.getLogger(__name__)
2729

2830

@@ -132,12 +134,14 @@ def prep_stream_data(data):
132134
133135
:param data: Input data
134136
:returns: Prepared stream
135-
:rtype: io.BytesIO
137+
:rtype: InsistentReaderBytesIO
136138
"""
137139
if isinstance(data, (six.string_types, six.binary_type)):
138-
return io.BytesIO(to_bytes(data))
140+
stream = io.BytesIO(to_bytes(data))
141+
else:
142+
stream = data
139143

140-
return data
144+
return InsistentReaderBytesIO(stream)
141145

142146

143147
def source_data_key_length_check(source_data_key, algorithm):

src/aws_encryption_sdk/internal/utils/streams.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Helper stream utility objects for AWS Encryption SDK."""
14+
import io
15+
1416
from wrapt import ObjectProxy
1517

1618
from aws_encryption_sdk.exceptions import ActionNotAllowedError
19+
from aws_encryption_sdk.internal.str_ops import to_bytes
1720

1821

1922
class ROStream(ObjectProxy):
@@ -56,3 +59,36 @@ def read(self, b=None):
5659
data = self.__wrapped__.read(b)
5760
self.__tee.write(data)
5861
return data
62+
63+
64+
class InsistentReaderBytesIO(ObjectProxy):
65+
"""Wrapper around a readable stream that insists on reading exactly the requested
66+
number of bytes. It will keep trying to read bytes from the wrapped stream until
67+
either the requested number of bytes are available or the wrapped stream has
68+
nothing more to return.
69+
70+
:param wrapped: File-like object
71+
"""
72+
73+
def read(self, b=-1):
74+
"""Keep reading from source stream until either the source stream is done
75+
or the requested number of bytes have been obtained.
76+
77+
:param int b: number of bytes to read
78+
:return: All bytes read from wrapped stream
79+
:rtype: bytes
80+
"""
81+
remaining_bytes = b
82+
data = io.BytesIO()
83+
while True:
84+
chunk = to_bytes(self.__wrapped__.read(remaining_bytes))
85+
86+
if not chunk:
87+
break
88+
89+
data.write(chunk)
90+
remaining_bytes -= len(chunk)
91+
92+
if remaining_bytes <= 0:
93+
break
94+
return data.getvalue()

test/functional/test_f_aws_encryption_sdk_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,4 +669,4 @@ def test_incomplete_read_stream_cycle(caplog, frame_length):
669669
"Unexpected error encrypting message: infinite loop detected."
670670
)
671671

672-
assert ciphertext != decrypted == VALUES["plaintext_128"]
672+
assert ciphertext != decrypted == plaintext

test/unit/test_util_streams.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,25 @@
1616
import pytest
1717

1818
from aws_encryption_sdk.exceptions import ActionNotAllowedError
19+
from aws_encryption_sdk.internal.str_ops import to_bytes, to_str
1920
from aws_encryption_sdk.internal.utils.streams import ROStream, TeeStream
2021

22+
from .unit_test_utils import NothingButRead, SometimesIncompleteReaderIO
23+
2124
pytestmark = [pytest.mark.unit, pytest.mark.local]
2225

2326

24-
def data():
25-
return io.BytesIO(b"asdijfhoaisjdfoiasjdfoijawef")
27+
def data(length=None, stream_type=io.BytesIO, converter=to_bytes):
28+
source = b"asdijfhoaisjdfoiasjdfoijawef"
29+
chunk_length = 100
30+
31+
if length is None:
32+
length = len(source)
33+
34+
while len(source) < length:
35+
source += source[:chunk_length]
36+
37+
return stream_type(converter(source[:length]))
2638

2739

2840
def test_rostream():
@@ -41,3 +53,24 @@ def test_teestream_full():
4153
raw_read = test_tee.read()
4254

4355
assert data().getvalue() == raw_read == new_tee.getvalue()
56+
57+
58+
@pytest.mark.parametrize(
59+
"stream_type, converter",
60+
(
61+
(io.BytesIO, to_bytes),
62+
(SometimesIncompleteReaderIO, to_bytes),
63+
(io.StringIO, to_str),
64+
(NothingButRead, to_bytes),
65+
),
66+
)
67+
@pytest.mark.parametrize("bytes_to_read", (1, 4, 7, 99, 500))
68+
@pytest.mark.parametrize("source_length", (1, 11, 100))
69+
def test_insistent_stream(source_length, bytes_to_read, stream_type, converter):
70+
source = data(length=source_length, stream_type=stream_type, converter=converter)
71+
72+
test = source.read(bytes_to_read)
73+
74+
assert (source_length >= bytes_to_read and len(test) == bytes_to_read) or (
75+
source_length < bytes_to_read and len(test) == source_length
76+
)

test/unit/unit_test_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Utility functions to handle common test framework functions."""
1414
import copy
15+
import io
1516
import itertools
1617

1718

@@ -50,3 +51,24 @@ def build_valid_kwargs_list(base, optional_kwargs):
5051
_kwargs.update(dict(valid_options))
5152
valid_kwargs.append(_kwargs)
5253
return valid_kwargs
54+
55+
56+
class SometimesIncompleteReaderIO(io.BytesIO):
57+
def __init__(self, *args, **kwargs):
58+
self.__read_counter = 0
59+
super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs)
60+
61+
def read(self, size=-1):
62+
"""Every other read request, return fewer than the requested number of bytes if more than one byte requested."""
63+
self.__read_counter += 1
64+
if size > 1 and self.__read_counter % 2 == 0:
65+
size //= 2
66+
return super(SometimesIncompleteReaderIO, self).read(size)
67+
68+
69+
class NothingButRead(object):
70+
def __init__(self, data):
71+
self._data = io.BytesIO(data)
72+
73+
def read(self, size=-1):
74+
return self._data.read(size)

0 commit comments

Comments
 (0)