Skip to content

Fix SDK behavior when request content-length does not match the data length returned by the publisher. #2788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AmazonS3-497d9da.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"category": "Amazon S3",
"contributor": "",
"type": "bugfix",
"description": "Truncate the async request body when the content-length is shorter than the request body, instead of raising a \"Data read has a different checksum\" exception."
}
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AmazonS3-ce33798.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"category": "Amazon S3",
"contributor": "",
"type": "bugfix",
"description": "Raise an exception instead of hanging when a put-object content-length exceeds the data written by the async request body."
}
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,14 @@ public void onError(Throwable t) {
@Override
public void onComplete() {
if (!done) {
done = true;
subscriber.onComplete();
Long expectedContentLength = requestContentLength.orElse(null);
if (expectedContentLength != null && written < expectedContentLength) {
onError(new IllegalStateException("Request content was only " + written + " bytes, but the specified "
+ "content-length was " + expectedContentLength + " bytes."));
} else {
done = true;
subscriber.onComplete();
}
}
}
});
Expand Down
4 changes: 1 addition & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,7 @@
<excludes>
<exclude>*.internal.*</exclude>
<exclude>software.amazon.awssdk.thirdparty.*</exclude>
<exclude>software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler</exclude>
<exclude>software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler</exclude>

<exclude>software.amazon.awssdk.services.s3.checksums.ChecksumCalculatingAsyncRequestBody</exclude>
<exclude>software.amazon.awssdk.protocols.core.OperationInfo</exclude>
</excludes>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,26 @@

import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.checksums.SdkChecksum;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.utils.BinaryUtils;

