From 41ee7a4f710ba812333b6f2f622da7c6cafcb4c8 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 23 Jan 2024 12:05:16 +0000 Subject: [PATCH 1/2] Ensure SocketStream recalculates the timeout on each read JAVA-5298 --- .../com/mongodb/internal/connection/SocketStream.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/SocketStream.java b/driver-core/src/main/com/mongodb/internal/connection/SocketStream.java index b9841d9a379..9441f70cf68 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SocketStream.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SocketStream.java @@ -171,16 +171,16 @@ public void write(final List buffers, final OperationContext operationC @Override public ByteBuf read(final int numBytes, final OperationContext operationContext) throws IOException { - int readTimeoutMS = (int) operationContext.getTimeoutContext().getReadTimeoutMS(); - if (readTimeoutMS > 0) { - socket.setSoTimeout(readTimeoutMS); - } try { ByteBuf buffer = bufferProvider.getBuffer(numBytes); try { int totalBytesRead = 0; byte[] bytes = buffer.array(); while (totalBytesRead < buffer.limit()) { + int readTimeoutMS = (int) operationContext.getTimeoutContext().getReadTimeoutMS(); + if (readTimeoutMS > 0) { + socket.setSoTimeout(readTimeoutMS); + } int bytesRead = inputStream.read(bytes, totalBytesRead, buffer.limit() - totalBytesRead); if (bytesRead == -1) { throw new MongoSocketReadException("Prematurely reached end of stream", getAddress()); From 5e6c6ff9561967e6616274686dd3aa2b8b699cb8 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 1 Feb 2024 10:15:14 +0000 Subject: [PATCH 2/2] Ensure TimeoutContext throws if calculating a timed out value. --- .../com/mongodb/internal/TimeoutContext.java | 3 +++ .../mongodb/internal/TimeoutContextTest.java | 25 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/driver-core/src/main/com/mongodb/internal/TimeoutContext.java b/driver-core/src/main/com/mongodb/internal/TimeoutContext.java index 04b4da1509e..a2b23ee72e6 100644 --- a/driver-core/src/main/com/mongodb/internal/TimeoutContext.java +++ b/driver-core/src/main/com/mongodb/internal/TimeoutContext.java @@ -213,6 +213,9 @@ public TimeoutContext withAdditionalReadTimeout(final int additionalReadTimeout) private long timeoutRemainingMS() { assertNotNull(timeout); + if (timeout.hasExpired()) { + throw createMongoTimeoutException("The operation timeout has expired."); + } return timeout.isInfinite() ? 0 : timeout.remaining(MILLISECONDS); } diff --git a/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java b/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java index 3f1c4a58eaa..018e11712e0 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java @@ -32,6 +32,7 @@ import static com.mongodb.ClusterFixture.sleep; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -197,6 +198,30 @@ Collection timeoutContextTest() { () -> assertTrue(smallTimeout.hasExpired()) ); }), + dynamicTest("throws when calculating timeout if expired", () -> { + TimeoutContext smallTimeout = new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(1)); + TimeoutContext longTimeout = + new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(9999999)); + TimeoutContext noTimeout = new TimeoutContext(TIMEOUT_SETTINGS); + sleep(100); + assertAll( + () -> assertThrows(MongoOperationTimeoutException.class, smallTimeout::getReadTimeoutMS), + () -> assertThrows(MongoOperationTimeoutException.class, smallTimeout::getWriteTimeoutMS), + () -> assertThrows(MongoOperationTimeoutException.class, smallTimeout::getMaxTimeMS), + () -> assertThrows(MongoOperationTimeoutException.class, smallTimeout::getMaxCommitTimeMS), + () -> assertThrows(MongoOperationTimeoutException.class, () -> smallTimeout.timeoutOrAlternative(1)), + () -> assertDoesNotThrow(longTimeout::getReadTimeoutMS), + () -> assertDoesNotThrow(longTimeout::getWriteTimeoutMS), + () -> assertDoesNotThrow(longTimeout::getMaxTimeMS), + () -> assertDoesNotThrow(longTimeout::getMaxCommitTimeMS), + () -> assertDoesNotThrow(() -> longTimeout.timeoutOrAlternative(1)), + () -> assertDoesNotThrow(noTimeout::getReadTimeoutMS), + () -> assertDoesNotThrow(noTimeout::getWriteTimeoutMS), + () -> assertDoesNotThrow(noTimeout::getMaxTimeMS), + () -> assertDoesNotThrow(noTimeout::getMaxCommitTimeMS), + () -> assertDoesNotThrow(() -> noTimeout.timeoutOrAlternative(1)) + ); + }), dynamicTest("validates minRoundTripTime for maxTimeMS", () -> { Supplier supplier = () -> new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(100)); assertAll(