Skip to content

Commit cbea075

Browse files
committed
fix merge
1 parent 504d08b commit cbea075

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

src/aws_encryption_sdk/streaming_client.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,20 @@ def read(self, b=-1):
209209
:returns: Processed (encrypted or decrypted) bytes from source stream
210210
:rtype: bytes
211211
"""
212-
# NoneType value for b is interpretted as a full read for legacy compatibility
213-
if b is None:
212+
# Any negative value for b is interpreted as a full read
213+
# None is also accepted for legacy compatibility
214+
if b is None or b < 0:
214215
b = -1
215216

216217
_LOGGER.debug("Stream read called, requesting %s bytes", b)
217218
output = io.BytesIO()
219+
218220
if not self._message_prepped:
219221
self._prep_message()
222+
220223
if self.closed:
221224
raise ValueError("I/O operation on closed file")
225+
222226
if b:
223227
self._read_bytes(b)
224228
output.write(self.output_buffer[:b])
@@ -228,6 +232,7 @@ def read(self, b=-1):
228232
self._read_bytes(LINE_LENGTH)
229233
output.write(self.output_buffer)
230234
self.output_buffer = b""
235+
231236
self.bytes_read += output.tell()
232237
_LOGGER.debug("Returning %s bytes of %s bytes requested", output.tell(), b)
233238
return output.getvalue()
@@ -511,14 +516,18 @@ def _read_bytes_to_non_framed_body(self, b):
511516
_LOGGER.debug("Closing encryptor after receiving only %s bytes of %s bytes requested", plaintext, b)
512517
self.source_stream.close()
513518
closing = self.encryptor.finalize()
519+
514520
if self.signer is not None:
515521
self.signer.update(closing)
522+
516523
closing += aws_encryption_sdk.internal.formatting.serialize.serialize_non_framed_close(
517524
tag=self.encryptor.tag, signer=self.signer
518525
)
526+
519527
if self.signer is not None:
520528
closing += aws_encryption_sdk.internal.formatting.serialize.serialize_footer(self.signer)
521529
return ciphertext + closing
530+
522531
return ciphertext
523532

524533
def _read_bytes_to_framed_body(self, b):
@@ -535,9 +544,11 @@ def _read_bytes_to_framed_body(self, b):
535544
plaintext = self.source_stream.read(b)
536545
_LOGGER.debug("%s bytes read from source", len(plaintext))
537546
finalize = False
547+
538548
if len(plaintext) < b:
539549
_LOGGER.debug("Final plaintext read from source")
540550
finalize = True
551+
541552
output = b""
542553
final_frame_written = False
543554

@@ -776,10 +787,13 @@ def _read_bytes_from_non_framed_body(self, b):
776787
bytes_to_read = self.body_end - self.source_stream.tell()
777788
_LOGGER.debug("%s bytes requested; reading %s bytes", b, bytes_to_read)
778789
ciphertext = self.source_stream.read(bytes_to_read)
790+
779791
if len(self.output_buffer) + len(ciphertext) < self.body_length:
780792
raise SerializationError("Total message body contents less than specified in body description")
793+
781794
if self.verifier is not None:
782795
self.verifier.update(ciphertext)
796+
783797
plaintext = self.decryptor.update(ciphertext)
784798
plaintext += self.decryptor.finalize()
785799
aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag(

test/functional/test_f_aws_encryption_sdk_client.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,75 @@ def test_stream_decryptor_readable():
598598
assert handler.readable()
599599
handler.read()
600600
assert not handler.readable()
601+
602+
603+
def exact_length_plaintext(length):
604+
plaintext = b""
605+
while len(plaintext) < length:
606+
plaintext += VALUES["plaintext_128"]
607+
return plaintext[:length]
608+
609+
610+
class SometimesIncompleteReaderIO(io.BytesIO):
611+
def __init__(self, *args, **kwargs):
612+
self.__read_counter = 0
613+
super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs)
614+
615+
def read(self, size=-1):
616+
"""Every other read request, return fewer than the requested number of bytes if more than one byte requested."""
617+
self.__read_counter += 1
618+
if size > 1 and self.__read_counter % 2 == 0:
619+
size //= 2
620+
return super(SometimesIncompleteReaderIO, self).read(size)
621+
622+
623+
@pytest.mark.parametrize(
624+
"frame_length",
625+
(
626+
0, # 0: unframed
627+
128, # 128: framed with exact final frame size match
628+
256, # 256: framed with inexact final frame size match
629+
),
630+
)
631+
def test_incomplete_read_stream_cycle(caplog, frame_length):
632+
caplog.set_level(logging.DEBUG)
633+
chunk_size = 21 # Will never be an exact match for the frame size
634+
key_provider = fake_kms_key_provider()
635+
636+
plaintext = exact_length_plaintext(384)
637+
ciphertext = b""
638+
cycle_count = 0
639+
with aws_encryption_sdk.stream(
640+
mode="encrypt",
641+
source=SometimesIncompleteReaderIO(plaintext),
642+
key_provider=key_provider,
643+
frame_length=frame_length,
644+
) as encryptor:
645+
while True:
646+
cycle_count += 1
647+
chunk = encryptor.read(chunk_size)
648+
if not chunk:
649+
break
650+
ciphertext += chunk
651+
if cycle_count > len(VALUES["plaintext_128"]):
652+
raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError(
653+
"Unexpected error encrypting message: infinite loop detected."
654+
)
655+
656+
decrypted = b""
657+
cycle_count = 0
658+
with aws_encryption_sdk.stream(
659+
mode="decrypt", source=SometimesIncompleteReaderIO(ciphertext), key_provider=key_provider
660+
) as decryptor:
661+
while True:
662+
cycle_count += 1
663+
chunk = decryptor.read(chunk_size)
664+
if not chunk:
665+
break
666+
decrypted += chunk
667+
if cycle_count > len(VALUES["plaintext_128"]):
668+
raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError(
669+
"Unexpected error encrypting message: infinite loop detected."
670+
)
671+
672+
assert ciphertext != decrypted == VALUES["plaintext_128"]

0 commit comments

Comments
 (0)