@SdkInternalApi
public class ChecksumCalculatingAsyncRequestBody implements AsyncRequestBody {

private final Long contentLength;
private final AsyncRequestBody wrapped;
private final SdkChecksum sdkChecksum;

public ChecksumCalculatingAsyncRequestBody(AsyncRequestBody wrapped, SdkChecksum sdkChecksum) {
public ChecksumCalculatingAsyncRequestBody(SdkHttpRequest request, AsyncRequestBody wrapped, SdkChecksum sdkChecksum) {
this.contentLength = request.firstMatchingHeader("Content-Length")
.map(Long::parseLong)
.orElse(wrapped.contentLength()
.orElse(null));
this.wrapped = wrapped;
this.sdkChecksum = sdkChecksum;
}
Expand All @@ -48,18 +54,21 @@ public String contentType() {
@Override
public void subscribe(Subscriber<? super ByteBuffer> s) {
sdkChecksum.reset();
wrapped.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum));
wrapped.subscribe(new ChecksumCalculatingSubscriber(s, sdkChecksum, contentLength));
}

private static final class ChecksumCalculatingSubscriber implements Subscriber<ByteBuffer> {

private final AtomicLong contentRead = new AtomicLong(0);
private final Subscriber<? super ByteBuffer> wrapped;
private final SdkChecksum checksum;
private final Long contentLength;

ChecksumCalculatingSubscriber(Subscriber<? super ByteBuffer> wrapped,
SdkChecksum sdkChecksum) {
SdkChecksum sdkChecksum,
Long contentLength) {
this.wrapped = wrapped;
this.checksum = sdkChecksum;
this.contentLength = contentLength;
}

@Override
Expand All @@ -69,11 +78,34 @@ public void onSubscribe(Subscription s) {

@Override
public void onNext(ByteBuffer byteBuffer) {
byte[] buf = BinaryUtils.copyBytesFrom(byteBuffer);
checksum.update(buf, 0, buf.length);
int amountToReadFromByteBuffer = getAmountToReadFromByteBuffer(byteBuffer);

if (amountToReadFromByteBuffer > 0) {
byte[] buf = BinaryUtils.copyBytesFrom(byteBuffer, amountToReadFromByteBuffer);
checksum.update(buf, 0, amountToReadFromByteBuffer);
}


wrapped.onNext(byteBuffer);
}

private int getAmountToReadFromByteBuffer(ByteBuffer byteBuffer) {
// If content length is null, we should include everything in the checksum because the stream is essentially
// unbounded.
if (contentLength == null) {
return byteBuffer.remaining();
}

long amountReadSoFar = contentRead.getAndAdd(byteBuffer.remaining());
long amountRemaining = Math.max(0, contentLength - amountReadSoFar);

if (amountRemaining > byteBuffer.remaining()) {
return byteBuffer.remaining();
} else {
return Math.toIntExact(amountRemaining);
}
}

@Override
public void onError(Throwable t) {
wrapped.onError(t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ public Optional<AsyncRequestBody> modifyAsyncHttpContent(Context.ModifyHttpReque
SdkChecksum checksum = new Md5Checksum();
executionAttributes.putAttribute(ASYNC_RECORDING_CHECKSUM, true);
executionAttributes.putAttribute(CHECKSUM, checksum);
return Optional.of(new ChecksumCalculatingAsyncRequestBody(context.asyncRequestBody().get(), checksum));
return Optional.of(new ChecksumCalculatingAsyncRequestBody(context.httpRequest(),
context.asyncRequestBody().get(),
checksum));
}

return context.asyncRequestBody();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.s3.functionaltests;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.any;
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.putRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.github.tomakehurst.wiremock.junit.WireMockRule;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.Rule;
import org.junit.Test;
import org.reactivestreams.Subscriber;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3AsyncClientBuilder;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;

public class ContentLengthMismatchTest {
@Rule
public WireMockRule wireMock = new WireMockRule(0);

private S3AsyncClientBuilder getAsyncClientBuilder() {
return S3AsyncClient.builder()
.region(Region.US_EAST_1)
.endpointOverride(endpoint())
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret")));
}

private URI endpoint() {
return URI.create("http://localhost:" + wireMock.port());
}

@Test
public void checksumDoesNotExceedContentLengthHeaderForPuts() {
String bucket = "Example-Bucket";
String key = "Example-Object";
String content = "Hello, World!";
String eTag = "65A8E27D8879283831B664BD8B7F0AD4";

stubFor(any(anyUrl()).willReturn(aResponse().withStatus(200).withHeader("ETag", eTag)));

S3AsyncClient s3Client = getAsyncClientBuilder().build();

PutObjectResponse response =
s3Client.putObject(r -> r.bucket(bucket).key(key).contentLength((long) content.length()),
AsyncRequestBody.fromString(content + " Extra stuff!"))
.join();

verify(putRequestedFor(anyUrl()).withRequestBody(equalTo(content)));
assertThat(response.eTag()).isEqualTo(eTag);
}
@Test
public void checksumDoesNotExceedAsyncRequestBodyLengthForPuts() {
String bucket = "Example-Bucket";
String key = "Example-Object";
String content = "Hello, World!";
String eTag = "65A8E27D8879283831B664BD8B7F0AD4";

stubFor(any(anyUrl()).willReturn(aResponse().withStatus(200).withHeader("ETag", eTag)));

S3AsyncClient s3Client = getAsyncClientBuilder().build();

PutObjectResponse response =
s3Client.putObject(r -> r.bucket(bucket).key(key),
new AsyncRequestBody() {
@Override
public Optional<Long> contentLength() {
return Optional.of((long) content.length());
}

@Override
public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
AsyncRequestBody.fromString(content + " Extra stuff!").subscribe(subscriber);
}
})
.join();

verify(putRequestedFor(anyUrl()).withRequestBody(equalTo(content)));
assertThat(response.eTag()).isEqualTo(eTag);
}

@Test
public void contentShorterThanContentLengthHeaderFails() {
String bucket = "Example-Bucket";
String key = "Example-Object";

S3AsyncClient s3Client = getAsyncClientBuilder().build();

AsyncRequestBody requestBody = new AsyncRequestBody() {
@Override
public Optional<Long> contentLength() {
return Optional.empty();
}

@Override
public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
AsyncRequestBody.fromString("A").subscribe(subscriber);
}
};

assertThatThrownBy(() -> s3Client.putObject(r -> r.bucket(bucket).key(key).contentLength(2L), requestBody)
.get(10, TimeUnit.SECONDS))
.isInstanceOf(ExecutionException.class)
.hasMessageContaining("content-length");
}

@Test
public void contentShorterThanRequestBodyLengthFails() {
String bucket = "Example-Bucket";
String key = "Example-Object";

S3AsyncClient s3Client = getAsyncClientBuilder().build();

AsyncRequestBody requestBody = new AsyncRequestBody() {
@Override
public Optional<Long> contentLength() {
return Optional.of(2L);
}

@Override
public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
AsyncRequestBody.fromString("A").subscribe(subscriber);
}
};

assertThatThrownBy(() -> s3Client.putObject(r -> r.bucket(bucket).key(key), requestBody)
.get(10, TimeUnit.SECONDS))
.isInstanceOf(ExecutionException.class)
.hasMessageContaining("content-length");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,28 @@ public static byte[] copyBytesFrom(ByteBuffer bb) {
return dst;
}

/**
* This behaves identically to {@link #copyBytesFrom(ByteBuffer)}, except
* that the readLimit acts as a limit to the number of bytes that should be
* read from the byte buffer.
*/
public static byte[] copyBytesFrom(ByteBuffer bb, int readLimit) {
if (bb == null) {
return null;
}

int numBytesToRead = Math.min(readLimit, bb.limit() - bb.position());

if (bb.hasArray()) {
return Arrays.copyOfRange(
bb.array(),
bb.arrayOffset() + bb.position(),
bb.arrayOffset() + bb.position() + numBytesToRead);
}

byte[] dst = new byte[numBytesToRead];
bb.asReadOnlyBuffer().get(dst);
return dst;
}

}