Skip to content

Commit bbc2d10

Browse files
authored
Ensure SocketStream recalculates the timeout on each read (#1299)
JAVA-5298
1 parent 5ca354f commit bbc2d10

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

driver-core/src/main/com/mongodb/internal/TimeoutContext.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ public TimeoutContext withAdditionalReadTimeout(final int additionalReadTimeout)
213213

214214
private long timeoutRemainingMS() {
215215
assertNotNull(timeout);
216+
if (timeout.hasExpired()) {
217+
throw createMongoTimeoutException("The operation timeout has expired.");
218+
}
216219
return timeout.isInfinite() ? 0 : timeout.remaining(MILLISECONDS);
217220
}
218221

driver-core/src/main/com/mongodb/internal/connection/SocketStream.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,16 @@ public void write(final List<ByteBuf> buffers, final OperationContext operationC
171171

172172
@Override
173173
public ByteBuf read(final int numBytes, final OperationContext operationContext) throws IOException {
174-
int readTimeoutMS = (int) operationContext.getTimeoutContext().getReadTimeoutMS();
175-
if (readTimeoutMS > 0) {
176-
socket.setSoTimeout(readTimeoutMS);
177-
}
178174
try {
179175
ByteBuf buffer = bufferProvider.getBuffer(numBytes);
180176
try {
181177
int totalBytesRead = 0;
182178
byte[] bytes = buffer.array();
183179
while (totalBytesRead < buffer.limit()) {
180+
int readTimeoutMS = (int) operationContext.getTimeoutContext().getReadTimeoutMS();
181+
if (readTimeoutMS > 0) {
182+
socket.setSoTimeout(readTimeoutMS);
183+
}
184184
int bytesRead = inputStream.read(bytes, totalBytesRead, buffer.limit() - totalBytesRead);
185185
if (bytesRead == -1) {
186186
throw new MongoSocketReadException("Prematurely reached end of stream", getAddress());

driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import static com.mongodb.ClusterFixture.sleep;
3333
import static java.util.Arrays.asList;
3434
import static org.junit.jupiter.api.Assertions.assertAll;
35+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
3536
import static org.junit.jupiter.api.Assertions.assertEquals;
3637
import static org.junit.jupiter.api.Assertions.assertFalse;
3738
import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -197,6 +198,30 @@ Collection<DynamicTest> timeoutContextTest() {
197198
() -> assertTrue(smallTimeout.hasExpired())
198199
);
199200
}),
201+
dynamicTest("throws when calculating timeout if expired", () -> {
202+
TimeoutContext smallTimeout = new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(1));
203+
TimeoutContext longTimeout =
204+
new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(9999999));
205+
TimeoutContext noTimeout = new TimeoutContext(TIMEOUT_SETTINGS);
206+
sleep(100);
207+
assertAll(
208+
() -> assertThrows(MongoOperationTimeoutException.class, smallTimeout::getReadTimeoutMS),
209+
() -> assertThrows(MongoOperationTimeoutException.class, smallTimeout::getWriteTimeoutMS),
210+
() -> assertThrows(MongoOperationTimeoutException.class, smallTimeout::getMaxTimeMS),
211+
() -> assertThrows(MongoOperationTimeoutException.class, smallTimeout::getMaxCommitTimeMS),
212+
() -> assertThrows(MongoOperationTimeoutException.class, () -> smallTimeout.timeoutOrAlternative(1)),
213+
() -> assertDoesNotThrow(longTimeout::getReadTimeoutMS),
214+
() -> assertDoesNotThrow(longTimeout::getWriteTimeoutMS),
215+
() -> assertDoesNotThrow(longTimeout::getMaxTimeMS),
216+
() -> assertDoesNotThrow(longTimeout::getMaxCommitTimeMS),
217+
() -> assertDoesNotThrow(() -> longTimeout.timeoutOrAlternative(1)),
218+
() -> assertDoesNotThrow(noTimeout::getReadTimeoutMS),
219+
() -> assertDoesNotThrow(noTimeout::getWriteTimeoutMS),
220+
() -> assertDoesNotThrow(noTimeout::getMaxTimeMS),
221+
() -> assertDoesNotThrow(noTimeout::getMaxCommitTimeMS),
222+
() -> assertDoesNotThrow(() -> noTimeout.timeoutOrAlternative(1))
223+
);
224+
}),
200225
dynamicTest("validates minRoundTripTime for maxTimeMS", () -> {
201226
Supplier<TimeoutContext> supplier = () -> new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(100));
202227
assertAll(

0 commit comments

Comments
 (0)