diff --git a/.changes/next-release/bugfix-AmazonS3-497d9da.json b/.changes/next-release/bugfix-AmazonS3-497d9da.json new file mode 100644 index 000000000000..d02dcdda8cc0 --- /dev/null +++ b/.changes/next-release/bugfix-AmazonS3-497d9da.json @@ -0,0 +1,6 @@ +{ + "category": "Amazon S3", + "contributor": "", + "type": "bugfix", + "description": "Truncate the async request body when the content-length is shorter than the request body, instead of raising a \"Data read has a different checksum\" exception." +} diff --git a/.changes/next-release/bugfix-AmazonS3-ce33798.json b/.changes/next-release/bugfix-AmazonS3-ce33798.json new file mode 100644 index 000000000000..897e9d65b4bf --- /dev/null +++ b/.changes/next-release/bugfix-AmazonS3-ce33798.json @@ -0,0 +1,6 @@ +{ + "category": "Amazon S3", + "contributor": "", + "type": "bugfix", + "description": "Raise an exception instead of hanging when a put-object content-length exceeds the data written by the async request body." +} diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/NettyRequestExecutor.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/NettyRequestExecutor.java index 79809e06d58e..2ded744ed573 100644 --- a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/NettyRequestExecutor.java +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/NettyRequestExecutor.java @@ -536,8 +536,14 @@ public void onError(Throwable t) { @Override public void onComplete() { if (!done) { - done = true; - subscriber.onComplete(); + Long expectedContentLength = requestContentLength.orElse(null); + if (expectedContentLength != null && written < expectedContentLength) { + onError(new IllegalStateException("Request content was only " + written + " bytes, but the specified " + + "content-length was " + expectedContentLength + " bytes.")); + } else { + done = true; + subscriber.onComplete(); + } } } }); diff --git a/pom.xml b/pom.xml index b12f23c7ac6d..389f9554d58b 100644 --- a/pom.xml +++ b/pom.xml @@ -522,9 +522,7 @@ *.internal.* software.amazon.awssdk.thirdparty.* - software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler - software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler - + software.amazon.awssdk.services.s3.checksums.ChecksumCalculatingAsyncRequestBody software.amazon.awssdk.protocols.core.OperationInfo diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumCalculatingAsyncRequestBody.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumCalculatingAsyncRequestBody.java index ff1d7475e1b7..187a5a1a9107 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumCalculatingAsyncRequestBody.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumCalculatingAsyncRequestBody.java @@ -17,20 +17,26 @@ import java.nio.ByteBuffer; import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.checksums.SdkChecksum; +import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.utils.BinaryUtils; @SdkInternalApi public class ChecksumCalculatingAsyncRequestBody implements AsyncRequestBody { - + private final Long contentLength; private final AsyncRequestBody wrapped; private final SdkChecksum sdkChecksum; - public ChecksumCalculatingAsyncRequestBody(AsyncRequestBody wrapped, SdkChecksum sdkChecksum) { + public ChecksumCalculatingAsyncRequestBody(SdkHttpRequest request, AsyncRequestBody wrapped, SdkChecksum sdkChecksum) { + this.contentLength = request.firstMatchingHeader("Content-Length") + .map(Long::parseLong) + .orElse(wrapped.contentLength() + .orElse(null)); this.wrapped = wrapped; this.sdkChecksum = sdkChecksum; } @@ -48,18 +54,21 @@ public String contentType() { @Override public void subscribe(Subscriber s) { sdkChecksum.reset(); - wrapped.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum)); + wrapped.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum, contentLength)); } private static final class ChecksumCalculatingSubscriber implements Subscriber { - + private final AtomicLong contentRead = new AtomicLong(0); private final Subscriber wrapped; private final SdkChecksum checksum; + private final Long contentLength; ChecksumCalculatingSubscriber(Subscriber wrapped, - SdkChecksum sdkChecksum) { + SdkChecksum sdkChecksum, + Long contentLength) { this.wrapped = wrapped; this.checksum = sdkChecksum; + this.contentLength = contentLength; } @Override @@ -69,11 +78,34 @@ public void onSubscribe(Subscription s) { @Override public void onNext(ByteBuffer byteBuffer) { - byte[] buf = BinaryUtils.copyBytesFrom(byteBuffer); - checksum.update(buf, 0, buf.length); + int amountToReadFromByteBuffer = getAmountToReadFromByteBuffer(byteBuffer); + + if (amountToReadFromByteBuffer > 0) { + byte[] buf = BinaryUtils.copyBytesFrom(byteBuffer, amountToReadFromByteBuffer); + checksum.update(buf, 0, amountToReadFromByteBuffer); + } + + wrapped.onNext(byteBuffer); } + private int getAmountToReadFromByteBuffer(ByteBuffer byteBuffer) { + // If content length is null, we should include everything in the checksum because the stream is essentially + // unbounded. + if (contentLength == null) { + return byteBuffer.remaining(); + } + + long amountReadSoFar = contentRead.getAndAdd(byteBuffer.remaining()); + long amountRemaining = Math.max(0, contentLength - amountReadSoFar); + + if (amountRemaining > byteBuffer.remaining()) { + return byteBuffer.remaining(); + } else { + return Math.toIntExact(amountRemaining); + } + } + @Override public void onError(Throwable t) { wrapped.onError(t); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/AsyncChecksumValidationInterceptor.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/AsyncChecksumValidationInterceptor.java index af691aae4ee5..4220bb4fd2bf 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/AsyncChecksumValidationInterceptor.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/AsyncChecksumValidationInterceptor.java @@ -51,7 +51,9 @@ public Optional modifyAsyncHttpContent(Context.ModifyHttpReque SdkChecksum checksum = new Md5Checksum(); executionAttributes.putAttribute(ASYNC_RECORDING_CHECKSUM, true); executionAttributes.putAttribute(CHECKSUM, checksum); - return Optional.of(new ChecksumCalculatingAsyncRequestBody(context.asyncRequestBody().get(), checksum)); + return Optional.of(new ChecksumCalculatingAsyncRequestBody(context.httpRequest(), + context.asyncRequestBody().get(), + checksum)); } return context.asyncRequestBody(); diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/ContentLengthMismatchTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/ContentLengthMismatchTest.java new file mode 100644 index 000000000000..89870fab2758 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/ContentLengthMismatchTest.java @@ -0,0 +1,159 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.functionaltests; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.any; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.putRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.github.tomakehurst.wiremock.junit.WireMockRule; +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import org.junit.Rule; +import org.junit.Test; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +public class ContentLengthMismatchTest { + @Rule + public WireMockRule wireMock = new WireMockRule(0); + + private S3AsyncClientBuilder getAsyncClientBuilder() { + return S3AsyncClient.builder() + .region(Region.US_EAST_1) + .endpointOverride(endpoint()) + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret"))); + } + + private URI endpoint() { + return URI.create("http://localhost:" + wireMock.port()); + } + + @Test + public void checksumDoesNotExceedContentLengthHeaderForPuts() { + String bucket = "Example-Bucket"; + String key = "Example-Object"; + String content = "Hello, World!"; + String eTag = "65A8E27D8879283831B664BD8B7F0AD4"; + + stubFor(any(anyUrl()).willReturn(aResponse().withStatus(200).withHeader("ETag", eTag))); + + S3AsyncClient s3Client = getAsyncClientBuilder().build(); + + PutObjectResponse response = + s3Client.putObject(r -> r.bucket(bucket).key(key).contentLength((long) content.length()), + AsyncRequestBody.fromString(content + " Extra stuff!")) + .join(); + + verify(putRequestedFor(anyUrl()).withRequestBody(equalTo(content))); + assertThat(response.eTag()).isEqualTo(eTag); + } + @Test + public void checksumDoesNotExceedAsyncRequestBodyLengthForPuts() { + String bucket = "Example-Bucket"; + String key = "Example-Object"; + String content = "Hello, World!"; + String eTag = "65A8E27D8879283831B664BD8B7F0AD4"; + + stubFor(any(anyUrl()).willReturn(aResponse().withStatus(200).withHeader("ETag", eTag))); + + S3AsyncClient s3Client = getAsyncClientBuilder().build(); + + PutObjectResponse response = + s3Client.putObject(r -> r.bucket(bucket).key(key), + new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.of((long) content.length()); + } + + @Override + public void subscribe(Subscriber subscriber) { + AsyncRequestBody.fromString(content + " Extra stuff!").subscribe(subscriber); + } + }) + .join(); + + verify(putRequestedFor(anyUrl()).withRequestBody(equalTo(content))); + assertThat(response.eTag()).isEqualTo(eTag); + } + + @Test + public void contentShorterThanContentLengthHeaderFails() { + String bucket = "Example-Bucket"; + String key = "Example-Object"; + + S3AsyncClient s3Client = getAsyncClientBuilder().build(); + + AsyncRequestBody requestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber subscriber) { + AsyncRequestBody.fromString("A").subscribe(subscriber); + } + }; + + assertThatThrownBy(() -> s3Client.putObject(r -> r.bucket(bucket).key(key).contentLength(2L), requestBody) + .get(10, TimeUnit.SECONDS)) + .isInstanceOf(ExecutionException.class) + .hasMessageContaining("content-length"); + } + + @Test + public void contentShorterThanRequestBodyLengthFails() { + String bucket = "Example-Bucket"; + String key = "Example-Object"; + + S3AsyncClient s3Client = getAsyncClientBuilder().build(); + + AsyncRequestBody requestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.of(2L); + } + + @Override + public void subscribe(Subscriber subscriber) { + AsyncRequestBody.fromString("A").subscribe(subscriber); + } + }; + + assertThatThrownBy(() -> s3Client.putObject(r -> r.bucket(bucket).key(key), requestBody) + .get(10, TimeUnit.SECONDS)) + .isInstanceOf(ExecutionException.class) + .hasMessageContaining("content-length"); + } + +} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/BinaryUtils.java b/utils/src/main/java/software/amazon/awssdk/utils/BinaryUtils.java index 6724b773be05..e7fd8c015e1d 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/BinaryUtils.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/BinaryUtils.java @@ -206,4 +206,28 @@ public static byte[] copyBytesFrom(ByteBuffer bb) { return dst; } + /** + * This behaves identically to {@link #copyBytesFrom(ByteBuffer)}, except + * that the readLimit acts as a limit to the number of bytes that should be + * read from the byte buffer. + */ + public static byte[] copyBytesFrom(ByteBuffer bb, int readLimit) { + if (bb == null) { + return null; + } + + int numBytesToRead = Math.min(readLimit, bb.limit() - bb.position()); + + if (bb.hasArray()) { + return Arrays.copyOfRange( + bb.array(), + bb.arrayOffset() + bb.position(), + bb.arrayOffset() + bb.position() + numBytesToRead); + } + + byte[] dst = new byte[numBytesToRead]; + bb.asReadOnlyBuffer().get(dst); + return dst; + } + }