diff --git a/src/main/java/org/dataloader/DataLoaderHelper.java b/src/main/java/org/dataloader/DataLoaderHelper.java index 30dc05d..c78f7f1 100644 --- a/src/main/java/org/dataloader/DataLoaderHelper.java +++ b/src/main/java/org/dataloader/DataLoaderHelper.java @@ -3,6 +3,8 @@ import org.dataloader.impl.CompletableFutureKit; import org.dataloader.stats.StatisticsCollector; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashSet; @@ -12,6 +14,12 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -61,6 +69,8 @@ Object getCallContext() { private final CacheMap> futureCache; private final List>> loaderQueue; private final StatisticsCollector stats; + private Instant lastDispatchTime; + private final ScheduledExecutorService executorService; DataLoaderHelper(DataLoader dataLoader, Object batchLoadFunction, DataLoaderOptions loaderOptions, CacheMap> futureCache, StatisticsCollector stats) { this.dataLoader = dataLoader; @@ -69,6 +79,8 @@ Object getCallContext() { this.futureCache = futureCache; this.loaderQueue = new ArrayList<>(); this.stats = stats; + this.lastDispatchTime = Instant.now(); + this.executorService = Executors.newSingleThreadScheduledExecutor(); } Optional> getIfPresent(K key) { @@ -136,6 +148,10 @@ Object getCacheKey(K key) { } DispatchResult dispatch() { + return dispatch(false); + } + + DispatchResult dispatch(boolean forced) { boolean batchingEnabled = loaderOptions.batchingEnabled(); // // we copy the pre-loaded set of futures ready for dispatch @@ -143,11 +159,21 @@ DispatchResult dispatch() { final List callContexts = new ArrayList<>(); final List> queuedFutures = new ArrayList<>(); synchronized (dataLoader) { + final long timeSinceLastDispatch = Duration.between(lastDispatchTime, Instant.now()).toMillis(); + + if (batchingEnabled && !forced && loaderQueue.size() < loaderOptions.minBatchSize() && timeSinceLastDispatch < loaderOptions.maxWaitInMillis()) { + executorService.schedule(() -> { dispatch(true); }, + loaderOptions.maxWaitInMillis() - timeSinceLastDispatch, + TimeUnit.MILLISECONDS); + return new DispatchResult<>(CompletableFuture.completedFuture(emptyList()), 0); + } + loaderQueue.forEach(entry -> { keys.add(entry.getKey()); queuedFutures.add(entry.getValue()); callContexts.add(entry.getCallContext()); }); + lastDispatchTime = Instant.now(); loaderQueue.clear(); } if (!batchingEnabled || keys.isEmpty()) { diff --git a/src/main/java/org/dataloader/DataLoaderOptions.java b/src/main/java/org/dataloader/DataLoaderOptions.java index 8158902..0708be1 100644 --- a/src/main/java/org/dataloader/DataLoaderOptions.java +++ b/src/main/java/org/dataloader/DataLoaderOptions.java @@ -40,6 +40,8 @@ public class DataLoaderOptions { private CacheKey cacheKeyFunction; private CacheMap cacheMap; private int maxBatchSize; + private int minBatchSize; + private int maxWaitInMillis; private Supplier statisticsCollector; private BatchLoaderContextProvider environmentProvider; @@ -51,6 +53,8 @@ public DataLoaderOptions() { cachingEnabled = true; cachingExceptionsEnabled = true; maxBatchSize = -1; + minBatchSize = 0; + maxWaitInMillis = 0; statisticsCollector = SimpleStatisticsCollector::new; environmentProvider = NULL_PROVIDER; } @@ -68,6 +72,8 @@ public DataLoaderOptions(DataLoaderOptions other) { this.cacheKeyFunction = other.cacheKeyFunction; this.cacheMap = other.cacheMap; this.maxBatchSize = other.maxBatchSize; + this.minBatchSize = other.minBatchSize; + this.maxWaitInMillis = other.maxWaitInMillis; this.statisticsCollector = other.statisticsCollector; this.environmentProvider = other.environmentProvider; } @@ -212,10 +218,62 @@ public int maxBatchSize() { * @return the data loader options for fluent coding */ public DataLoaderOptions setMaxBatchSize(int maxBatchSize) { + if(maxBatchSize != -1 && (minBatchSize > maxBatchSize)) { + throw new IllegalArgumentException("minBatchSize should not be greater than maxBatchSize"); + } this.maxBatchSize = maxBatchSize; return this; } + /** + * Gets the minimum number of keys that will be presented to the {@link BatchLoader} function. + * minimum number of keys in a batch are also controlled by another option, maxWaitInMillis. + * + * @return the minimum batch size or 0 if there is no limit + */ + public int minBatchSize() { + return minBatchSize; + } + + /** + * Sets the minimum number of keys that will be presented to the {@link BatchLoader} function. + * minimum number of keys in a batch are also controlled by another option, maxWaitInMillis. + * + * @param minBatchSize the minimum batch size + * + * @return the data loader options for fluent coding + */ + public DataLoaderOptions setMinBatchSize(int minBatchSize) { + if(maxBatchSize != -1 && (minBatchSize > maxBatchSize)) { + throw new IllegalArgumentException("minBatchSize should not be greater than maxBatchSize"); + } + this.minBatchSize = minBatchSize; + return this; + } + + /** + * Gets the max milliseconds to wait before presenting a batch of keys to the {@link BatchLoader} function. + * minimum number of keys in a batch are also controlled by another option, minBatchSize. + * + * @return the max wait time in milliseconds or 0 if there is no limit + */ + public int maxWaitInMillis() { + return maxWaitInMillis; + } + + /** + * Sets the max milliseconds to wait before presenting a batch of keys to the {@link BatchLoader} function. + * minimum number of keys in a batch are also controlled by another option, minBatchSize. + * + * @param maxWaitInMillis the max wait time in milliseconds + * + * @return the data loader options for fluent coding + */ + public DataLoaderOptions setMaxWaitInMillis(int maxWaitInMillis) { + this.maxWaitInMillis = maxWaitInMillis; + return this; + } + /** * @return the statistics collector to use with these options */ diff --git a/src/test/java/org/dataloader/DataLoaderOptionsTest.java b/src/test/java/org/dataloader/DataLoaderOptionsTest.java new file mode 100644 index 0000000..6a08280 --- /dev/null +++ b/src/test/java/org/dataloader/DataLoaderOptionsTest.java @@ -0,0 +1,35 @@ +package org.dataloader; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import org.junit.Test; + +public class DataLoaderOptionsTest { + @Test + public void should_create_a_default_data_loader_options() { + DataLoaderOptions options = new DataLoaderOptions(createDefaultDataLoaderOptions()); + assertThat(options.batchingEnabled(), equalTo(true)); + assertThat(options.cachingEnabled(), equalTo(true)); + assertThat(options.cachingExceptionsEnabled(), equalTo(true)); + assertThat(options.maxBatchSize(), equalTo(-1)); + assertThat(options.minBatchSize(), equalTo(0)); + assertThat(options.maxWaitInMillis(), equalTo(0)); + } + + @Test(expected = IllegalArgumentException.class) + public void should_fail_if_min_batch_size_is_greater_than_max() { + DataLoaderOptions options = createDefaultDataLoaderOptions(); + options.setMaxBatchSize(5).setMinBatchSize(6); + } + + @Test(expected = IllegalArgumentException.class) + public void should_fail_if_max_batch_size_is_less_than_min() { + DataLoaderOptions options = createDefaultDataLoaderOptions(); + options.setMinBatchSize(6).setMaxBatchSize(5); + } + + private DataLoaderOptions createDefaultDataLoaderOptions() { + return DataLoaderOptions.newOptions(); + } +} diff --git a/src/test/java/org/dataloader/DataLoaderTest.java b/src/test/java/org/dataloader/DataLoaderTest.java index 0718225..7207d54 100644 --- a/src/test/java/org/dataloader/DataLoaderTest.java +++ b/src/test/java/org/dataloader/DataLoaderTest.java @@ -860,6 +860,37 @@ public void batching_disabled_and_caching_disabled_should_dispatch_immediately_a } + @Test + public void min_batch_size_with_batching_disabled_and_caching_disabled_should_dispatch_immediately_and_forget() throws Exception { + List> loadCalls = new ArrayList<>(); + DataLoaderOptions options = newOptions().setMinBatchSize(5).setMaxWaitInMillis(10).setBatchingEnabled(false).setCachingEnabled(false); + DataLoader identityLoader = idLoader(options, loadCalls); + + CompletableFuture fa = identityLoader.load("A"); + CompletableFuture fb = identityLoader.load("B"); + + // caching is off + CompletableFuture fa1 = identityLoader.load("A"); + CompletableFuture fb1 = identityLoader.load("B"); + + List values = CompletableFutureKit.allOf(asList(fa, fb, fa1, fb1)).join(); + + assertThat(fa.join(), equalTo("A")); + assertThat(fb.join(), equalTo("B")); + assertThat(fa1.join(), equalTo("A")); + assertThat(fb1.join(), equalTo("B")); + + assertThat(values, equalTo(asList("A", "B", "A", "B"))); + + assertThat(loadCalls, equalTo(asList( + singletonList("A"), + singletonList("B"), + singletonList("A"), + singletonList("B") + ))); + + } + @Test public void batches_multiple_requests_with_max_batch_size() throws Exception { List> loadCalls = new ArrayList<>(); @@ -881,6 +912,70 @@ public void batches_multiple_requests_with_max_batch_size() throws Exception { } + @Test + public void batches_multiple_requests_with_min_batch_size() throws Exception { + List> loadCalls = new ArrayList<>(); + DataLoader identityLoader = idLoader(newOptions().setMinBatchSize(3).setMaxWaitInMillis(10), loadCalls); + + CompletableFuture f1 = identityLoader.load(1); + identityLoader.dispatch(); + CompletableFuture f2 = identityLoader.load(2); + identityLoader.dispatch(); + CompletableFuture f3 = identityLoader.load(3); + identityLoader.dispatch(); + + CompletableFuture.allOf(f1, f2, f3).join(); + + assertThat(f1.join(), equalTo(1)); + assertThat(f2.join(), equalTo(2)); + assertThat(f3.join(), equalTo(3)); + + assertThat(loadCalls, equalTo(singletonList(asList(1, 2, 3)))); + + } + + @Test + public void min_batch_size_with_no_wait_time_should_not_batch_requests() throws Exception { + List> loadCalls = new ArrayList<>(); + DataLoader identityLoader = idLoader(newOptions().setMinBatchSize(3), loadCalls); + + CompletableFuture f1 = identityLoader.load(1); + identityLoader.dispatch(); + CompletableFuture f2 = identityLoader.load(2); + identityLoader.dispatch(); + CompletableFuture f3 = identityLoader.load(3); + identityLoader.dispatch(); + + CompletableFuture.allOf(f1, f2, f3).join(); + + assertThat(f1.join(), equalTo(1)); + assertThat(f2.join(), equalTo(2)); + assertThat(f3.join(), equalTo(3)); + + assertThat(loadCalls, equalTo(asList(singletonList(1), singletonList(2), singletonList(3)))); + } + + @Test + public void max_wait_time_with_no_min_batch_size_should_not_batch_requests() throws Exception { + List> loadCalls = new ArrayList<>(); + DataLoader identityLoader = idLoader(newOptions().setMaxWaitInMillis(100), loadCalls); + + CompletableFuture f1 = identityLoader.load(1); + identityLoader.dispatch(); + CompletableFuture f2 = identityLoader.load(2); + identityLoader.dispatch(); + CompletableFuture f3 = identityLoader.load(3); + identityLoader.dispatch(); + + CompletableFuture.allOf(f1, f2, f3).join(); + + assertThat(f1.join(), equalTo(1)); + assertThat(f2.join(), equalTo(2)); + assertThat(f3.join(), equalTo(3)); + + assertThat(loadCalls, equalTo(asList(singletonList(1), singletonList(2), singletonList(3)))); + } + @Test public void can_split_max_batch_sizes_correctly() throws Exception { List> loadCalls = new ArrayList<>(); @@ -903,6 +998,28 @@ public void can_split_max_batch_sizes_correctly() throws Exception { } + @Test + public void can_combine_min_batch_size_and_split_max_batch_sizes_correctly() throws Exception { + List> loadCalls = new ArrayList<>(); + DataLoader identityLoader = idLoader(newOptions().setMinBatchSize(5).setMaxWaitInMillis(20).setMaxBatchSize(5), loadCalls); + + List> results = new ArrayList<>(); + for (int i = 0; i < 21; i++) { + results.add(identityLoader.load(i)); + identityLoader.dispatch(); + } + List> expectedCalls = new ArrayList<>(); + expectedCalls.add(listFrom(0, 5)); + expectedCalls.add(listFrom(5, 10)); + expectedCalls.add(listFrom(10, 15)); + expectedCalls.add(listFrom(15, 20)); + expectedCalls.add(listFrom(20, 21)); + + results.forEach(CompletableFuture::join); + + assertThat(loadCalls, equalTo(expectedCalls)); + } + @Test public void should_Batch_loads_occurring_within_futures() { List> loadCalls = new ArrayList<>();