diff --git a/build-tools/src/main/resources/software/amazon/awssdk/spotbugs-suppressions.xml b/build-tools/src/main/resources/software/amazon/awssdk/spotbugs-suppressions.xml index 48cd9e6fabd0..e950e01812b4 100644 --- a/build-tools/src/main/resources/software/amazon/awssdk/spotbugs-suppressions.xml +++ b/build-tools/src/main/resources/software/amazon/awssdk/spotbugs-suppressions.xml @@ -206,6 +206,13 @@ + + + + + + + diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/AwsCrtAsyncHttpClient.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/AwsCrtAsyncHttpClient.java index fd646685f8f3..9bb52e700a03 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/AwsCrtAsyncHttpClient.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/AwsCrtAsyncHttpClient.java @@ -322,6 +322,8 @@ public interface Builder extends SdkAsyncHttpClient.Builder execute(CrtRequestContext executionContext) { HttpRequest crtRequest = CrtRequestAdapter.toCrtRequest(executionContext); HttpStreamResponseHandler crtResponseHandler = - CrtResponseAdapter.toCrtResponseHandler(crtConn, requestFuture, executionContext); + CrtResponseAdapter.toCrtResponseHandler(crtConn, requestFuture, asyncRequest.responseHandler()); // Submit the request on the connection try { diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/request/CrtRequestBodyAdapter.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/request/CrtRequestBodyAdapter.java index ed716219d5b3..1e46a8ca2eb1 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/request/CrtRequestBodyAdapter.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/request/CrtRequestBodyAdapter.java @@ -19,24 +19,27 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.crt.http.HttpRequestBodyStream; import software.amazon.awssdk.http.async.SdkHttpContentPublisher; -import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.async.ByteBufferStoringSubscriber; +import software.amazon.awssdk.utils.async.ByteBufferStoringSubscriber.TransferResult; -/** - * Implements the CrtHttpStreamHandler API and converts CRT callbacks into calls to SDK AsyncExecuteRequest methods - */ @SdkInternalApi final class CrtRequestBodyAdapter implements HttpRequestBodyStream { - private final int windowSize; - private final CrtRequestBodySubscriber requestBodySubscriber; + private final SdkHttpContentPublisher requestPublisher; + private final ByteBufferStoringSubscriber requestBodySubscriber; - CrtRequestBodyAdapter(SdkHttpContentPublisher requestPublisher, int windowSize) { - this.windowSize = Validate.isPositive(windowSize, "windowSize is <= 0"); - this.requestBodySubscriber = new CrtRequestBodySubscriber(windowSize); + CrtRequestBodyAdapter(SdkHttpContentPublisher requestPublisher, int readLimit) { + this.requestPublisher = requestPublisher; + this.requestBodySubscriber = new ByteBufferStoringSubscriber(readLimit); requestPublisher.subscribe(requestBodySubscriber); } @Override public boolean sendRequestBody(ByteBuffer bodyBytesOut) { - return requestBodySubscriber.transferRequestBody(bodyBytesOut); + return requestBodySubscriber.transferTo(bodyBytesOut) == TransferResult.END_OF_STREAM; + } + + @Override + public long getLength() { + return requestPublisher.contentLength().orElse(0L); } } diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/request/CrtRequestBodySubscriber.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/request/CrtRequestBodySubscriber.java deleted file mode 100644 index 88d7f985feef..000000000000 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/request/CrtRequestBodySubscriber.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.http.crt.internal.request; - -import static software.amazon.awssdk.crt.utils.ByteBufferUtils.transferData; - -import java.nio.ByteBuffer; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.Validate; - -/** - * Implements the Subscriber API to be be callable from AwsCrtAsyncHttpStreamAdapter.sendRequestBody() - */ -@SdkInternalApi -public final class CrtRequestBodySubscriber implements Subscriber { - private static final Logger log = Logger.loggerFor(CrtRequestBodySubscriber.class); - - private final int windowSize; - private final Queue queuedBuffers = new ConcurrentLinkedQueue<>(); - private final AtomicLong queuedByteCount = new AtomicLong(0); - private final AtomicBoolean isComplete = new AtomicBoolean(false); - private final AtomicReference error = new AtomicReference<>(null); - - private AtomicReference subscriptionRef = new AtomicReference<>(null); - - /** - * - * @param windowSize The number bytes to be queued before we stop proactively queuing data - */ - public CrtRequestBodySubscriber(int windowSize) { - Validate.isPositive(windowSize, "windowSize is <= 0"); - this.windowSize = windowSize; - } - - protected void requestDataIfNecessary() { - Subscription subscription = subscriptionRef.get(); - if (subscription == null) { - log.error(() -> "Subscription is null"); - return; - } - if (queuedByteCount.get() < windowSize) { - subscription.request(1); - } - } - - @Override - public void onSubscribe(Subscription s) { - Validate.paramNotNull(s, "s"); - - boolean wasFirstSubscription = subscriptionRef.compareAndSet(null, s); - - if (!wasFirstSubscription) { - log.error(() -> "Only one Subscription supported!"); - s.cancel(); - return; - } - - requestDataIfNecessary(); - } - - @Override - public void onNext(ByteBuffer byteBuffer) { - Validate.paramNotNull(byteBuffer, "byteBuffer"); - queuedBuffers.add(byteBuffer); - queuedByteCount.addAndGet(byteBuffer.remaining()); - requestDataIfNecessary(); - } - - @Override - public void onError(Throwable t) { - log.error(() -> "onError() received an error: " + t.getMessage()); - error.compareAndSet(null, t); - } - - @Override - public void onComplete() { - log.debug(() -> "AwsCrtRequestBodySubscriber Completed"); - isComplete.set(true); - } - - /** - * Transfers any queued data from the Request Body subscriptionRef to the output buffer - * @param out The output ByteBuffer - * @return true if Request Body is completely transferred, false otherwise - */ - public synchronized boolean transferRequestBody(ByteBuffer out) { - if (error.get() != null) { - throw new RuntimeException(error.get()); - } - - while (out.remaining() > 0 && !queuedBuffers.isEmpty()) { - ByteBuffer nextBuffer = queuedBuffers.peek(); - int amtTransferred = transferData(nextBuffer, out); - queuedByteCount.addAndGet(-amtTransferred); - - if (nextBuffer.remaining() == 0) { - queuedBuffers.remove(); - } - } - - boolean endOfStream = isComplete.get() && queuedBuffers.isEmpty(); - - if (!endOfStream) { - requestDataIfNecessary(); - } else { - log.debug(() -> "End Of RequestBody reached"); - } - - return endOfStream; - } -} diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java index c69870c929a1..370f4aebec47 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.http.crt.internal.response; +import java.nio.ByteBuffer; import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.crt.CRT; @@ -26,10 +27,10 @@ import software.amazon.awssdk.crt.http.HttpStreamResponseHandler; import software.amazon.awssdk.http.HttpStatusFamily; import software.amazon.awssdk.http.SdkHttpResponse; -import software.amazon.awssdk.http.async.AsyncExecuteRequest; -import software.amazon.awssdk.http.crt.internal.CrtRequestContext; +import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.async.SimplePublisher; /** * Implements the CrtHttpStreamHandler API and converts CRT callbacks into calls to SDK AsyncExecuteRequest methods @@ -39,97 +40,110 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler { private static final Logger log = Logger.loggerFor(CrtResponseAdapter.class); private final HttpClientConnection connection; - private final CompletableFuture responseComplete; - private final AsyncExecuteRequest sdkRequest; - private final SdkHttpResponse.Builder respBuilder = SdkHttpResponse.builder(); - private final int windowSize; - private CrtResponseBodyPublisher respBodyPublisher; + private final CompletableFuture completionFuture; + private final SdkAsyncHttpResponseHandler responseHandler; + private final SimplePublisher responsePublisher = new SimplePublisher<>(); - private CrtResponseAdapter(HttpClientConnection connection, - CompletableFuture responseComplete, - AsyncExecuteRequest sdkRequest, - int windowSize) { - this.connection = Validate.notNull(connection, "HttpConnection is null"); - this.responseComplete = Validate.notNull(responseComplete, "reqComplete Future is null"); - this.sdkRequest = Validate.notNull(sdkRequest, "AsyncExecuteRequest Future is null"); - this.windowSize = Validate.isPositive(windowSize, "windowSize is <= 0"); - } + private final SdkHttpResponse.Builder responseBuilder = SdkHttpResponse.builder(); - public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnection connection, - CompletableFuture responseComplete, - CrtRequestContext request) { - return new CrtResponseAdapter(connection, responseComplete, request.sdkRequest(), request.readBufferSize()); + private CrtResponseAdapter(HttpClientConnection connection, + CompletableFuture completionFuture, + SdkAsyncHttpResponseHandler responseHandler) { + this.connection = Validate.paramNotNull(connection, "connection"); + this.completionFuture = Validate.paramNotNull(completionFuture, "completionFuture"); + this.responseHandler = Validate.paramNotNull(responseHandler, "responseHandler"); } - private void initRespBodyPublisherIfNeeded(HttpStream stream) { - if (respBodyPublisher == null) { - respBodyPublisher = new CrtResponseBodyPublisher(connection, stream, responseComplete, windowSize); - } + public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnection crtConn, + CompletableFuture requestFuture, + SdkAsyncHttpResponseHandler responseHandler) { + return new CrtResponseAdapter(crtConn, requestFuture, responseHandler); } @Override - public void onResponseHeaders(HttpStream stream, int responseStatusCode, int blockType, HttpHeader[] nextHeaders) { - initRespBodyPublisherIfNeeded(stream); - - for (HttpHeader h : nextHeaders) { - respBuilder.appendHeader(h.getName(), h.getValue()); + public void onResponseHeaders(HttpStream stream, int responseStatusCode, int headerType, HttpHeader[] nextHeaders) { + if (headerType == HttpHeaderBlock.MAIN.getValue()) { + for (HttpHeader h : nextHeaders) { + responseBuilder.appendHeader(h.getName(), h.getValue()); + } } } @Override public void onResponseHeadersDone(HttpStream stream, int headerType) { if (headerType == HttpHeaderBlock.MAIN.getValue()) { - initRespBodyPublisherIfNeeded(stream); - - respBuilder.statusCode(stream.getResponseStatusCode()); - sdkRequest.responseHandler().onHeaders(respBuilder.build()); - sdkRequest.responseHandler().onStream(respBodyPublisher); + responseBuilder.statusCode(stream.getResponseStatusCode()); + responseHandler.onHeaders(responseBuilder.build()); + responseHandler.onStream(responsePublisher); } } @Override public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) { - initRespBodyPublisherIfNeeded(stream); + CompletableFuture writeFuture = responsePublisher.send(ByteBuffer.wrap(bodyBytesIn)); + + if (writeFuture.isDone() && !writeFuture.isCompletedExceptionally()) { + // Optimization: If write succeeded immediately, return non-zero to avoid the extra call back into the CRT. + return bodyBytesIn.length; + } - respBodyPublisher.queueBuffer(bodyBytesIn); - respBodyPublisher.publishToSubscribers(); + writeFuture.whenComplete((result, failure) -> { + if (failure != null) { + failResponseHandlerAndFuture(stream, failure); + return; + } + + stream.incrementWindow(bodyBytesIn.length); + }); - /* - * Intentionally zero. We manually manage the crt stream's window within the body publisher by updating with - * the exact amount we were able to push to the subcriber. - * - * See the call to stream.incrementWindow() in AwsCrtResponseBodyPublisher. - */ return 0; } @Override public void onResponseComplete(HttpStream stream, int errorCode) { - initRespBodyPublisherIfNeeded(stream); - - if (HttpStatusFamily.of(respBuilder.statusCode()) == HttpStatusFamily.SERVER_ERROR) { - connection.shutdown(); - } - if (errorCode == CRT.AWS_CRT_SUCCESS) { - log.debug(() -> "Response Completed Successfully"); - respBodyPublisher.setQueueComplete(); - respBodyPublisher.publishToSubscribers(); + onSuccessfulResponseComplete(stream); } else { - HttpException error = new HttpException(errorCode); - log.error(() -> "Response Encountered an Error.", error); - - // Invoke Error Callback on SdkAsyncHttpResponseHandler - try { - sdkRequest.responseHandler().onError(error); - } catch (Exception e) { - log.error(() -> String.format("SdkAsyncHttpResponseHandler %s threw an exception in onError: %s", - sdkRequest.responseHandler(), e)); + onFailedResponseComplete(stream, new HttpException(errorCode)); + } + } + + private void onSuccessfulResponseComplete(HttpStream stream) { + responsePublisher.complete().whenComplete((result, failure) -> { + if (failure != null) { + failResponseHandlerAndFuture(stream, failure); + return; + } + + if (HttpStatusFamily.of(responseBuilder.statusCode()) == HttpStatusFamily.SERVER_ERROR) { + connection.shutdown(); } - // Invoke Error Callback on any Subscriber's of the Response Body - respBodyPublisher.setError(error); - respBodyPublisher.publishToSubscribers(); + connection.close(); + stream.close(); + completionFuture.complete(null); + }); + } + + private void onFailedResponseComplete(HttpStream stream, HttpException error) { + log.error(() -> "HTTP response encountered an error.", error); + responsePublisher.error(error); + failResponseHandlerAndFuture(stream, error); + } + + private void failResponseHandlerAndFuture(HttpStream stream, Throwable error) { + callResponseHandlerOnError(error); + completionFuture.completeExceptionally(error); + connection.shutdown(); + connection.close(); + stream.close(); + } + + private void callResponseHandlerOnError(Throwable error) { + try { + responseHandler.onError(error); + } catch (RuntimeException e) { + log.warn(() -> "Exception raised from SdkAsyncHttpResponseHandler#onError.", e); } } } diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseBodyPublisher.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseBodyPublisher.java deleted file mode 100644 index 01085e18a4e5..000000000000 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseBodyPublisher.java +++ /dev/null @@ -1,333 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.http.crt.internal.response; - -import java.nio.ByteBuffer; -import java.util.Optional; -import java.util.Queue; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.LongUnaryOperator; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.crt.http.HttpClientConnection; -import software.amazon.awssdk.crt.http.HttpStream; -import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.Validate; - -/** - * Adapts an AWS Common Runtime Response Body stream from CrtHttpStreamHandler to a Publisher - */ -@SdkInternalApi -public final class CrtResponseBodyPublisher implements Publisher { - private static final Logger log = Logger.loggerFor(CrtResponseBodyPublisher.class); - private static final LongUnaryOperator DECREMENT_IF_GREATER_THAN_ZERO = x -> ((x > 0) ? (x - 1) : (x)); - - private final HttpClientConnection connection; - private final HttpStream stream; - private final CompletableFuture responseComplete; - private final AtomicLong outstandingRequests = new AtomicLong(0); - private final int windowSize; - private final AtomicBoolean isCancelled = new AtomicBoolean(false); - private final AtomicBoolean areNativeResourcesReleased = new AtomicBoolean(false); - private final AtomicBoolean isSubscriptionComplete = new AtomicBoolean(false); - private final AtomicBoolean queueComplete = new AtomicBoolean(false); - private final AtomicInteger mutualRecursionDepth = new AtomicInteger(0); - private final AtomicInteger queuedBytes = new AtomicInteger(0); - private final AtomicReference> subscriberRef = new AtomicReference<>(null); - private final Queue queuedBuffers = new ConcurrentLinkedQueue<>(); - private final AtomicReference error = new AtomicReference<>(null); - - /** - * Adapts a streaming AWS CRT Http Response Body to a Publisher - * @param stream The AWS CRT Http Stream for this Response - * @param windowSize The max allowed bytes to be queued. The sum of the sizes of all queued ByteBuffers should - * never exceed this value. - */ - public CrtResponseBodyPublisher(HttpClientConnection connection, HttpStream stream, - CompletableFuture responseComplete, int windowSize) { - this.connection = Validate.notNull(connection, "HttpConnection must not be null"); - this.stream = Validate.notNull(stream, "Stream must not be null"); - this.responseComplete = Validate.notNull(responseComplete, "ResponseComplete future must not be null"); - this.windowSize = Validate.isPositive(windowSize, "windowSize must be > 0"); - } - - /** - * Method for the users consuming the Http Response Body to register a subscriber. - * @param subscriber The Subscriber to register. - */ - @Override - public void subscribe(Subscriber subscriber) { - Validate.notNull(subscriber, "Subscriber must not be null"); - - boolean wasFirstSubscriber = subscriberRef.compareAndSet(null, subscriber); - - if (!wasFirstSubscriber) { - log.error(() -> "Only one subscriber allowed"); - - // onSubscribe must be called first before onError gets called, so give it a do-nothing Subscription - subscriber.onSubscribe(new Subscription() { - @Override - public void request(long n) { - // This is a dummy implementation to allow the onError call - } - - @Override - public void cancel() { - // This is a dummy implementation to allow the onError call - } - }); - subscriber.onError(new IllegalStateException("Only one subscriber allowed")); - } else { - subscriber.onSubscribe(new AwsCrtResponseBodySubscription(this)); - } - } - - /** - * Adds a Buffer to the Queue to be published to any Subscribers - * @param buffer The Buffer to be queued. - */ - public void queueBuffer(byte[] buffer) { - Validate.notNull(buffer, "ByteBuffer must not be null"); - - if (isCancelled.get()) { - // Immediately open HttpStream's IO window so it doesn't see any IO Back-pressure. - // AFAIK there's no way to abort an in-progress HttpStream, only free it's memory by calling close() - stream.incrementWindow(buffer.length); - return; - } - - queuedBuffers.add(buffer); - int totalBytesQueued = queuedBytes.addAndGet(buffer.length); - - if (totalBytesQueued > windowSize) { - throw new IllegalStateException("Queued more than Window Size: queued=" + totalBytesQueued - + ", window=" + windowSize); - } - } - - /** - * Function called by Response Body Subscribers to request more Response Body buffers. - * @param n The number of buffers requested. - */ - protected void request(long n) { - Validate.inclusiveBetween(1, Long.MAX_VALUE, n, "request"); - - // Check for overflow of outstanding Requests, and clamp to LONG_MAX. - long outstandingReqs; - if (n > (Long.MAX_VALUE - outstandingRequests.get())) { - outstandingRequests.set(Long.MAX_VALUE); - outstandingReqs = Long.MAX_VALUE; - } else { - outstandingReqs = outstandingRequests.addAndGet(n); - } - - /* - * Since we buffer, in the case where the subscriber came in after the publication has already begun, - * go ahead and flush what we have. - */ - publishToSubscribers(); - - log.trace(() -> "Subscriber Requested more Buffers. Outstanding Requests: " + outstandingReqs); - } - - public void setError(Throwable t) { - log.error(() -> "Error processing Response Body", t); - error.compareAndSet(null, t); - } - - protected void setCancelled() { - isCancelled.set(true); - /** - * subscriberRef must set to null due to ReactiveStream Spec stating references to Subscribers must be deleted - * when onCancel() is called. - */ - subscriberRef.set(null); - } - - private synchronized void releaseNativeResources() { - boolean alreadyReleased = areNativeResourcesReleased.getAndSet(true); - - if (!alreadyReleased) { - stream.close(); - connection.close(); - } - } - - /** - * Called when the final Buffer has been queued and no more data is expected. - */ - public void setQueueComplete() { - log.trace(() -> "Response Body Publisher queue marked as completed."); - queueComplete.set(true); - // We're done with the Native Resources, release them so they can be used by another request. - releaseNativeResources(); - } - - /** - * Completes the Subscription by calling either the .onError() or .onComplete() callbacks exactly once. - */ - protected void completeSubscriptionExactlyOnce() { - boolean alreadyComplete = isSubscriptionComplete.getAndSet(true); - - if (alreadyComplete) { - return; - } - - // Subscriber may have cancelled their subscription, in which case this may be null. - Optional> subscriber = Optional.ofNullable(subscriberRef.getAndSet(null)); - - Throwable throwable = error.get(); - - // We're done with the Native Resources, release them so they can be used by another request. - releaseNativeResources(); - - // Complete the Futures - if (throwable != null) { - log.error(() -> "Error before ResponseBodyPublisher could complete: " + throwable.getMessage()); - try { - subscriber.ifPresent(s -> s.onError(throwable)); - } catch (Exception e) { - log.warn(() -> "Failed to exceptionally complete subscriber future with: " + throwable.getMessage()); - } - responseComplete.completeExceptionally(throwable); - } else { - log.debug(() -> "ResponseBodyPublisher Completed Successfully"); - try { - subscriber.ifPresent(Subscriber::onComplete); - } catch (Exception e) { - log.warn(() -> "Failed to successfully complete subscriber future"); - } - responseComplete.complete(null); - } - } - - /** - * Publishes any queued data to any Subscribers if there is data queued and there is an outstanding Subscriber - * request for more data. Will also call onError() or onComplete() callbacks if needed. - * - * This method MUST be synchronized since it can be called simultaneously from both the Native EventLoop Thread and - * the User Thread. If this method wasn't synchronized, it'd be possible for each thread to dequeue a buffer by - * calling queuedBuffers.poll(), but then have the 2nd thread call subscriber.onNext(buffer) first, resulting in the - * subscriber seeing out-of-order data. To avoid this race condition, this method must be synchronized. - */ - protected void publishToSubscribers() { - boolean shouldComplete = true; - synchronized (this) { - if (error.get() == null) { - if (isSubscriptionComplete.get() || isCancelled.get()) { - log.debug(() -> "Subscription already completed or cancelled, can't publish updates to Subscribers."); - return; - } - - if (mutualRecursionDepth.get() > 0) { - /** - * If our depth is > 0, then we already made a call to publishToSubscribers() further up the stack that - * will continue publishing to subscribers, and this call should return without completing work to avoid - * infinite recursive loop between: "subscription.request() -> subscriber.onNext() -> subscription.request()" - */ - return; - } - - int totalAmountTransferred = 0; - - while (outstandingRequests.get() > 0 && !queuedBuffers.isEmpty()) { - byte[] buffer = queuedBuffers.poll(); - outstandingRequests.getAndUpdate(DECREMENT_IF_GREATER_THAN_ZERO); - int amount = buffer.length; - publishWithoutMutualRecursion(subscriberRef.get(), ByteBuffer.wrap(buffer)); - totalAmountTransferred += amount; - } - - if (totalAmountTransferred > 0) { - queuedBytes.addAndGet(-totalAmountTransferred); - - // We may have released the Native HttpConnection and HttpStream if they completed before the Subscriber - // has finished reading the data. - if (!areNativeResourcesReleased.get()) { - // Open HttpStream's IO window so HttpStream can keep track of IO back-pressure - // This is why it is correct to return 0 from AwsCrtAsyncHttpStreamAdapter::onResponseBody - stream.incrementWindow(totalAmountTransferred); - } - } - - shouldComplete = queueComplete.get() && queuedBuffers.isEmpty(); - } else { - shouldComplete = true; - } - } - - // Check if Complete, consider no subscriber as a completion. - if (shouldComplete) { - completeSubscriptionExactlyOnce(); - } - } - - /** - * This method is used to avoid a StackOverflow due to the potential infinite loop between - * "subscription.request() -> subscriber.onNext() -> subscription.request()" calls. We only call subscriber.onNext() - * if the recursion depth is zero, otherwise we return up to the stack frame with depth zero and continue publishing - * from there. - * @param subscriber The Subscriber to publish to. - * @param buffer The buffer to publish to the subscriber. - */ - private synchronized void publishWithoutMutualRecursion(Subscriber subscriber, ByteBuffer buffer) { - try { - /** - * Need to keep track of recursion depth between .onNext() -> .request() calls - */ - int depth = mutualRecursionDepth.getAndIncrement(); - if (depth == 0) { - subscriber.onNext(buffer); - } - } finally { - mutualRecursionDepth.decrementAndGet(); - } - } - - static class AwsCrtResponseBodySubscription implements Subscription { - private final CrtResponseBodyPublisher publisher; - - AwsCrtResponseBodySubscription(CrtResponseBodyPublisher publisher) { - this.publisher = publisher; - } - - @Override - public void request(long n) { - if (n <= 0) { - // Reactive Stream Spec requires us to call onError() callback instead of throwing Exception here. - publisher.setError(new IllegalArgumentException("Request is for <= 0 elements: " + n)); - publisher.publishToSubscribers(); - return; - } - - publisher.request(n); - publisher.publishToSubscribers(); - } - - @Override - public void cancel() { - publisher.setCancelled(); - } - } - -} diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/CrtRequestBodySubscriberReactiveStreamCompatTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/CrtRequestBodySubscriberReactiveStreamCompatTest.java deleted file mode 100644 index 57db737698e8..000000000000 --- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/CrtRequestBodySubscriberReactiveStreamCompatTest.java +++ /dev/null @@ -1,66 +0,0 @@ -package software.amazon.awssdk.http.crt; - -import java.nio.ByteBuffer; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import org.reactivestreams.tck.SubscriberWhiteboxVerification; -import org.reactivestreams.tck.TestEnvironment; -import software.amazon.awssdk.http.crt.internal.request.CrtRequestBodySubscriber; - -public class CrtRequestBodySubscriberReactiveStreamCompatTest extends SubscriberWhiteboxVerification { - private static final int DEFAULT_STREAM_WINDOW_SIZE = 16 * 1024 * 1024; // 16 MB Total Buffer size - - public CrtRequestBodySubscriberReactiveStreamCompatTest() { - super(new TestEnvironment()); - } - - @Override - public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { - CrtRequestBodySubscriber actualSubscriber = new CrtRequestBodySubscriber(DEFAULT_STREAM_WINDOW_SIZE); - - // Pass Through calls to AwsCrtRequestBodySubscriber, but also register calls to the whitebox probe - Subscriber passthroughSubscriber = new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - actualSubscriber.onSubscribe(s); - probe.registerOnSubscribe(new SubscriberPuppet() { - - @Override - public void triggerRequest(long elements) { - s.request(elements); - } - - @Override - public void signalCancel() { - s.cancel(); - } - }); - } - - @Override - public void onNext(ByteBuffer byteBuffer) { - actualSubscriber.onNext(byteBuffer); - probe.registerOnNext(byteBuffer); - } - - @Override - public void onError(Throwable t) { - actualSubscriber.onError(t); - probe.registerOnError(t); - } - - @Override - public void onComplete() { - actualSubscriber.onComplete(); - probe.registerOnComplete(); - } - }; - - return passthroughSubscriber; - } - - @Override - public ByteBuffer createElement(int element) { - return ByteBuffer.wrap(Integer.toString(element).getBytes()); - } -} diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/CrtResponseBodyPublisherReactiveStreamCompatTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/CrtResponseBodyPublisherReactiveStreamCompatTest.java deleted file mode 100644 index ad536bab1ccf..000000000000 --- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/CrtResponseBodyPublisherReactiveStreamCompatTest.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.http.crt; - -import static org.mockito.Mockito.mock; - -import java.nio.ByteBuffer; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import org.reactivestreams.Publisher; -import org.reactivestreams.tck.PublisherVerification; -import org.reactivestreams.tck.TestEnvironment; -import software.amazon.awssdk.crt.http.HttpClientConnection; -import software.amazon.awssdk.crt.http.HttpStream; -import software.amazon.awssdk.http.crt.internal.response.CrtResponseBodyPublisher; -import software.amazon.awssdk.utils.Logger; - -public class CrtResponseBodyPublisherReactiveStreamCompatTest extends PublisherVerification { - private static final Logger log = Logger.loggerFor(CrtResponseBodyPublisherReactiveStreamCompatTest.class); - - public CrtResponseBodyPublisherReactiveStreamCompatTest() { - super(new TestEnvironment()); - } - - @Override - public Publisher createPublisher(long elements) { - HttpClientConnection connection = mock(HttpClientConnection.class); - HttpStream stream = mock(HttpStream.class); - CrtResponseBodyPublisher bodyPublisher = new CrtResponseBodyPublisher(connection, stream, new CompletableFuture<>(), Integer.MAX_VALUE); - - for (long i = 0; i < elements; i++) { - bodyPublisher.queueBuffer(UUID.randomUUID().toString().getBytes()); - } - - bodyPublisher.setQueueComplete(); - return bodyPublisher; - } - - // Some tests try to create INT_MAX elements, which causes OutOfMemory Exceptions. Lower the max allowed number of - // queued buffers to 1024. - @Override - public long maxElementsFromPublisher() { - return 1024; - } - - @Override - public Publisher createFailedPublisher() { - return null; - } -} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriber.java new file mode 100644 index 000000000000..dcf06cce4fb5 --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriber.java @@ -0,0 +1,198 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static software.amazon.awssdk.utils.async.StoringSubscriber.EventType.ON_NEXT; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.async.StoringSubscriber.Event; + +/** + * An implementation of {@link Subscriber} that stores {@link ByteBuffer} events it receives for retrieval. + * + *

Stored bytes can be read via {@link #transferTo(ByteBuffer)}. + */ +@SdkProtectedApi +public class ByteBufferStoringSubscriber implements Subscriber { + /** + * The minimum amount of data (in bytes) that should be buffered in memory at a time. The subscriber will request new byte + * buffers from upstream until the bytes received equals or exceeds this value. + */ + private final long minimumBytesBuffered; + + /** + * The amount of data (in bytes) currently stored in this subscriber. The subscriber will request more data when this value + * is below the {@link #minimumBytesBuffered}. + */ + private final AtomicLong bytesBuffered = new AtomicLong(0L); + + /** + * A delegate subscriber that we use to store the buffered bytes in the order they are received. + */ + private final StoringSubscriber storingSubscriber; + + /** + * The active subscription. Set when {@link #onSubscribe(Subscription)} is invoked. + */ + private Subscription subscription; + + /** + * Create a subscriber that stores at least {@code minimumBytesBuffered} in memory for retrieval. + */ + public ByteBufferStoringSubscriber(long minimumBytesBuffered) { + this.minimumBytesBuffered = Validate.isPositive(minimumBytesBuffered, "Data buffer minimum must be positive"); + this.storingSubscriber = new StoringSubscriber<>(Integer.MAX_VALUE); + } + + /** + * Transfer the data stored by this subscriber into the provided byte buffer. + * + *

If the data stored by this subscriber exceeds {@code out}'s {@code limit}, then {@code out} will be filled. If the data + * stored by this subscriber is less than {@code out}'s {@code limit}, then all stored data will be written to {@code out}. + * + *

If {@link #onError(Throwable)} was called on this subscriber, as much data as is available will be transferred into + * {@code out} before the provided exception is thrown (as a {@link RuntimeException}). + * + *

If {@link #onComplete()} was called on this subscriber, as much data as is available will be transferred into + * {@code out}, and this will return {@link TransferResult#END_OF_STREAM}. + * + *

Note: This method MUST NOT be called concurrently. Other methods on this class may be called concurrently with this + * one. + */ + public TransferResult transferTo(ByteBuffer out) { + int transferred = 0; + + Optional> next = storingSubscriber.peek(); + + while (out.hasRemaining()) { + if (!next.isPresent() || next.get().type() != ON_NEXT) { + break; + } + + transferred += transfer(next.get().value(), out); + next = storingSubscriber.peek(); + } + + addBufferedDataAmount(-transferred); + + if (!next.isPresent()) { + return TransferResult.SUCCESS; + } + + switch (next.get().type()) { + case ON_COMPLETE: + return TransferResult.END_OF_STREAM; + case ON_ERROR: + throw next.get().runtimeError(); + case ON_NEXT: + return TransferResult.SUCCESS; + default: + throw new IllegalStateException("Unknown stored type: " + next.get().type()); + } + } + + private int transfer(ByteBuffer in, ByteBuffer out) { + int amountToTransfer = Math.min(in.remaining(), out.remaining()); + + ByteBuffer truncatedIn = in.duplicate(); + truncatedIn.limit(truncatedIn.position() + amountToTransfer); + + out.put(truncatedIn); + in.position(truncatedIn.position()); + + if (!in.hasRemaining()) { + storingSubscriber.poll(); + } + + return amountToTransfer; + } + + @Override + public void onSubscribe(Subscription s) { + storingSubscriber.onSubscribe(new DemandIgnoringSubscription(s)); + subscription = s; + subscription.request(1); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + storingSubscriber.onNext(byteBuffer.duplicate()); + addBufferedDataAmount(byteBuffer.remaining()); + } + + @Override + public void onError(Throwable t) { + storingSubscriber.onError(t); + } + + @Override + public void onComplete() { + storingSubscriber.onComplete(); + } + + private void addBufferedDataAmount(long amountToAdd) { + long currentDataBuffered = bytesBuffered.addAndGet(amountToAdd); + maybeRequestMore(currentDataBuffered); + } + + private void maybeRequestMore(long currentDataBuffered) { + if (currentDataBuffered < minimumBytesBuffered) { + subscription.request(1); + } + } + + /** + * The result of {@link #transferTo(ByteBuffer)}. + */ + public enum TransferResult { + /** + * Data was successfully transferred to {@code out}, and the end of stream has been reached. No future calls to + * {@link #transferTo(ByteBuffer)} will yield additional data. + */ + END_OF_STREAM, + + /** + * Data was successfully transferred to {@code out}, but the end of stream has not been reached. Future calls to + * {@link #transferTo(ByteBuffer)} may yield additional data. + */ + SUCCESS + } + + private static final class DemandIgnoringSubscription implements Subscription { + private final Subscription delegate; + + private DemandIgnoringSubscription(Subscription delegate) { + this.delegate = delegate; + } + + @Override + public void request(long n) { + // Ignore demand requests from downstream, they want too much. + // We feed them the amount that we want. + } + + @Override + public void cancel() { + delegate.cancel(); + } + } +} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java b/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java new file mode 100644 index 000000000000..b83ad5a1149f --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java @@ -0,0 +1,499 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static java.util.Arrays.asList; +import static software.amazon.awssdk.utils.async.SimplePublisher.QueueEntry.Type.CANCEL; +import static software.amazon.awssdk.utils.async.SimplePublisher.QueueEntry.Type.ON_COMPLETE; +import static software.amazon.awssdk.utils.async.SimplePublisher.QueueEntry.Type.ON_ERROR; +import static software.amazon.awssdk.utils.async.SimplePublisher.QueueEntry.Type.ON_NEXT; + +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CopyOnWriteArraySet; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; + +/** + * A {@link Publisher} to which callers can {@link #send(Object)} messages, simplifying the process of implementing a publisher. + * + *

Operations + * + *

The {@code SimplePublisher} supports three simplified operations: + *

    + *
  1. {@link #send(Object)} for sending messages
  2. + *
  3. {@link #complete()} for indicating the successful end of messages
  4. + *
  5. {@link #error(Throwable)} for indicating the unsuccessful end of messages
  6. + *
+ * + * Each of these operations returns a {@link CompletableFuture} for indicating when the message has been successfully sent. + * + *

Callers are expected to invoke a series of {@link #send(Object)}s followed by a single {@link #complete()} or + * {@link #error(Throwable)}. See the documentation on each operation for more details. + * + *

This publisher will store an unbounded number of messages. It is recommended that callers limit the number of in-flight + * {@link #send(Object)} operations in order to bound the amount of memory used by this publisher. + */ +@SdkProtectedApi +public final class SimplePublisher implements Publisher { + private static final Logger log = Logger.loggerFor(SimplePublisher.class); + + /** + * Track the amount of outstanding demand requested by the active subscriber. + */ + private final AtomicLong outstandingDemand = new AtomicLong(); + + /** + * The queue of events to be processed, in the order they should be processed. + * + *

All logic within this publisher is represented using events in this queue. This ensures proper ordering of events + * processing and simplified reasoning about thread safety. + */ + private final Queue> eventQueue = new ConcurrentLinkedQueue<>(); + + /** + * When processing the {@link #eventQueue}, these are the entries that should be skipped (and failed). This is used to + * safely drain the queue when there are urgent events needing processing, like a {@link Subscription#cancel()}. + */ + private final Set entryTypesToFail = new CopyOnWriteArraySet<>(); + + /** + * Whether the {@link #eventQueue} is currently being processed. Only one thread may read events from the queue at a time. + */ + private final AtomicBoolean processingQueue = new AtomicBoolean(false); + + /** + * An exception that should be raised to any failed {@link #send(Object)}, {@link #complete()} or {@link #error(Throwable)} + * operations. This is used to stop accepting messages after the downstream subscription is cancelled or after the + * caller sends a {@code complete()} or {@code #error()}. + * + *

This is a supplier to avoid the cost of creating an exception in the successful code path. + */ + private final AtomicReference> rejectException = new AtomicReference<>(); + + /** + * The subscriber provided via {@link #subscribe(Subscriber)}. This publisher only supports a single subscriber. + */ + private Subscriber subscriber; + + /** + * Send a message using this publisher. + * + *

Messages sent using this publisher will eventually be sent to a downstream subscriber, in the order they were + * written. When the message is sent to the subscriber, the returned future will be completed successfully. + * + *

This method may be invoked concurrently when the order of messages is not important. + * + *

In the time between when this method is invoked and the returned future is not completed, this publisher stores the + * request message in memory. Callers are recommended to limit the number of sends in progress at a time to bound the + * amount of memory used by this publisher. + * + *

The returned future will be completed exceptionally if the downstream subscriber cancels the subscription, or + * if the {@code send} call was performed after a {@link #complete()} or {@link #error(Throwable)} call. + * + * @param value The message to send. Must not be null. + * @return A future that is completed when the message is sent to the subscriber. + */ + public CompletableFuture send(T value) { + log.trace(() -> "Received send() with " + value); + + OnNextQueueEntry entry = new OnNextQueueEntry<>(value); + try { + Validate.notNull(value, "Null cannot be written."); + validateRejectState(); + eventQueue.add(entry); + processEventQueue(); + } catch (RuntimeException t) { + entry.resultFuture.completeExceptionally(t); + } + return entry.resultFuture; + } + + /** + * Indicate that no more {@link #send(Object)} calls will be made, and that stream of messages is completed successfully. + * + *

This can be called before any in-flight {@code send} calls are complete. Such messages will be processed before the + * stream is treated as complete. The returned future will be completed successfully when the {@code complete} is sent to + * the downstream subscriber. + * + *

After this method is invoked, any future {@link #send(Object)}, {@code complete()} or {@link #error(Throwable)} + * calls will be completed exceptionally and not be processed. + * + *

The returned future will be completed exceptionally if the downstream subscriber cancels the subscription, or + * if the {@code complete} call was performed after a {@code complete} or {@link #error(Throwable)} call. + * + * @return A future that is completed when the complete has been sent to the downstream subscriber. + */ + public CompletableFuture complete() { + log.trace(() -> "Received complete()"); + + OnCompleteQueueEntry entry = new OnCompleteQueueEntry<>(); + + try { + validateRejectState(); + setRejectExceptionOrThrow(() -> new IllegalStateException("complete() has been invoked")); + eventQueue.add(entry); + processEventQueue(); + } catch (RuntimeException t) { + entry.resultFuture.completeExceptionally(t); + } + return entry.resultFuture; + } + + /** + * Indicate that no more {@link #send(Object)} calls will be made, and that streaming of messages has failed. + * + *

This can be called before any in-flight {@code send} calls are complete. Such messages will be processed before the + * stream is treated as being in-error. The returned future will be completed successfully when the {@code error} is + * sent to the downstream subscriber. + * + *

After this method is invoked, any future {@link #send(Object)}, {@link #complete()} or {@code #error(Throwable)} + * calls will be completed exceptionally and not be processed. + * + *

The returned future will be completed exceptionally if the downstream subscriber cancels the subscription, or + * if the {@code complete} call was performed after a {@link #complete()} or {@code error} call. + * + * @param error The error to send. + * @return A future that is completed when the exception has been sent to the downstream subscriber. + */ + public CompletableFuture error(Throwable error) { + log.trace(() -> "Received error() with " + error, error); + + OnErrorQueueEntry entry = new OnErrorQueueEntry<>(error); + + try { + validateRejectState(); + setRejectExceptionOrThrow(() -> new IllegalStateException("error() has been invoked")); + eventQueue.add(entry); + processEventQueue(); + } catch (RuntimeException t) { + entry.resultFuture.completeExceptionally(t); + } + return entry.resultFuture; + } + + /** + * A method called by the downstream subscriber in order to subscribe to the publisher. + */ + @Override + public void subscribe(Subscriber s) { + if (subscriber != null) { + s.onSubscribe(new NoOpSubscription()); + s.onError(new IllegalStateException("Only one subscription may be active at a time.")); + } + this.subscriber = s; + s.onSubscribe(new SubscriptionImpl()); + processEventQueue(); + } + + /** + * Process the messages in the event queue. This is invoked after every operation on the publisher that changes the state + * of the event queue. + * + *

Internally, this method will only be executed by one thread at a time. Any calls to this method will another thread + * is processing the queue will return immediately. This ensures: (1) thread safety in queue processing, (2) mutual recursion + * between onSubscribe/onNext with {@link Subscription#request(long)} are impossible. + */ + private void processEventQueue() { + do { + if (!processingQueue.compareAndSet(false, true)) { + // Some other thread is processing the queue, so we don't need to. + return; + } + + try { + doProcessQueue(); + } catch (Throwable e) { + panicAndDie(e); + break; + } finally { + processingQueue.set(false); + } + + // Once releasing the processing-queue flag, we need to double-check that the queue still doesn't need to be + // processed, because new messages might have come in since we decided to release the flag. + } while (shouldProcessQueueEntry(eventQueue.peek())); + } + + /** + * Pop events off of the queue and process them in the order they are given, returning when we can no longer process the + * event at the head of the queue. + * + *

Invoked only from within the {@link #processEventQueue()} method with the {@link #processingQueue} flag held. + */ + private void doProcessQueue() { + while (true) { + QueueEntry entry = eventQueue.peek(); + + if (!shouldProcessQueueEntry(entry)) { + // We're done processing entries. + return; + } + + if (entryTypesToFail.contains(entry.type())) { + // We're supposed to skip this entry type. Fail it and move on. + entry.resultFuture.completeExceptionally(rejectException.get().get()); + } else { + switch (entry.type()) { + case ON_NEXT: + OnNextQueueEntry onNextEntry = (OnNextQueueEntry) entry; + + log.trace(() -> "Calling onNext() with " + onNextEntry.value); + subscriber.onNext(onNextEntry.value); + long newDemand = outstandingDemand.decrementAndGet(); + log.trace(() -> "Decreased demand to " + newDemand); + break; + case ON_COMPLETE: + entryTypesToFail.addAll(asList(ON_NEXT, ON_COMPLETE, ON_ERROR)); + log.trace(() -> "Calling onComplete()"); + subscriber.onComplete(); + break; + case ON_ERROR: + OnErrorQueueEntry onErrorEntry = (OnErrorQueueEntry) entry; + + entryTypesToFail.addAll(asList(ON_NEXT, ON_COMPLETE, ON_ERROR)); + log.trace(() -> "Calling onError() with " + onErrorEntry.failure, onErrorEntry.failure); + subscriber.onError(onErrorEntry.failure); + break; + case CANCEL: + subscriber = null; // Allow subscriber to be garbage collected after cancellation. + break; + default: + // Should never happen. Famous last words? + throw new IllegalStateException("Unknown entry type: " + entry.type()); + } + + entry.resultFuture.complete(null); + } + + eventQueue.remove(); + } + } + + /** + * Return true if we should process the provided queue entry. + */ + private boolean shouldProcessQueueEntry(QueueEntry entry) { + if (subscriber == null) { + // We don't have a subscriber yet. + return false; + } + + if (entry == null) { + // The queue is empty. + return false; + } + + if (entry.type() != ON_NEXT) { + // This event isn't an on-next event, so we don't need subscriber demand in order to process it. + return true; + } + + if (entryTypesToFail.contains(ON_NEXT)) { + // This is an on-next call (decided above), but we're failing on-next calls. Go ahead and fail it. + return true; + } + + // This is an on-next event and we're not failing on-next events, so make sure we have demand available before + // processing it. + return outstandingDemand.get() > 0; + } + + /** + * Invoked from within {@link #processEventQueue()} when we can't process the queue for some reason. This is likely + * caused by a downstream subscriber throwing an exception from {@code onNext}, which it should never do. + * + *

Here we try our best to fail all of the entries in the queue, so that no callers have "stuck" futures. + */ + private void panicAndDie(Throwable cause) { + try { + // Create exception here instead of in supplier to preserve a more-useful stack trace. + RuntimeException failure = new IllegalStateException("Encountered fatal error in publisher", cause); + rejectException.compareAndSet(null, () -> failure); + entryTypesToFail.addAll(asList(QueueEntry.Type.values())); + subscriber.onError(cause instanceof Error ? cause : failure); + + while (true) { + QueueEntry entry = eventQueue.poll(); + if (entry == null) { + break; + } + entry.resultFuture.completeExceptionally(failure); + } + } catch (Throwable t) { + t.addSuppressed(cause); + log.error(() -> "Failed while processing a failure. This could result in stuck futures.", t); + } + } + + /** + * Ensure that {@link #rejectException} is null. If it is not, throw the exception. + */ + private void validateRejectState() { + if (rejectException.get() != null) { + throw rejectException.get().get(); + } + } + + /** + * Set the {@link #rejectException}, if it is null. If it is not, throw the exception. + */ + private void setRejectExceptionOrThrow(Supplier rejectedException) { + if (!rejectException.compareAndSet(null, rejectedException)) { + throw rejectException.get().get(); + } + } + + /** + * The subscription passed to the first {@link #subscriber} that subscribes to this publisher. This allows the downstream + * subscriber to request for more {@code onNext} calls or to {@code cancel} the stream of messages. + */ + private class SubscriptionImpl implements Subscription { + @Override + public void request(long n) { + log.trace(() -> "Received request() with " + n); + if (n <= 0) { + // Create exception here instead of in supplier to preserve a more-useful stack trace. + IllegalArgumentException failure = new IllegalArgumentException("A downstream publisher requested an invalid " + + "amount of data: " + n); + rejectException.compareAndSet(null, () -> failure); + eventQueue.add(new OnErrorQueueEntry<>(failure)); + entryTypesToFail.addAll(asList(ON_NEXT, ON_COMPLETE)); + processEventQueue(); + } else { + long newDemand = outstandingDemand.updateAndGet(current -> { + if (Long.MAX_VALUE - current < n) { + return Long.MAX_VALUE; + } + + return current + n; + }); + log.trace(() -> "Increased demand to " + newDemand); + processEventQueue(); + } + } + + @Override + public void cancel() { + log.trace(() -> "Received cancel()"); + + // Create exception here instead of in supplier to preserve a more-useful stack trace. + IllegalStateException failure = new IllegalStateException("A downstream publisher has cancelled the subscription."); + rejectException.compareAndSet(null, () -> failure); + eventQueue.add(new CancelQueueEntry<>()); + entryTypesToFail.addAll(asList(ON_NEXT, ON_COMPLETE, ON_ERROR)); + processEventQueue(); + } + } + + /** + * An entry in the {@link #eventQueue}. + */ + abstract static class QueueEntry { + /** + * The future that was returned to a {@link #send(Object)}, {@link #complete()} or {@link #error(Throwable)} message. + */ + protected final CompletableFuture resultFuture = new CompletableFuture<>(); + + /** + * Retrieve the type of this queue entry. + */ + protected abstract Type type(); + + protected enum Type { + ON_NEXT, + ON_COMPLETE, + ON_ERROR, + CANCEL + } + } + + /** + * An entry added when we get a {@link #send(Object)} call. + */ + private static final class OnNextQueueEntry extends QueueEntry { + private final T value; + + private OnNextQueueEntry(T value) { + this.value = value; + } + + @Override + protected Type type() { + return ON_NEXT; + } + } + + /** + * An entry added when we get a {@link #complete()} call. + */ + private static final class OnCompleteQueueEntry extends QueueEntry { + @Override + protected Type type() { + return ON_COMPLETE; + } + } + + /** + * An entry added when we get an {@link #error(Throwable)} call. + */ + private static final class OnErrorQueueEntry extends QueueEntry { + private final Throwable failure; + + private OnErrorQueueEntry(Throwable failure) { + this.failure = failure; + } + + @Override + protected Type type() { + return ON_ERROR; + } + } + + /** + * An entry added when we get a {@link SubscriptionImpl#cancel()} call. + */ + private static final class CancelQueueEntry extends QueueEntry { + @Override + protected Type type() { + return CANCEL; + } + } + + /** + * A subscription that does nothing. This is used for signaling {@code onError} to subscribers that subscribe to this + * publisher for the second time. Only one subscriber is supported. + */ + private static final class NoOpSubscription implements Subscription { + @Override + public void request(long n) { + } + + @Override + public void cancel() { + } + } +} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/StoringSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/StoringSubscriber.java new file mode 100644 index 000000000000..15c2a13da00c --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/StoringSubscriber.java @@ -0,0 +1,193 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Validate; + +/** + * An implementation of {@link Subscriber} that stores the events it receives for retrieval. + * + *

Events can be observed via {@link #peek()} and {@link #drop()}. The number of events stored is limited by the + * {@code maxElements} configured at construction. + */ +@SdkProtectedApi +public class StoringSubscriber implements Subscriber { + /** + * The maximum number of events that can be stored in this subscriber. The number of events in {@link #events} may be + * slightly higher once {@link #onComplete()} and {@link #onError(Throwable)} events are added. + */ + private final int maxEvents; + + /** + * The events stored in this subscriber. The maximum size of this queue is approximately {@link #maxEvents}. + */ + private final Queue> events; + + /** + * The active subscription. Set when {@link #onSubscribe(Subscription)} is invoked. + */ + private Subscription subscription; + + /** + * Create a subscriber that stores up to {@code maxElements} events for retrieval. + */ + public StoringSubscriber(int maxEvents) { + Validate.isPositive(maxEvents, "Max elements must be positive."); + this.maxEvents = maxEvents; + this.events = new ConcurrentLinkedQueue<>(); + } + + /** + * Check the first event stored in this subscriber. + * + *

This will return empty if no events are currently available (outstanding demand has not yet + * been filled). + */ + public Optional> peek() { + return Optional.ofNullable(events.peek()); + } + + /** + * Remove and return the first event stored in this subscriber. + * + *

This will return empty if no events are currently available (outstanding demand has not yet + * been filled). + */ + public Optional> poll() { + Event result = events.poll(); + if (result != null) { + subscription.request(1); + return Optional.of(result); + } + return Optional.empty(); + } + + @Override + public void onSubscribe(Subscription subscription) { + if (this.subscription != null) { + subscription.cancel(); + } + + this.subscription = subscription; + subscription.request(maxEvents); + } + + @Override + public void onNext(T t) { + Validate.notNull(t, "onNext(null) is not allowed."); + + try { + events.add(Event.value(t)); + } catch (RuntimeException e) { + subscription.cancel(); + onError(new IllegalStateException("Failed to store element.", e)); + } + } + + @Override + public void onComplete() { + events.add(Event.complete()); + } + + @Override + public void onError(Throwable throwable) { + events.add(Event.error(throwable)); + } + + /** + * An event stored for later retrieval by this subscriber. + * + *

Stored events are one of the follow {@link #type()}s: + *

    + *
  • {@code VALUE} - A value received by {@link #onNext(Object)}, available via {@link #value()}.
  • + *
  • {@code COMPLETE} - Indicating {@link #onComplete()} was called.
  • + *
  • {@code ERROR} - Indicating {@link #onError(Throwable)} was called. The exception is available via + * {@link #runtimeError()}
  • + *
  • {@code EMPTY} - Indicating that no events remain in the queue (but more from upstream may be given later).
  • + *
+ */ + public static final class Event { + private final EventType type; + private final T value; + private final Throwable error; + + private Event(EventType type, T value, Throwable error) { + this.type = type; + this.value = value; + this.error = error; + } + + private static Event complete() { + return new Event<>(EventType.ON_COMPLETE, null, null); + } + + private static Event error(Throwable error) { + return new Event<>(EventType.ON_ERROR, null, error); + } + + private static Event value(T value) { + return new Event<>(EventType.ON_NEXT, value, null); + } + + /** + * Retrieve the {@link EventType} of this event. + */ + public EventType type() { + return type; + } + + /** + * The value stored in this {@code VALUE} type. Null for all other event types. + */ + public T value() { + return value; + } + + /** + * The error stored in this {@code ERROR} type. Null for all other event types. If a checked exception was received via + * {@link #onError(Throwable)}, this will return a {@code RuntimeException} with the checked exception as its cause. + */ + public RuntimeException runtimeError() { + if (type != EventType.ON_ERROR) { + return null; + } + + if (error instanceof RuntimeException) { + return (RuntimeException) error; + } + + if (error instanceof IOException) { + return new UncheckedIOException((IOException) error); + } + + return new RuntimeException(error); + } + } + + public enum EventType { + ON_NEXT, + ON_COMPLETE, + ON_ERROR + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriberTckTest.java new file mode 100644 index 000000000000..0dc9c6229511 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriberTckTest.java @@ -0,0 +1,72 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.nio.ByteBuffer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class ByteBufferStoringSubscriberTckTest extends SubscriberWhiteboxVerification { + protected ByteBufferStoringSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { + return new ByteBufferStoringSubscriber(16) { + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + probe.registerOnError(throwable); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + subscription.request(elements); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(ByteBuffer next) { + super.onNext(next); + probe.registerOnNext(next); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public ByteBuffer createElement(int element) { + return ByteBuffer.wrap(new byte[0]); + } +} \ No newline at end of file diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriberTest.java new file mode 100644 index 000000000000..798098e3c585 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/ByteBufferStoringSubscriberTest.java @@ -0,0 +1,308 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import java.nio.ByteBuffer; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.utils.async.ByteBufferStoringSubscriber.TransferResult; + +public class ByteBufferStoringSubscriberTest { + @Test + public void constructorCalled_withNonPositiveSize_throwsException() { + assertThatCode(() -> new ByteBufferStoringSubscriber(1)).doesNotThrowAnyException(); + assertThatCode(() -> new ByteBufferStoringSubscriber(Integer.MAX_VALUE)).doesNotThrowAnyException(); + + assertThatThrownBy(() -> new ByteBufferStoringSubscriber(0)).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new ByteBufferStoringSubscriber(-1)).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new ByteBufferStoringSubscriber(Integer.MIN_VALUE)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void doesNotRequestMoreThanMaxBytes() { + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(3); + Subscription subscription = mock(Subscription.class); + + subscriber.onSubscribe(subscription); + verify(subscription).request(1); + + subscriber.onNext(fullByteBufferOfSize(2)); + verify(subscription, times(2)).request(1); + + subscriber.onNext(fullByteBufferOfSize(0)); + verify(subscription, times(3)).request(1); + + subscriber.onNext(fullByteBufferOfSize(1)); + verifyNoMoreInteractions(subscription); + } + + @Test + public void canStoreMoreThanMaxBytesButWontAskForMoreUntilBelowMax() { + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(3); + Subscription subscription = mock(Subscription.class); + + subscriber.onSubscribe(subscription); + verify(subscription).request(1); + + subscriber.onNext(fullByteBufferOfSize(1)); // After: Storing 1 + verify(subscription, times(2)).request(1); // It should request more + + subscriber.onNext(fullByteBufferOfSize(50)); // After: Storing 51 + subscriber.transferTo(emptyByteBufferOfSize(48)); // After: Storing 3 + verifyNoMoreInteractions(subscription); // It should NOT request more + + subscriber.transferTo(emptyByteBufferOfSize(1)); // After: Storing 2 + verify(subscription, times(3)).request(1); // It should request more + } + + @Test + public void noDataTransferredIfNoDataBuffered() { + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(2); + subscriber.onSubscribe(mock(Subscription.class)); + + ByteBuffer out = emptyByteBufferOfSize(1); + + assertThat(subscriber.transferTo(out)).isEqualTo(TransferResult.SUCCESS); + assertThat(out.remaining()).isEqualTo(1); + } + + @Test + public void noDataTransferredIfComplete() { + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onComplete(); + + ByteBuffer out = emptyByteBufferOfSize(1); + + assertThat(subscriber.transferTo(out)).isEqualTo(TransferResult.END_OF_STREAM); + assertThat(out.remaining()).isEqualTo(1); + } + + @Test + public void noDataTransferredIfError() { + RuntimeException error = new RuntimeException(); + + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onError(error); + + ByteBuffer out = emptyByteBufferOfSize(1); + + assertThatThrownBy(() -> subscriber.transferTo(out)).isEqualTo(error); + assertThat(out.remaining()).isEqualTo(1); + } + + @Test + public void checkedExceptionsAreWrapped() { + Exception error = new Exception(); + + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onError(error); + + ByteBuffer out = emptyByteBufferOfSize(1); + + assertThatThrownBy(() -> subscriber.transferTo(out)).hasCause(error); + assertThat(out.remaining()).isEqualTo(1); + } + + @Test + public void completeIsReportedEvenWithExactOutSize() { + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onNext(fullByteBufferOfSize(2)); + subscriber.onComplete(); + + ByteBuffer out = emptyByteBufferOfSize(2); + assertThat(subscriber.transferTo(out)).isEqualTo(TransferResult.END_OF_STREAM); + assertThat(out.remaining()).isEqualTo(0); + } + + @Test + public void completeIsReportedEvenWithExtraOutSize() { + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onNext(fullByteBufferOfSize(2)); + subscriber.onComplete(); + + ByteBuffer out = emptyByteBufferOfSize(3); + assertThat(subscriber.transferTo(out)).isEqualTo(TransferResult.END_OF_STREAM); + assertThat(out.remaining()).isEqualTo(1); + } + + @Test + public void errorIsReportedEvenWithExactOutSize() { + RuntimeException error = new RuntimeException(); + + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onNext(fullByteBufferOfSize(2)); + subscriber.onError(error); + + ByteBuffer out = emptyByteBufferOfSize(2); + assertThatThrownBy(() -> subscriber.transferTo(out)).isEqualTo(error); + assertThat(out.remaining()).isEqualTo(0); + } + + @Test + public void errorIsReportedEvenWithExtraOutSize() { + RuntimeException error = new RuntimeException(); + + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onNext(fullByteBufferOfSize(2)); + subscriber.onError(error); + + ByteBuffer out = emptyByteBufferOfSize(3); + assertThatThrownBy(() -> subscriber.transferTo(out)).isEqualTo(error); + assertThat(out.remaining()).isEqualTo(1); + } + + @Test + public void dataIsDeliveredInTheRightOrder() { + ByteBuffer buffer1 = fullByteBufferOfSize(1); + ByteBuffer buffer2 = fullByteBufferOfSize(1); + ByteBuffer buffer3 = fullByteBufferOfSize(1); + + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(3); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onNext(buffer1); + subscriber.onNext(buffer2); + subscriber.onNext(buffer3); + subscriber.onComplete(); + + ByteBuffer out = emptyByteBufferOfSize(4); + subscriber.transferTo(out); + + out.flip(); + assertThat(out.get()).isEqualTo(buffer1.get()); + assertThat(out.get()).isEqualTo(buffer2.get()); + assertThat(out.get()).isEqualTo(buffer3.get()); + assertThat(out.hasRemaining()).isFalse(); + } + + @Test + @Timeout(30) + public void stochastic_subscriberSeemsThreadSafe() throws Throwable { + ExecutorService producer = Executors.newFixedThreadPool(1); + ExecutorService consumer = Executors.newFixedThreadPool(1); + try { + ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(50); + + AtomicBoolean testRunning = new AtomicBoolean(true); + AtomicInteger messageNumber = new AtomicInteger(0); + + AtomicReference producerFailure = new AtomicReference<>(); + Subscription subscription = new Subscription() { + @Override + public void request(long n) { + producer.submit(() -> { + try { + for (int i = 0; i < n; i++) { + ByteBuffer buffer = ByteBuffer.allocate(4); + buffer.putInt(messageNumber.getAndIncrement()); + buffer.flip(); + subscriber.onNext(buffer); + } + } catch (Throwable t) { + producerFailure.set(t); + } + }); + } + + @Override + public void cancel() { + producerFailure.set(new AssertionError("Cancel not expected.")); + } + }; + + subscriber.onSubscribe(subscription); + + Future consumerFuture = consumer.submit(() -> { + ByteBuffer carryOver = ByteBuffer.allocate(4); + + int expectedMessageNumber = 0; + while (testRunning.get()) { + Thread.sleep(1); + + ByteBuffer out = ByteBuffer.allocate(4 + expectedMessageNumber); + subscriber.transferTo(out); + + out.flip(); + + if (carryOver.position() > 0) { + int oldOutLimit = out.limit(); + out.limit(carryOver.remaining()); + carryOver.put(out); + out.limit(oldOutLimit); + + carryOver.flip(); + assertThat(carryOver.getInt()).isEqualTo(expectedMessageNumber); + ++expectedMessageNumber; + carryOver.clear(); + } + + while (out.remaining() >= 4) { + assertThat(out.getInt()).isEqualTo(expectedMessageNumber); + ++expectedMessageNumber; + } + + if (out.hasRemaining()) { + carryOver.put(out); + } + } + return null; + }); + + Thread.sleep(5_000); + testRunning.set(false); + consumerFuture.get(); + if (producerFailure.get() != null) { + throw producerFailure.get(); + } + assertThat(messageNumber.get()).isGreaterThan(10); // ensure we actually tested something + } finally { + producer.shutdownNow(); + consumer.shutdownNow(); + } + } + + private ByteBuffer fullByteBufferOfSize(int size) { + byte[] data = new byte[size]; + ThreadLocalRandom.current().nextBytes(data); + return ByteBuffer.wrap(data); + } + + private ByteBuffer emptyByteBufferOfSize(int size) { + return ByteBuffer.allocate(size); + } +} \ No newline at end of file diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/SimplePublisherTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/SimplePublisherTckTest.java new file mode 100644 index 000000000000..fc2ae1e95da2 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/SimplePublisherTckTest.java @@ -0,0 +1,48 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import org.reactivestreams.Publisher; +import org.reactivestreams.tck.PublisherVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class SimplePublisherTckTest extends PublisherVerification { + public SimplePublisherTckTest() { + super(new TestEnvironment()); + } + + @Override + public Publisher createPublisher(long elements) { + SimplePublisher publisher = new SimplePublisher<>(); + for (int i = 0; i < elements; i++) { + publisher.send(i); + } + publisher.complete(); + return publisher; + } + + @Override + public Publisher createFailedPublisher() { + SimplePublisher publisher = new SimplePublisher<>(); + publisher.error(new RuntimeException()); + return publisher; + } + + @Override + public long maxElementsFromPublisher() { + return 256L; + } +} \ No newline at end of file diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/SimplePublisherTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/SimplePublisherTest.java new file mode 100644 index 000000000000..383df2aca71c --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/SimplePublisherTest.java @@ -0,0 +1,551 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.async.StoringSubscriber.Event; +import software.amazon.awssdk.utils.async.StoringSubscriber.EventType; + +public class SimplePublisherTest { + /** + * This class has tests that try to break things for a fixed period of time, and then make sure nothing broke. + * This flag controls how long those tests run. Longer values provider a better guarantee of catching an issue, but + * increase the build time. 5 seconds seems okay for now, but if a flaky test is found try increasing the duration to make + * it reproduce more reliably. + */ + private static final Duration STOCHASTIC_TEST_DURATION = Duration.ofSeconds(5); + + @Test + public void immediateSuccessWorks() { + SimplePublisher publisher = new SimplePublisher<>(); + StoringSubscriber subscriber = new StoringSubscriber<>(1); + publisher.subscribe(subscriber); + publisher.complete(); + + assertThat(subscriber.poll().get().type()).isEqualTo(EventType.ON_COMPLETE); + assertThat(subscriber.poll()).isNotPresent(); + } + + @Test + public void immediateFailureWorks() { + RuntimeException error = new RuntimeException(); + + SimplePublisher publisher = new SimplePublisher<>(); + StoringSubscriber subscriber = new StoringSubscriber<>(1); + publisher.subscribe(subscriber); + publisher.error(error); + + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_ERROR); + assertThat(subscriber.peek().get().runtimeError()).isEqualTo(error); + + subscriber.poll(); + + assertThat(subscriber.poll()).isNotPresent(); + } + + @Test + public void writeAfterCompleteFails() { + SimplePublisher publisher = new SimplePublisher<>(); + publisher.complete(); + assertThat(publisher.send(5)).isCompletedExceptionally(); + } + + @Test + public void writeAfterErrorFails() { + SimplePublisher publisher = new SimplePublisher<>(); + publisher.error(new Throwable()); + assertThat(publisher.send(5)).isCompletedExceptionally(); + } + + @Test + public void completeAfterCompleteFails() { + SimplePublisher publisher = new SimplePublisher<>(); + publisher.complete(); + assertThat(publisher.complete()).isCompletedExceptionally(); + } + + @Test + public void completeAfterErrorFails() { + SimplePublisher publisher = new SimplePublisher<>(); + publisher.error(new Throwable()); + assertThat(publisher.complete()).isCompletedExceptionally(); + } + + @Test + public void errorAfterCompleteFails() { + SimplePublisher publisher = new SimplePublisher<>(); + publisher.complete(); + assertThat(publisher.error(new Throwable())).isCompletedExceptionally(); + } + + @Test + public void errorAfterErrorFails() { + SimplePublisher publisher = new SimplePublisher<>(); + publisher.error(new Throwable()); + assertThat(publisher.error(new Throwable())).isCompletedExceptionally(); + } + + @Test + public void oneDemandWorks() { + SimplePublisher publisher = new SimplePublisher<>(); + StoringSubscriber subscriber = new StoringSubscriber<>(1); + publisher.subscribe(subscriber); + + publisher.send(1); + publisher.send(2); + publisher.complete(); + + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_NEXT); + assertThat(subscriber.peek().get().value()).isEqualTo(1); + + subscriber.poll(); + + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_NEXT); + assertThat(subscriber.peek().get().value()).isEqualTo(2); + + subscriber.poll(); + + assertThat(subscriber.poll().get().type()).isEqualTo(EventType.ON_COMPLETE); + assertThat(subscriber.poll()).isNotPresent(); + } + + @Test + public void highDemandWorks() { + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + publisher.subscribe(subscriber); + subscriber.subscription.request(Long.MAX_VALUE); + + publisher.send(1); + subscriber.subscription.request(Long.MAX_VALUE); + publisher.send(2); + subscriber.subscription.request(Long.MAX_VALUE); + publisher.complete(); + subscriber.subscription.request(Long.MAX_VALUE); + + assertThat(subscriber.eventQueue.peek().get().type()).isEqualTo(EventType.ON_NEXT); + assertThat(subscriber.eventQueue.peek().get().value()).isEqualTo(1); + + subscriber.eventQueue.poll(); + + assertThat(subscriber.eventQueue.peek().get().type()).isEqualTo(EventType.ON_NEXT); + assertThat(subscriber.eventQueue.peek().get().value()).isEqualTo(2); + + subscriber.eventQueue.poll(); + + assertThat(subscriber.eventQueue.peek().get().type()).isEqualTo(EventType.ON_COMPLETE); + + subscriber.eventQueue.poll(); + + assertThat(subscriber.eventQueue.poll()).isNotPresent(); + } + + @Test + public void writeFuturesDoNotCompleteUntilAfterOnNext() { + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + publisher.subscribe(subscriber); + + CompletableFuture writeFuture = publisher.send(5); + + assertThat(subscriber.eventQueue.peek()).isNotPresent(); + assertThat(writeFuture).isNotCompleted(); + + subscriber.subscription.request(1); + + assertThat(subscriber.eventQueue.peek().get().type()).isEqualTo(EventType.ON_NEXT); + assertThat(subscriber.eventQueue.peek().get().value()).isEqualTo(5); + assertThat(writeFuture).isCompletedWithValue(null); + } + + @Test + public void completeFuturesDoNotCompleteUntilAfterOnComplete() { + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + + publisher.subscribe(subscriber); + publisher.send(5); + CompletableFuture completeFuture = publisher.complete(); + + assertThat(subscriber.eventQueue.peek()).isNotPresent(); + assertThat(completeFuture).isNotCompleted(); + + subscriber.subscription.request(1); + subscriber.eventQueue.poll(); // Drop the 5 value + + assertThat(subscriber.eventQueue.peek().get().type()).isEqualTo(EventType.ON_COMPLETE); + assertThat(completeFuture).isCompletedWithValue(null); + } + + @Test + public void errorFuturesDoNotCompleteUntilAfterOnError() { + RuntimeException error = new RuntimeException(); + + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + + publisher.subscribe(subscriber); + publisher.send(5); + CompletableFuture errorFuture = publisher.error(error); + + assertThat(subscriber.eventQueue.peek()).isNotPresent(); + assertThat(errorFuture).isNotCompleted(); + + subscriber.subscription.request(1); + subscriber.eventQueue.poll(); // Drop the 5 value + + assertThat(subscriber.eventQueue.peek().get().type()).isEqualTo(EventType.ON_ERROR); + assertThat(subscriber.eventQueue.peek().get().runtimeError()).isEqualTo(error); + assertThat(errorFuture).isCompletedWithValue(null); + } + + @Test + public void completeBeforeSubscribeIsDeliveredOnSubscribe() { + SimplePublisher publisher = new SimplePublisher<>(); + StoringSubscriber subscriber = new StoringSubscriber<>(Integer.MAX_VALUE); + + publisher.complete(); + publisher.subscribe(subscriber); + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_COMPLETE); + } + + @Test + public void errorBeforeSubscribeIsDeliveredOnSubscribe() { + SimplePublisher publisher = new SimplePublisher<>(); + StoringSubscriber subscriber = new StoringSubscriber<>(Integer.MAX_VALUE); + + RuntimeException error = new RuntimeException(); + publisher.error(error); + publisher.subscribe(subscriber); + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_ERROR); + assertThat(subscriber.peek().get().runtimeError()).isEqualTo(error); + } + + @Test + public void writeBeforeSubscribeIsDeliveredOnSubscribe() { + SimplePublisher publisher = new SimplePublisher<>(); + StoringSubscriber subscriber = new StoringSubscriber<>(Integer.MAX_VALUE); + + publisher.send(5); + publisher.subscribe(subscriber); + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_NEXT); + assertThat(subscriber.peek().get().value()).isEqualTo(5); + } + + @Test + public void cancelFailsAnyInFlightFutures() { + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + + publisher.subscribe(subscriber); + CompletableFuture writeFuture = publisher.send(5); + CompletableFuture completeFuture = publisher.complete(); + + subscriber.subscription.cancel(); + + assertThat(writeFuture).isCompletedExceptionally(); + assertThat(completeFuture).isCompletedExceptionally(); + } + + @Test + public void newCallsAfterCancelFail() { + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + + publisher.subscribe(subscriber); + subscriber.subscription.cancel(); + + assertThat(publisher.send(5)).isCompletedExceptionally(); + assertThat(publisher.complete()).isCompletedExceptionally(); + assertThat(publisher.error(new Throwable())).isCompletedExceptionally(); + } + + @Test + public void negativeDemandSkipsOutstandingMessages() { + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + + publisher.subscribe(subscriber); + CompletableFuture sendFuture = publisher.send(0); + CompletableFuture completeFuture = publisher.complete(); + subscriber.subscription.request(-1); + + assertThat(sendFuture).isCompletedExceptionally(); + assertThat(completeFuture).isCompletedExceptionally(); + assertThat(subscriber.eventQueue.poll().get().type()).isEqualTo(EventType.ON_ERROR); + } + + @Test + public void evilDownstreamPublisherThrowingInOnNextStillCancelsInFlightFutures() { + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + subscriber.failureInOnNext = new RuntimeException(); + + CompletableFuture writeFuture = publisher.send(5); + CompletableFuture completeFuture = publisher.complete(); + + publisher.subscribe(subscriber); + subscriber.subscription.request(1); + + assertThat(writeFuture).isCompletedExceptionally(); + assertThat(completeFuture).isCompletedExceptionally(); + } + + @Test + public void stochastic_onNext_singleProducerSeemsThreadSafe() throws Exception { + // Single-producer is interesting because we can validate the ordering of messages, unlike with multi-producer. + seemsThreadSafeWithProducerCount(1); + } + + @Test + public void stochastic_onNext_multiProducerSeemsThreadSafe() throws Exception { + seemsThreadSafeWithProducerCount(3); + } + + @Test + public void stochastic_completeAndError_seemThreadSafe() throws Exception { + assertTimeoutPreemptively(STOCHASTIC_TEST_DURATION.plusSeconds(5), () -> { + Instant start = Instant.now(); + Instant end = start.plus(STOCHASTIC_TEST_DURATION); + + ExecutorService executor = Executors.newCachedThreadPool(); + + while (end.isAfter(Instant.now())) { + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + publisher.subscribe(subscriber); + subscriber.subscription.request(1); + + AtomicBoolean scenarioStart = new AtomicBoolean(false); + CountDownLatch allAreWaiting = new CountDownLatch(3); + + Runnable waitForStart = () -> { + allAreWaiting.countDown(); + while (!scenarioStart.get()) { + Thread.yield(); + } + }; + + Future writeCall = executor.submit(() -> { + waitForStart.run(); + publisher.send(0).join(); + }); + + Future completeCall = executor.submit(() -> { + waitForStart.run(); + publisher.complete().join(); + }); + + Future errorCall = executor.submit(() -> { + Throwable t = new Throwable(); + waitForStart.run(); + publisher.error(t).join(); + }); + + allAreWaiting.await(); + scenarioStart.set(true); + + List> failures = new ArrayList<>(); + addIfFailed(failures, "write", writeCall); + boolean writeSucceeded = failures.isEmpty(); + + addIfFailed(failures, "complete", completeCall); + addIfFailed(failures, "error", errorCall); + + int expectedFailures = writeSucceeded ? 1 : 2; + assertThat(failures).hasSize(expectedFailures); + } + }); + } + + private void addIfFailed(List> failures, String callName, Future call) { + try { + call.get(); + } catch (Throwable t) { + failures.add(Pair.of(callName, t)); + } + } + + private void seemsThreadSafeWithProducerCount(int producerCount) throws InterruptedException, ExecutionException { + assertTimeoutPreemptively(STOCHASTIC_TEST_DURATION.plusSeconds(5), () -> { + AtomicBoolean runProducers = new AtomicBoolean(true); + AtomicBoolean runConsumers = new AtomicBoolean(true); + AtomicInteger completesReceived = new AtomicInteger(0); + + AtomicLong messageCount = new AtomicLong(0); + AtomicLong messageReceiveCount = new AtomicLong(0); + + Semaphore productionLimiter = new Semaphore(101); + Semaphore requestLimiter = new Semaphore(57); + ExecutorService executor = Executors.newFixedThreadPool(2 + producerCount); + + SimplePublisher publisher = new SimplePublisher<>(); + ControllableSubscriber subscriber = new ControllableSubscriber<>(); + publisher.subscribe(subscriber); + + // Producer tasks + List> producers = new ArrayList<>(); + for (int i = 0; i < producerCount; i++) { + producers.add(executor.submit(() -> { + while (runProducers.get()) { + productionLimiter.acquire(); + publisher.send(messageCount.getAndIncrement()); + } + publisher.complete(); // All but one producer sending this will fail. + return null; + })); + } + + // Requester Task + Future requester = executor.submit(() -> { + while (runConsumers.get()) { + requestLimiter.acquire(); + subscriber.subscription.request(1); + } + return null; + }); + + // Consumer Task + Future consumer = executor.submit(() -> { + int expectedEvent = 0; + while (runConsumers.get() || subscriber.eventQueue.peek().isPresent()) { + Optional> event = subscriber.eventQueue.poll(); + + if (!event.isPresent()) { + continue; + } + + // When we only have 1 producer, we can verify the messages are in order. + if (producerCount == 1 && event.get().type() == EventType.ON_NEXT) { + assertThat(event.get().value()).isEqualTo(expectedEvent); + expectedEvent++; + } + + if (event.get().type() == EventType.ON_NEXT) { + messageReceiveCount.incrementAndGet(); + productionLimiter.release(); + requestLimiter.release(); + } + + if (event.get().type() == EventType.ON_COMPLETE) { + completesReceived.incrementAndGet(); + } + } + }); + + Thread.sleep(STOCHASTIC_TEST_DURATION.toMillis()); + + // Shut down producers + runProducers.set(false); + productionLimiter.release(producerCount); + for (Future producer : producers) { + producer.get(); + } + + // Shut down consumers + runConsumers.set(false); + requestLimiter.release(); + requester.get(); + consumer.get(); + + assertThat(completesReceived.get()).isEqualTo(1); + assertThat(messageReceiveCount.get()).isEqualTo(messageCount.get()); + + // Make sure we actually tested something + assertThat(messageCount.get()).isGreaterThan(10); + }); + } + + private class ControllableSubscriber implements Subscriber { + private final StoringSubscriber eventQueue = new StoringSubscriber<>(Integer.MAX_VALUE); + private Subscription subscription; + private RuntimeException failureInOnNext; + + @Override + public void onSubscribe(Subscription s) { + this.subscription = new ControllableSubscription(s); + + // Give the event queue a subscription we just ignore. We are the captain of the subscription! + eventQueue.onSubscribe(new Subscription() { + @Override + public void request(long n) { + } + + @Override + public void cancel() { + } + }); + } + + @Override + public void onNext(T o) { + if (failureInOnNext != null) { + throw failureInOnNext; + } + eventQueue.onNext(o); + } + + @Override + public void onError(Throwable t) { + eventQueue.onError(t); + } + + @Override + public void onComplete() { + eventQueue.onComplete(); + } + + private class ControllableSubscription implements Subscription { + private final Subscription delegate; + + private ControllableSubscription(Subscription s) { + delegate = s; + } + + @Override + public void request(long n) { + delegate.request(n); + } + + @Override + public void cancel() { + delegate.cancel(); + } + } + } + +} \ No newline at end of file diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/StoringSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/StoringSubscriberTckTest.java new file mode 100644 index 000000000000..8d23a23f9367 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/StoringSubscriberTckTest.java @@ -0,0 +1,71 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class StoringSubscriberTckTest extends SubscriberWhiteboxVerification { + protected StoringSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { + return new StoringSubscriber(16) { + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + probe.registerOnError(throwable); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + subscription.request(elements); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(Integer nextItems) { + super.onNext(nextItems); + probe.registerOnNext(nextItems); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public Integer createElement(int element) { + return element; + } +} \ No newline at end of file diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/StoringSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/StoringSubscriberTest.java new file mode 100644 index 000000000000..569022d97cfd --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/StoringSubscriberTest.java @@ -0,0 +1,198 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.utils.async.StoringSubscriber.Event; +import software.amazon.awssdk.utils.async.StoringSubscriber.EventType; + +public class StoringSubscriberTest { + @Test + public void constructorCalled_withNonPositiveSize_throwsException() { + assertThatCode(() -> new StoringSubscriber<>(1)).doesNotThrowAnyException(); + assertThatCode(() -> new StoringSubscriber<>(Integer.MAX_VALUE)).doesNotThrowAnyException(); + + assertThatThrownBy(() -> new StoringSubscriber<>(0)).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new StoringSubscriber<>(-1)).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new StoringSubscriber<>(Integer.MIN_VALUE)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void doesNotStoreMoreThanMaxElements() { + StoringSubscriber subscriber = new StoringSubscriber<>(2); + Subscription subscription = mock(Subscription.class); + + subscriber.onSubscribe(subscription); + verify(subscription).request(2); + + subscriber.onNext(0); + subscriber.onNext(0); + subscriber.peek(); + verifyNoMoreInteractions(subscription); + + subscriber.poll(); + subscriber.poll(); + verify(subscription, times(2)).request(1); + + assertThat(subscriber.peek()).isNotPresent(); + verifyNoMoreInteractions(subscription); + } + + @Test + public void returnsEmptyEventWithOutstandingDemand() { + StoringSubscriber subscriber = new StoringSubscriber<>(2); + subscriber.onSubscribe(mock(Subscription.class)); + assertThat(subscriber.peek()).isNotPresent(); + } + + @Test + public void returnsCompleteOnComplete() { + StoringSubscriber subscriber = new StoringSubscriber<>(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onComplete(); + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_COMPLETE); + } + + @Test + public void returnsErrorOnError() { + RuntimeException error = new RuntimeException(); + StoringSubscriber subscriber = new StoringSubscriber<>(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onError(error); + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_ERROR); + assertThat(subscriber.peek().get().runtimeError()).isEqualTo(error); + } + + @Test + public void errorWrapsCheckedExceptions() { + Exception error = new Exception(); + StoringSubscriber subscriber = new StoringSubscriber<>(2); + subscriber.onSubscribe(mock(Subscription.class)); + subscriber.onError(error); + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_ERROR); + assertThat(subscriber.peek().get().runtimeError()).hasCause(error); + } + + @Test + public void deliversMessagesInTheCorrectOrder() { + StoringSubscriber subscriber = new StoringSubscriber<>(2); + Subscription subscription = mock(Subscription.class); + + subscriber.onSubscribe(subscription); + subscriber.onNext(1); + subscriber.onNext(2); + subscriber.onComplete(); + + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_NEXT); + assertThat(subscriber.peek().get().value()).isEqualTo(1); + subscriber.poll(); + + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_NEXT); + assertThat(subscriber.peek().get().value()).isEqualTo(2); + subscriber.poll(); + + assertThat(subscriber.peek().get().type()).isEqualTo(EventType.ON_COMPLETE); + subscriber.poll(); + + assertThat(subscriber.peek()).isNotPresent(); + } + + @Test + @Timeout(30) + public void stochastic_subscriberSeemsThreadSafe() throws Throwable { + ExecutorService producer = Executors.newFixedThreadPool(1); + ExecutorService consumer = Executors.newFixedThreadPool(1); + try { + StoringSubscriber subscriber = new StoringSubscriber<>(10); + + AtomicBoolean testRunning = new AtomicBoolean(true); + AtomicInteger messageNumber = new AtomicInteger(0); + + AtomicReference producerFailure = new AtomicReference<>(); + Subscription subscription = new Subscription() { + @Override + public void request(long n) { + producer.submit(() -> { + try { + for (int i = 0; i < n; i++) { + subscriber.onNext(messageNumber.getAndIncrement()); + } + } catch (Throwable t) { + producerFailure.set(t); + } + }); + } + + @Override + public void cancel() { + producerFailure.set(new AssertionError("Cancel not expected.")); + } + }; + + subscriber.onSubscribe(subscription); + + Future consumerFuture = consumer.submit(() -> { + int expectedMessageNumber = 0; + while (testRunning.get()) { + Thread.sleep(1); + + Optional> current = subscriber.peek(); + Optional> current2 = subscriber.peek(); + + if (current.isPresent()) { + assertThat(current.get()).isSameAs(current2.get()); + Event event = current.get(); + + assertThat(event.type()).isEqualTo(EventType.ON_NEXT); + assertThat(event.value()).isEqualTo(expectedMessageNumber); + expectedMessageNumber++; + } + + subscriber.poll(); + } + return null; + }); + + Thread.sleep(5_000); + testRunning.set(false); + consumerFuture.get(); + if (producerFailure.get() != null) { + throw producerFailure.get(); + } + assertThat(messageNumber.get()).isGreaterThan(10); // ensure we actually tested something + } finally { + producer.shutdownNow(); + consumer.shutdownNow(); + } + } +} \ No newline at end of file