diff --git a/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java b/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java index 9aabc5e9af..876b8a6b7f 100644 --- a/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java +++ b/src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java @@ -44,6 +44,8 @@ public final class Constants { public static final String NAME_OF_INTERNAL_ID = "__internalNeo4jId__"; public static final String NAME_OF_ELEMENT_ID = "__elementId__"; + public static final String NAME_OF_ADDITIONAL_SORT = "__stable_uniq_sort__"; + /** * Indicates the list of dynamic labels. */ diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java index 299a376c38..abd193fed4 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java @@ -47,7 +47,6 @@ import org.springframework.data.support.PageableExecutionUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; /** * Base class for {@link RepositoryQuery} implementations for Neo4j. @@ -80,7 +79,7 @@ public QueryMethod getQueryMethod() { @Override public final Object execute(Object[] parameters) { - boolean incrementLimit = queryMethod.isSliceQuery() && !queryMethod.getQueryAnnotation().map(q -> q.countQuery()).filter(StringUtils::hasText).isPresent(); + boolean incrementLimit = queryMethod.incrementLimit(); Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor( (Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters); @@ -91,8 +90,7 @@ public final Object execute(Object[] parameters) { PropertyFilterSupport.getInputProperties(resultProcessor, factory, mappingContext), parameterAccessor, null, getMappingFunction(resultProcessor), incrementLimit ? l -> l + 1 : UnaryOperator.identity()); - Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(neo4jOperations).execute(preparedQuery, - queryMethod.isCollectionLikeQuery() || queryMethod.isPageQuery() || queryMethod.isSliceQuery()); + Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(neo4jOperations).execute(preparedQuery, queryMethod.asCollectionQuery()); Converter preparingConverter = OptionalUnwrappingConverter.INSTANCE; if (returnedType.isProjecting()) { @@ -107,6 +105,8 @@ public final Object execute(Object[] parameters) { rawResult = createPage(parameterAccessor, (List) rawResult); } else if (queryMethod.isSliceQuery()) { rawResult = createSlice(incrementLimit, parameterAccessor, (List) rawResult); + } else if (queryMethod.isScrollQuery()) { + rawResult = createWindow(resultProcessor, incrementLimit, parameterAccessor, (List) rawResult, preparedQuery.getQueryFragmentsAndParameters()); } return resultProcessor.processResult(rawResult, preparingConverter); } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java index 08433fa309..d3849715dd 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java @@ -18,6 +18,7 @@ import java.util.Collection; import java.util.function.BiFunction; import java.util.function.Supplier; +import java.util.function.UnaryOperator; import org.neo4j.driver.types.MapAccessor; import org.neo4j.driver.types.TypeSystem; @@ -37,6 +38,8 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import reactor.core.publisher.Flux; + /** * Base class for {@link RepositoryQuery} implementations for Neo4j. * @@ -67,16 +70,17 @@ public QueryMethod getQueryMethod() { @Override public final Object execute(Object[] parameters) { + boolean incrementLimit = queryMethod.incrementLimit(); Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor((Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters); ResultProcessor resultProcessor = queryMethod.getResultProcessor().withDynamicProjection(parameterAccessor); ReturnedType returnedType = resultProcessor.getReturnedType(); PreparedQuery preparedQuery = prepareQuery(returnedType.getReturnedType(), PropertyFilterSupport.getInputProperties(resultProcessor, factory, mappingContext), parameterAccessor, - null, getMappingFunction(resultProcessor)); + null, getMappingFunction(resultProcessor), incrementLimit ? l -> l + 1 : UnaryOperator.identity()); Object rawResult = new Neo4jQueryExecution.ReactiveQueryExecution(neo4jOperations).execute(preparedQuery, - queryMethod.isCollectionLikeQuery()); + queryMethod.asCollectionQuery()); Converter preparingConverter = OptionalUnwrappingConverter.INSTANCE; if (returnedType.isProjecting()) { @@ -87,10 +91,16 @@ public final Object execute(Object[] parameters) { (EntityInstanceWithSource) OptionalUnwrappingConverter.INSTANCE.convert(source)); } + if (queryMethod.isScrollQuery()) { + rawResult = ((Flux) rawResult).collectList().map(rawResultList -> + createWindow(resultProcessor, incrementLimit, parameterAccessor, rawResultList, preparedQuery.getQueryFragmentsAndParameters())); + } + return resultProcessor.processResult(rawResult, preparingConverter); } protected abstract PreparedQuery prepareQuery(Class returnedType, Collection includedProperties, Neo4jParameterAccessor parameterAccessor, - @Nullable Neo4jQueryType queryType, @Nullable Supplier> mappingFunction); + @Nullable Neo4jQueryType queryType, @Nullable Supplier> mappingFunction, + @Nullable UnaryOperator limitModifier); } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/CypherAdapterUtils.java b/src/main/java/org/springframework/data/neo4j/repository/query/CypherAdapterUtils.java index dcbf733e7b..1fe16a82a2 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/CypherAdapterUtils.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/CypherAdapterUtils.java @@ -18,20 +18,29 @@ import static org.neo4j.cypherdsl.core.Cypher.property; import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; import org.apiguardian.api.API; +import org.neo4j.cypherdsl.core.Condition; +import org.neo4j.cypherdsl.core.Conditions; import org.neo4j.cypherdsl.core.Cypher; import org.neo4j.cypherdsl.core.Expression; import org.neo4j.cypherdsl.core.Functions; import org.neo4j.cypherdsl.core.SortItem; import org.neo4j.cypherdsl.core.StatementBuilder; import org.neo4j.cypherdsl.core.SymbolicName; +import org.neo4j.driver.Value; +import org.springframework.data.domain.KeysetScrollPosition; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.neo4j.core.mapping.Constants; -import org.springframework.data.neo4j.core.mapping.GraphPropertyDescription; +import org.springframework.data.neo4j.core.mapping.Neo4jPersistentEntity; +import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty; import org.springframework.data.neo4j.core.mapping.NodeDescription; /** @@ -51,6 +60,7 @@ public final class CypherAdapterUtils { */ public static Function sortAdapterFor(NodeDescription nodeDescription) { return order -> { + String domainProperty = order.getProperty(); boolean propertyIsQualified = domainProperty.contains("."); SymbolicName root; @@ -61,12 +71,21 @@ public static Function sortAdapterFor(NodeDescription n root = Cypher.name(domainProperty.substring(0, indexOfSeparator)); domainProperty = domainProperty.substring(indexOfSeparator + 1); } - String graphProperty = nodeDescription.getGraphProperty(domainProperty) - .map(GraphPropertyDescription::getPropertyName).orElseThrow(() -> new IllegalStateException( - String.format("Cannot order by the unknown graph property: '%s'", order.getProperty()))); - Expression expression = property(root, graphProperty); - if (order.isIgnoreCase()) { - expression = Functions.toLower(expression); + + var optionalGraphProperty = nodeDescription.getGraphProperty(domainProperty); + if (optionalGraphProperty.isEmpty()) { + throw new IllegalStateException(String.format("Cannot order by the unknown graph property: '%s'", order.getProperty())); + } + var graphProperty = optionalGraphProperty.get(); + Expression expression; + if (graphProperty.isInternalIdProperty()) { + // Not using the id expression here, as the root will be referring to the constructed map being returned. + expression = property(root, Constants.NAME_OF_INTERNAL_ID); + } else { + expression = property(root, graphProperty.getPropertyName()); + if (order.isIgnoreCase()) { + expression = Functions.toLower(expression); + } } SortItem sortItem = Cypher.sort(expression); @@ -78,6 +97,72 @@ public static Function sortAdapterFor(NodeDescription n }; } + public static Condition combineKeysetIntoCondition(Neo4jPersistentEntity entity, KeysetScrollPosition scrollPosition, Sort sort) { + + var incomingKeys = scrollPosition.getKeys(); + var orderedKeys = new LinkedHashMap(); + + record PropertyAndOrder(Neo4jPersistentProperty property, Sort.Order order) { + } + var propertyAndDirection = new HashMap(); + + sort.forEach(order -> { + var property = entity.getRequiredPersistentProperty(order.getProperty()); + var propertyName = property.getPropertyName(); + propertyAndDirection.put(propertyName, new PropertyAndOrder(property, order)); + + if (incomingKeys.containsKey(propertyName)) { + orderedKeys.put(propertyName, incomingKeys.get(propertyName)); + } + }); + if (incomingKeys.containsKey(Constants.NAME_OF_ADDITIONAL_SORT)) { + orderedKeys.put(Constants.NAME_OF_ADDITIONAL_SORT, incomingKeys.get(Constants.NAME_OF_ADDITIONAL_SORT)); + } + + var root = Constants.NAME_OF_TYPED_ROOT_NODE.apply(entity); + + var resultingCondition = Conditions.noCondition(); + // This is the next equality pair if previous sort key was equal + var nextEquals = Conditions.noCondition(); + // This is the condition for when all the sort orderedKeys are equal, and we must filter via id + var allEqualsWithArtificialSort = Conditions.noCondition(); + + for (Map.Entry entry : orderedKeys.entrySet()) { + + var k = entry.getKey(); + var v = entry.getValue(); + if (v == null || (v instanceof Value value && value.isNull())) { + throw new IllegalStateException("Cannot resume from KeysetScrollPosition. Offending key: '%s' is 'null'".formatted(k)); + } + var parameter = Cypher.anonParameter(v); + + Expression expression; + + var scrollDirection = scrollPosition.getDirection(); + if (Constants.NAME_OF_ADDITIONAL_SORT.equals(k)) { + expression = entity.getIdExpression(); + var comparatorFunction = getComparatorFunction(scrollDirection == KeysetScrollPosition.Direction.Forward ? Sort.Direction.ASC : Sort.Direction.DESC, scrollDirection); + allEqualsWithArtificialSort = allEqualsWithArtificialSort.and(comparatorFunction.apply(expression, parameter)); + } else { + var p = propertyAndDirection.get(k); + expression = p.property.isIdProperty() ? entity.getIdExpression() : root.property(k); + + var comparatorFunction = getComparatorFunction(p.order.getDirection(), scrollDirection); + resultingCondition = resultingCondition.or(nextEquals.and(comparatorFunction.apply(expression, parameter))); + nextEquals = expression.eq(parameter); + allEqualsWithArtificialSort = allEqualsWithArtificialSort.and(nextEquals); + } + } + return resultingCondition.or(allEqualsWithArtificialSort); + } + + private static BiFunction getComparatorFunction(Sort.Direction sortDirection, KeysetScrollPosition.Direction scrollDirection) { + if (scrollDirection == KeysetScrollPosition.Direction.Backward) { + return sortDirection.isAscending() ? Expression::lte : Expression::gte; + } + return sortDirection.isAscending() ? Expression::gt : Expression::lt; + } + /** * Converts a Spring Data sort to an equivalent list of {@link SortItem sort items}. * diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java b/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java index 57322d79fc..21ce7674b2 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java @@ -45,8 +45,11 @@ import org.neo4j.cypherdsl.core.RelationshipPattern; import org.neo4j.cypherdsl.core.SortItem; import org.neo4j.driver.types.Point; +import org.springframework.data.domain.KeysetScrollPosition; +import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; @@ -64,6 +67,7 @@ import org.springframework.data.neo4j.core.mapping.PropertyFilter; import org.springframework.data.neo4j.core.mapping.RelationshipDescription; import org.springframework.data.neo4j.core.schema.TargetNode; +import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.parser.AbstractQueryCreator; import org.springframework.data.repository.query.parser.Part; import org.springframework.data.repository.query.parser.PartTree; @@ -82,6 +86,7 @@ final class CypherQueryCreator extends AbstractQueryCreator { private final Neo4jMappingContext mappingContext; + private final QueryMethod queryMethod; private final Class domainType; private final NodeDescription nodeDescription; @@ -99,6 +104,8 @@ final class CypherQueryCreator extends AbstractQueryCreator propertyPathWrappers; + private final boolean keysetRequiresSort; + /** * Can be used to modify the limit of a paged or sliced query. */ private final UnaryOperator limitModifier; - CypherQueryCreator(Neo4jMappingContext mappingContext, Class domainType, Neo4jQueryType queryType, PartTree tree, + CypherQueryCreator(Neo4jMappingContext mappingContext, QueryMethod queryMethod, Class domainType, Neo4jQueryType queryType, PartTree tree, Neo4jParameterAccessor actualParameters, Collection includedProperties, BiFunction, Object> parameterConversion, UnaryOperator limitModifier) { super(tree, actualParameters); this.mappingContext = mappingContext; + this.queryMethod = queryMethod; this.domainType = domainType; this.nodeDescription = this.mappingContext.getRequiredNodeDescription(this.domainType); @@ -139,6 +149,7 @@ final class CypherQueryCreator extends AbstractQueryCreator p.nameOrIndex, p -> parameterConversion.apply(p.value, p.conversionOverride))); QueryFragments queryFragments = createQueryFragments(condition, sort); - return new QueryFragmentsAndParameters(nodeDescription, queryFragments, convertedParameters); + + var theSort = pagingParameter.getSort().and(sort); + if (keysetRequiresSort && theSort.isUnsorted()) { + throw new UnsupportedOperationException("Unsorted keyset based scrolling is not supported."); + } + return new QueryFragmentsAndParameters(nodeDescription, queryFragments, convertedParameters, theSort); } @NonNull @@ -280,15 +297,12 @@ private QueryFragments createQueryFragments(@Nullable Condition condition, Sort } } - // closing action: add the condition and path match - queryFragments.setCondition(conditionFragment); - if (!relationshipChain.isEmpty()) { queryFragments.setMatchOn(relationshipChain); } else { queryFragments.addMatchOn(startNode); } - /// end of initial filter query creation + // end of initial filter query creation if (queryType == Neo4jQueryType.COUNT) { queryFragments.setReturnExpression(Functions.count(Cypher.asterisk()), true); @@ -298,20 +312,38 @@ private QueryFragments createQueryFragments(@Nullable Condition condition, Sort queryFragments.setDeleteExpression(Constants.NAME_OF_TYPED_ROOT_NODE.apply(nodeDescription)); queryFragments.setReturnExpression(Functions.count(Constants.NAME_OF_TYPED_ROOT_NODE.apply(nodeDescription)), true); } else { + + var theSort = pagingParameter.getSort().and(sort); + + if (pagingParameter.isUnpaged() && scrollPosition == null && maxResults != null) { + queryFragments.setLimit(limitModifier.apply(maxResults.intValue())); + } else if (scrollPosition instanceof KeysetScrollPosition keysetScrollPosition) { + + Neo4jPersistentEntity entity = (Neo4jPersistentEntity) nodeDescription; + // Enforce sorting by something that is hopefully stable comparable (looking at Neo4j's id() with tears in my eyes). + theSort = theSort.and(Sort.by(entity.getRequiredIdProperty().getName()).ascending()); + + queryFragments.setLimit(limitModifier.apply(maxResults.intValue())); + if (!keysetScrollPosition.isInitial()) { + conditionFragment = conditionFragment.and(CypherAdapterUtils.combineKeysetIntoCondition(entity, keysetScrollPosition, theSort)); + } + + queryFragments.setRequiresReverseSort(keysetScrollPosition.getDirection() == KeysetScrollPosition.Direction.Backward); + } else if (scrollPosition instanceof OffsetScrollPosition offsetScrollPosition) { + queryFragments.setSkip(offsetScrollPosition.getOffset()); + queryFragments.setLimit(limitModifier.apply(pagingParameter.isUnpaged() ? maxResults.intValue() : pagingParameter.getPageSize())); + } + queryFragments.setReturnBasedOn(nodeDescription, includedProperties, isDistinct); queryFragments.setOrderBy(Stream .concat(sortItems.stream(), - pagingParameter.getSort().and(sort).stream().map(CypherAdapterUtils.sortAdapterFor(nodeDescription))) + theSort.stream().map(CypherAdapterUtils.sortAdapterFor(nodeDescription))) .collect(Collectors.toList())); - if (pagingParameter.isUnpaged()) { - queryFragments.setLimit(maxResults); - } else { - long skip = pagingParameter.getOffset(); - int pageSize = pagingParameter.getPageSize(); - queryFragments.setSkip(skip); - queryFragments.setLimit(limitModifier.apply(pageSize)); - } } + + // closing action: add the condition and path match + queryFragments.setCondition(conditionFragment); + return queryFragments; } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java b/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java index 6d01cd5287..39b519e94a 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQueryMethod.java @@ -30,6 +30,7 @@ import org.springframework.data.util.TypeInformation; import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; /** * Neo4j specific implementation of {@link QueryMethod}. It contains a custom implementation of {@link Parameter} which @@ -158,4 +159,12 @@ public String getNameOrIndex() { return this.getName().orElseGet(() -> Integer.toString(this.getIndex())); } } + + boolean incrementLimit() { + return (this.isSliceQuery() && this.getQueryAnnotation().map(Query::countQuery).filter(StringUtils::hasText).isEmpty()) || this.isScrollQuery(); + } + + boolean asCollectionQuery() { + return this.isCollectionLikeQuery() || this.isPageQuery() || this.isSliceQuery() || this.isScrollQuery(); + } } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java b/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java index 664a18a0a2..600748921d 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/Neo4jQuerySupport.java @@ -23,6 +23,8 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiFunction; @@ -30,18 +32,23 @@ import java.util.stream.Collectors; import org.apache.commons.logging.LogFactory; +import org.neo4j.driver.Value; import org.neo4j.driver.Values; import org.neo4j.driver.types.MapAccessor; import org.neo4j.driver.types.TypeSystem; import org.springframework.core.log.LogAccessor; import org.springframework.data.convert.EntityWriter; +import org.springframework.data.domain.KeysetScrollPosition; +import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Metrics; import org.springframework.data.neo4j.core.TemplateSupport; import org.springframework.data.neo4j.core.convert.Neo4jPersistentPropertyConverter; +import org.springframework.data.neo4j.core.mapping.Constants; import org.springframework.data.neo4j.core.mapping.CypherGenerator; import org.springframework.data.neo4j.core.mapping.EntityInstanceWithSource; import org.springframework.data.neo4j.core.mapping.Neo4jMappingContext; @@ -267,6 +274,49 @@ void logWarningsIfNecessary(QueryContext queryContext, Neo4jParameterAccessor pa } } + final Window createWindow(ResultProcessor resultProcessor, boolean incrementLimit, Neo4jParameterAccessor parameterAccessor, List rawResult, QueryFragmentsAndParameters orderBy) { + + var domainType = resultProcessor.getReturnedType().getDomainType(); + var neo4jPersistentEntity = mappingContext.getPersistentEntity(domainType); + var limit = orderBy.getQueryFragments().getLimit().intValue() - (incrementLimit ? 1 : 0); + var conversionService = mappingContext.getConversionService(); + var scrollPosition = parameterAccessor.getScrollPosition(); + + var scrollDirection = scrollPosition instanceof KeysetScrollPosition keysetScrollPosition ? keysetScrollPosition.getDirection() : KeysetScrollPosition.Direction.Forward; + if (scrollDirection == KeysetScrollPosition.Direction.Backward) { + Collections.reverse(rawResult); + } + + return Window.from(getSubList(rawResult, limit, scrollDirection), v -> { + if (scrollPosition instanceof OffsetScrollPosition offsetScrollPosition) { + return OffsetScrollPosition.of(offsetScrollPosition.getOffset() + v + limit); + } else { + var accessor = neo4jPersistentEntity.getPropertyAccessor(rawResult.get(v)); + var keys = new LinkedHashMap(); + orderBy.getSort().forEach(o -> { + // Storing the graph property name here + var persistentProperty = neo4jPersistentEntity.getRequiredPersistentProperty(o.getProperty()); + keys.put(persistentProperty.getPropertyName(), conversionService.convert(accessor.getProperty(persistentProperty), Value.class)); + }); + keys.put(Constants.NAME_OF_ADDITIONAL_SORT, conversionService.convert(accessor.getProperty(neo4jPersistentEntity.getRequiredIdProperty()), Value.class)); + return KeysetScrollPosition.of(keys); + } + }, hasMoreElements(rawResult, limit)); + } + + private static boolean hasMoreElements(List result, int limit) { + return !result.isEmpty() && result.size() > limit; + } + + private static List getSubList(List result, int limit, KeysetScrollPosition.Direction scrollDirection) { + + if (limit > 0 && result.size() > limit) { + return scrollDirection == KeysetScrollPosition.Direction.Forward ? result.subList(0, limit) : result.subList(1, limit + 1); + } + + return result; + } + private Map convertRange(Range range) { Map map = new HashMap<>(); range.getLowerBound().getValue().map(this::convertParameter).ifPresent(v -> map.put("lb", v)); diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/PartTreeNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/PartTreeNeo4jQuery.java index 296f0da9c8..d2bf380882 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/PartTreeNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/PartTreeNeo4jQuery.java @@ -65,7 +65,7 @@ protected PreparedQuery prepareQuery(Class returnedType Neo4jParameterAccessor parameterAccessor, @Nullable Neo4jQueryType queryType, @Nullable Supplier> mappingFunction, UnaryOperator limitModifier) { - CypherQueryCreator queryCreator = new CypherQueryCreator(mappingContext, getDomainType(queryMethod), + CypherQueryCreator queryCreator = new CypherQueryCreator(mappingContext, queryMethod, getDomainType(queryMethod), Optional.ofNullable(queryType).orElseGet(() -> Neo4jQueryType.fromPartTree(tree)), tree, parameterAccessor, includedProperties, this::convertParameter, limitModifier); diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java b/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java index 3ca24474cf..ae90c2b104 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java @@ -20,6 +20,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import org.apiguardian.api.API; import org.neo4j.cypherdsl.core.Condition; @@ -56,6 +57,10 @@ public final class QueryFragments { private boolean scalarValueReturn = false; private boolean renderConstantsAsParameters = false; private Expression deleteExpression; + /** + * This flag becomes {@literal true} for backward scrolling keyset pagination. Any {@code AbstractNeo4jQuery} will in turn reverse the result list. + */ + private boolean requiresReverseSort = false; public void addMatchOn(PatternElement match) { this.matchOn.add(match); @@ -115,6 +120,14 @@ public boolean isScalarValueReturn() { return scalarValueReturn; } + public boolean requiresReverseSort() { + return requiresReverseSort; + } + + public void setRequiresReverseSort(boolean requiresReverseSort) { + this.requiresReverseSort = requiresReverseSort; + } + public void setRenderConstantsAsParameters(boolean renderConstantsAsParameters) { this.renderConstantsAsParameters = renderConstantsAsParameters; } @@ -162,7 +175,33 @@ private boolean isDistinctReturn() { } public Collection getOrderBy() { - return orderBy != null ? orderBy : Collections.emptySet(); + + if (orderBy == null) { + return List.of(); + } else if (!requiresReverseSort) { + return orderBy; + } else { + return orderBy.stream().map(QueryFragments::reverse).toList(); + } + } + + // Yeah, would be kinda nice having a simple method in Cypher-DSL ;) + private static SortItem reverse(SortItem sortItem) { + + var sortedExpression = new AtomicReference(); + var sortDirection = new AtomicReference(); + + sortItem.accept(segment -> { + if (segment instanceof SortItem.Direction direction) { + sortDirection.compareAndSet(null, direction == SortItem.Direction.UNDEFINED || direction == SortItem.Direction.ASC ? SortItem.Direction.DESC : SortItem.Direction.ASC); + } else if (segment instanceof Expression expression) { + sortedExpression.compareAndSet(null, expression); + } + }); + + // Default might not explicitly set. + sortDirection.compareAndSet(null, SortItem.Direction.DESC); + return Cypher.sort(sortedExpression.get(), sortDirection.get()); } public Number getLimit() { diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java b/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java index 78a7b32d3e..771f0165e5 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java @@ -53,23 +53,26 @@ public final class QueryFragmentsAndParameters { private NodeDescription nodeDescription; private final QueryFragments queryFragments; private final String cypherQuery; + private final Sort sort; public QueryFragmentsAndParameters(NodeDescription nodeDescription, QueryFragments queryFragments, - @Nullable Map parameters) { + @Nullable Map parameters, @Nullable Sort sort) { this.nodeDescription = nodeDescription; this.queryFragments = queryFragments; this.parameters = parameters; this.cypherQuery = null; + this.sort = sort == null ? Sort.unsorted() : sort; } public QueryFragmentsAndParameters(String cypherQuery) { this(cypherQuery, null); } - public QueryFragmentsAndParameters(String cypherQuery, Map parameters) { + public QueryFragmentsAndParameters(String cypherQuery, @Nullable Map parameters) { this.cypherQuery = cypherQuery; this.queryFragments = new QueryFragments(); this.parameters = parameters; + this.sort = Sort.unsorted(); } public Map getParameters() { @@ -92,6 +95,10 @@ public void setParameters(Map newParameters) { this.parameters = newParameters; } + public Sort getSort() { + return sort; + } + /* * Convenience methods that are used by the (Reactive)Neo4jTemplate */ @@ -110,7 +117,7 @@ public static QueryFragmentsAndParameters forFindById(Neo4jPersistentEntity e queryFragments.addMatchOn(container); queryFragments.setCondition(condition); queryFragments.setReturnExpressions(cypherGenerator.createReturnStatementForMatch(entityMetaData)); - return new QueryFragmentsAndParameters(entityMetaData, queryFragments, parameters); + return new QueryFragmentsAndParameters(entityMetaData, queryFragments, parameters, null); } public static QueryFragmentsAndParameters forFindByAllId(Neo4jPersistentEntity entityMetaData, Object idValues) { @@ -133,7 +140,7 @@ public static QueryFragmentsAndParameters forFindByAllId(Neo4jPersistentEntity entityMetaData) { @@ -141,7 +148,7 @@ public static QueryFragmentsAndParameters forFindAll(Neo4jPersistentEntity en queryFragments.addMatchOn(cypherGenerator.createRootNode(entityMetaData)); queryFragments.setCondition(Conditions.noCondition()); queryFragments.setReturnExpressions(cypherGenerator.createReturnStatementForMatch(entityMetaData)); - return new QueryFragmentsAndParameters(entityMetaData, queryFragments, Collections.emptyMap()); + return new QueryFragmentsAndParameters(entityMetaData, queryFragments, Collections.emptyMap(), null); } public static QueryFragmentsAndParameters forExistsById(Neo4jPersistentEntity entityMetaData, Object idValues) { @@ -159,7 +166,7 @@ public static QueryFragmentsAndParameters forExistsById(Neo4jPersistentEntity queryFragments.addMatchOn(container); queryFragments.setCondition(condition); queryFragments.setReturnExpressions(cypherGenerator.createReturnStatementForExists(entityMetaData)); - return new QueryFragmentsAndParameters(entityMetaData, queryFragments, parameters); + return new QueryFragmentsAndParameters(entityMetaData, queryFragments, parameters, null); } /* @@ -234,7 +241,7 @@ static QueryFragmentsAndParameters forCondition(Neo4jPersistentEntity entityM queryFragments.setOrderBy(sortItems); } - return new QueryFragmentsAndParameters(entityMetaData, queryFragments, Collections.emptyMap()); + return new QueryFragmentsAndParameters(entityMetaData, queryFragments, Collections.emptyMap(), null); } private static void adaptPageable( @@ -286,7 +293,7 @@ private static QueryFragmentsAndParameters getQueryFragmentsAndParameters( queryFragments.setOrderBy(CypherAdapterUtils.toSortItems(entityMetaData, sort)); } - return new QueryFragmentsAndParameters(entityMetaData, queryFragments, parameters); + return new QueryFragmentsAndParameters(entityMetaData, queryFragments, parameters, sort); } } diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveCypherdslBasedQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveCypherdslBasedQuery.java index e75eb03786..3da64d9dda 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveCypherdslBasedQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveCypherdslBasedQuery.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.function.BiFunction; import java.util.function.Supplier; +import java.util.function.UnaryOperator; import org.neo4j.cypherdsl.core.Statement; import org.neo4j.driver.types.MapAccessor; @@ -28,6 +29,7 @@ import org.springframework.data.neo4j.core.mapping.Neo4jMappingContext; import org.springframework.data.neo4j.core.mapping.PropertyFilter; import org.springframework.data.projection.ProjectionFactory; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -53,8 +55,9 @@ private ReactiveCypherdslBasedQuery(ReactiveNeo4jOperations neo4jOperations, @Override protected PreparedQuery prepareQuery(Class returnedType, Collection includedProperties, - Neo4jParameterAccessor parameterAccessor, Neo4jQueryType queryType, - Supplier> mappingFunction) { + @Nullable Neo4jParameterAccessor parameterAccessor, Neo4jQueryType queryType, + @Nullable Supplier> mappingFunction, + @Nullable UnaryOperator limitModifier) { Object[] parameters = parameterAccessor.getValues(); diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java index 4a40814218..074eccf084 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/ReactivePartTreeNeo4jQuery.java @@ -63,11 +63,11 @@ private ReactivePartTreeNeo4jQuery(ReactiveNeo4jOperations neo4jOperations, Neo4 @Override protected PreparedQuery prepareQuery(Class returnedType, Collection includedProperties, Neo4jParameterAccessor parameterAccessor, @Nullable Neo4jQueryType queryType, - @Nullable Supplier> mappingFunction) { + @Nullable Supplier> mappingFunction, @Nullable UnaryOperator limitModifier) { - CypherQueryCreator queryCreator = new CypherQueryCreator(mappingContext, getDomainType(queryMethod), + CypherQueryCreator queryCreator = new CypherQueryCreator(mappingContext, queryMethod, getDomainType(queryMethod), Optional.ofNullable(queryType).orElseGet(() -> Neo4jQueryType.fromPartTree(tree)), tree, parameterAccessor, - includedProperties, this::convertParameter, UnaryOperator.identity()); + includedProperties, this::convertParameter, limitModifier); QueryFragmentsAndParameters queryAndParameters = queryCreator.createQuery(); diff --git a/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveStringBasedNeo4jQuery.java b/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveStringBasedNeo4jQuery.java index 46fe704afb..3411826421 100644 --- a/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveStringBasedNeo4jQuery.java +++ b/src/main/java/org/springframework/data/neo4j/repository/query/ReactiveStringBasedNeo4jQuery.java @@ -21,6 +21,7 @@ import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Supplier; +import java.util.function.UnaryOperator; import org.neo4j.driver.types.MapAccessor; import org.neo4j.driver.types.TypeSystem; @@ -128,7 +129,7 @@ private ReactiveStringBasedNeo4jQuery(ReactiveNeo4jOperations neo4jOperations, N @Override protected PreparedQuery prepareQuery(Class returnedType, Collection includedProperties, Neo4jParameterAccessor parameterAccessor, @Nullable Neo4jQueryType queryType, - @Nullable Supplier> mappingFunction) { + @Nullable Supplier> mappingFunction, @Nullable UnaryOperator limitModifier) { Map boundParameters = bindParameters(parameterAccessor); QueryContext queryContext = new QueryContext( diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java index 0f2ef1e409..90421321ca 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java @@ -51,6 +51,9 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.neo4j.driver.Driver; import org.neo4j.driver.Record; import org.neo4j.driver.Session; @@ -68,21 +71,23 @@ import org.springframework.data.domain.Example; import org.springframework.data.domain.ExampleMatcher; import org.springframework.data.domain.ExampleMatcher.StringMatcher; +import org.springframework.data.domain.KeysetScrollPosition; +import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; +import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.WindowIterator; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Polygon; import org.springframework.data.mapping.MappingException; -import org.springframework.data.neo4j.core.mapping.IdentitySupport; -import org.springframework.data.neo4j.test.Neo4jImperativeTestConfiguration; import org.springframework.data.neo4j.core.DatabaseSelection; import org.springframework.data.neo4j.core.DatabaseSelectionProvider; import org.springframework.data.neo4j.core.Neo4jClient; @@ -90,6 +95,7 @@ import org.springframework.data.neo4j.core.UserSelection; import org.springframework.data.neo4j.core.UserSelectionProvider; import org.springframework.data.neo4j.core.convert.Neo4jConversions; +import org.springframework.data.neo4j.core.mapping.IdentitySupport; import org.springframework.data.neo4j.core.mapping.Neo4jMappingContext; import org.springframework.data.neo4j.core.transaction.Neo4jBookmarkManager; import org.springframework.data.neo4j.core.transaction.Neo4jTransactionManager; @@ -150,6 +156,7 @@ import org.springframework.data.neo4j.repository.query.Query; import org.springframework.data.neo4j.test.BookmarkCapture; import org.springframework.data.neo4j.test.Neo4jExtension; +import org.springframework.data.neo4j.test.Neo4jImperativeTestConfiguration; import org.springframework.data.neo4j.test.ServerVersion; import org.springframework.data.neo4j.types.CartesianPoint2d; import org.springframework.data.neo4j.types.GeographicPoint2d; @@ -258,6 +265,27 @@ void findAll(@Autowired PersonRepository repository) { assertThat(people).extracting("name").containsExactlyInAnyOrder(TEST_PERSON1_NAME, TEST_PERSON2_NAME); } + static Stream basicScrollSupportFor(@Autowired PersonRepository repository) { + return Stream.of(Arguments.of(repository, KeysetScrollPosition.initial()), Arguments.of(repository, OffsetScrollPosition.initial())); + } + + @ParameterizedTest(name = "basicScrollSupportFor {1}") + @MethodSource + void basicScrollSupportFor(PersonRepository repository, ScrollPosition initialPosition) { + + var it = WindowIterator.of(repository::findTop1ByOrderByName) + .startingAt(initialPosition); + var content = new ArrayList(); + while (it.hasNext()) { + var next = it.next(); + content.add(next); + } + assertThat(content) + .hasSize(2) + .extracting(PersonWithAllConstructor::getName) + .containsExactly("Test", "Test2"); + } + @Test void findAllWithoutResultDoesNotThrowAnException(@Autowired PersonRepository repository) { diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/ScrollingIT.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/ScrollingIT.java new file mode 100644 index 0000000000..7d8edcbfd6 --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/ScrollingIT.java @@ -0,0 +1,192 @@ +/* + * Copyright 2011-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.integration.imperative; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.Map; +import java.util.function.Function; + +import org.assertj.core.data.Index; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Values; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.domain.KeysetScrollPosition; +import org.springframework.data.domain.WindowIterator; +import org.springframework.data.neo4j.core.DatabaseSelectionProvider; +import org.springframework.data.neo4j.core.mapping.Constants; +import org.springframework.data.neo4j.core.transaction.Neo4jBookmarkManager; +import org.springframework.data.neo4j.core.transaction.Neo4jTransactionManager; +import org.springframework.data.neo4j.integration.imperative.repositories.ScrollingRepository; +import org.springframework.data.neo4j.integration.shared.common.ScrollingEntity; +import org.springframework.data.neo4j.repository.config.EnableNeo4jRepositories; +import org.springframework.data.neo4j.test.BookmarkCapture; +import org.springframework.data.neo4j.test.Neo4jExtension; +import org.springframework.data.neo4j.test.Neo4jImperativeTestConfiguration; +import org.springframework.data.neo4j.test.Neo4jIntegrationTest; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.EnableTransactionManagement; + +/** + * @author Michael J. Simons + */ +@Neo4jIntegrationTest +class ScrollingIT { + + @SuppressWarnings("unused") + private static Neo4jExtension.Neo4jConnectionSupport neo4jConnectionSupport; + + @BeforeAll + static void setupTestData(@Autowired Driver driver, @Autowired BookmarkCapture bookmarkCapture) { + try ( + var session = driver.session(bookmarkCapture.createSessionConfig()); + var transaction = session.beginTransaction() + ) { + ScrollingEntity.createTestData(transaction); + transaction.commit(); + bookmarkCapture.seedWith(session.lastBookmarks()); + } + } + + @Test + void oneColumnSortNoScroll(@Autowired ScrollingRepository repository) { + + var topN = repository.findTop4ByOrderByB(); + assertThat(topN) + .hasSize(4) + .extracting(ScrollingEntity::getA) + .containsExactly("A0", "B0", "C0", "D0"); + } + + @Test + void forwardWithDuplicatesManualIteration(@Autowired ScrollingRepository repository) { + + var duplicates = repository.findAllByAOrderById("D0"); + assertThat(duplicates).hasSize(2); + + var window = repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, KeysetScrollPosition.initial()); + assertThat(window.hasNext()).isTrue(); + assertThat(window) + .hasSize(4) + .extracting(Function.identity()) + .satisfies(e -> assertThat(e.getId()).isEqualTo(duplicates.get(0).getId()), Index.atIndex(3)) + .extracting(ScrollingEntity::getA) + .containsExactly("A0", "B0", "C0", "D0"); + + window = repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, window.positionAt(window.size() - 1)); + assertThat(window.hasNext()).isTrue(); + assertThat(window) + .hasSize(4) + .extracting(Function.identity()) + .satisfies(e -> assertThat(e.getId()).isEqualTo(duplicates.get(1).getId()), Index.atIndex(0)) + .extracting(ScrollingEntity::getA) + .containsExactly("D0", "E0", "F0", "G0"); + + window = repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, window.positionAt(window.size() - 1)); + assertThat(window.isLast()).isTrue(); + assertThat(window).extracting(ScrollingEntity::getA) + .containsExactly("H0", "I0"); + } + + @Test + void forwardWithDuplicatesIteratorIteration(@Autowired ScrollingRepository repository) { + + var it = WindowIterator.of(pos -> repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, pos)) + .startingAt(KeysetScrollPosition.initial()); + var content = new ArrayList(); + while (it.hasNext()) { + var next = it.next(); + content.add(next); + } + + assertThat(content).hasSize(10); + assertThat(content.stream().map(ScrollingEntity::getId) + .distinct().toList()).hasSize(10); + } + + @Test + void backwardWithDuplicatesManualIteration(@Autowired ScrollingRepository repository) { + + // Recreate the last position + var last = repository.findFirstByA("I0"); + var keys = Map.of( + "foobar", Values.value(last.getA()), + "b", Values.value(last.getB()), + Constants.NAME_OF_ADDITIONAL_SORT, Values.value(last.getId().toString()) + ); + + var duplicates = repository.findAllByAOrderById("D0"); + assertThat(duplicates).hasSize(2); + + var window = repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, KeysetScrollPosition.of(keys, KeysetScrollPosition.Direction.Backward)); + assertThat(window.hasNext()).isTrue(); + assertThat(window) + .hasSize(4) + .extracting(ScrollingEntity::getA) + .containsExactly("F0", "G0", "H0", "I0"); + + var pos = ((KeysetScrollPosition) window.positionAt(0)); + pos = KeysetScrollPosition.of(pos.getKeys(), KeysetScrollPosition.Direction.Backward); + window = repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, pos); + assertThat(window.hasNext()).isTrue(); + assertThat(window) + .hasSize(4) + .extracting(Function.identity()) + .extracting(ScrollingEntity::getA) + .containsExactly("C0", "D0", "D0", "E0"); + + pos = ((KeysetScrollPosition) window.positionAt(0)); + pos = KeysetScrollPosition.of(pos.getKeys(), KeysetScrollPosition.Direction.Backward); + window = repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, pos); + assertThat(window.isLast()).isTrue(); + assertThat(window).extracting(ScrollingEntity::getA) + .containsExactly("A0", "B0"); + } + + @Configuration + @EnableNeo4jRepositories + @EnableTransactionManagement + static class Config extends Neo4jImperativeTestConfiguration { + + @Bean + public Driver driver() { + return neo4jConnectionSupport.getDriver(); + } + + @Bean + public BookmarkCapture bookmarkCapture() { + return new BookmarkCapture(); + } + + @Override + public PlatformTransactionManager transactionManager(Driver driver, DatabaseSelectionProvider databaseNameProvider) { + + BookmarkCapture bookmarkCapture = bookmarkCapture(); + return new Neo4jTransactionManager(driver, databaseNameProvider, Neo4jBookmarkManager.create(bookmarkCapture)); + } + + @Override + public boolean isCypher5Compatible() { + return neo4jConnectionSupport.isCypher5SyntaxCompatible(); + } + + } +} diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/repositories/PersonRepository.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/repositories/PersonRepository.java index b083c3c3c5..1ec611cd23 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/imperative/repositories/PersonRepository.java +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/repositories/PersonRepository.java @@ -28,8 +28,10 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -113,6 +115,8 @@ Optional getOptionalPersonViaNamedQuery(@Param("part1" Slice findSliceByNameOrName(String aName, String anotherName, Pageable pageable); + Window findTop1ByOrderByName(ScrollPosition scrollPosition); + @Query("MATCH (n:PersonWithAllConstructor) WHERE n.name = $aName OR n.name = $anotherName RETURN n ORDER BY n.name DESC SKIP $skip LIMIT $limit") Slice findSliceByCustomQueryWithoutCount(@Param("aName") String aName, @Param("anotherName") String anotherName, Pageable pageable); diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/repositories/ScrollingRepository.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/repositories/ScrollingRepository.java new file mode 100644 index 0000000000..707a6c8f4a --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/repositories/ScrollingRepository.java @@ -0,0 +1,39 @@ +/* + * Copyright 2011-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.integration.imperative.repositories; + +import java.util.List; +import java.util.UUID; + +import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Window; +import org.springframework.data.neo4j.integration.shared.common.ScrollingEntity; +import org.springframework.data.neo4j.repository.Neo4jRepository; + +/** + * @author Michael J. Simons + */ +public interface ScrollingRepository extends Neo4jRepository { + + List findTop4ByOrderByB(); + + Window findTop4By(Sort sort, ScrollPosition position); + + ScrollingEntity findFirstByA(String a); + + List findAllByAOrderById(String a); +} diff --git a/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveScrollingIT.java b/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveScrollingIT.java new file mode 100644 index 0000000000..b4d506ef01 --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveScrollingIT.java @@ -0,0 +1,210 @@ +/* + * Copyright 2011-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.integration.reactive; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import org.assertj.core.data.Index; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Values; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.domain.KeysetScrollPosition; +import org.springframework.data.domain.Window; +import org.springframework.data.neo4j.core.ReactiveDatabaseSelectionProvider; +import org.springframework.data.neo4j.core.mapping.Constants; +import org.springframework.data.neo4j.core.transaction.Neo4jBookmarkManager; +import org.springframework.data.neo4j.core.transaction.ReactiveNeo4jTransactionManager; +import org.springframework.data.neo4j.integration.reactive.repositories.ReactiveScrollingRepository; +import org.springframework.data.neo4j.integration.shared.common.ScrollingEntity; +import org.springframework.data.neo4j.repository.config.EnableNeo4jRepositories; +import org.springframework.data.neo4j.repository.config.EnableReactiveNeo4jRepositories; +import org.springframework.data.neo4j.test.BookmarkCapture; +import org.springframework.data.neo4j.test.Neo4jExtension; +import org.springframework.data.neo4j.test.Neo4jIntegrationTest; +import org.springframework.data.neo4j.test.Neo4jReactiveTestConfiguration; +import org.springframework.transaction.ReactiveTransactionManager; + +import reactor.test.StepVerifier; + +/** + * @author Michael J. Simons + */ +@Neo4jIntegrationTest +class ReactiveScrollingIT { + + @SuppressWarnings("unused") + private static Neo4jExtension.Neo4jConnectionSupport neo4jConnectionSupport; + + @BeforeAll + static void setupTestData(@Autowired Driver driver, @Autowired BookmarkCapture bookmarkCapture) { + try ( + var session = driver.session(bookmarkCapture.createSessionConfig()); + var transaction = session.beginTransaction() + ) { + ScrollingEntity.createTestData(transaction); + transaction.commit(); + bookmarkCapture.seedWith(session.lastBookmarks()); + } + } + + @Test + void oneColumnSortNoScroll(@Autowired ReactiveScrollingRepository repository) { + + repository.findTop4ByOrderByB() + .map(ScrollingEntity::getA) + .as(StepVerifier::create) + .expectNext("A0", "B0", "C0", "D0"); + } + + @Test + void forwardWithDuplicatesManualIteration(@Autowired ReactiveScrollingRepository repository) { + + var duplicates = new ArrayList(); + repository.findAllByAOrderById("D0").as(StepVerifier::create) + .recordWith(() -> duplicates) + .expectNextCount(2) + .verifyComplete(); + + var windowContainer = new AtomicReference>(); + repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, KeysetScrollPosition.initial()) + .as(StepVerifier::create) + .consumeNextWith(windowContainer::set) + .verifyComplete(); + var window = windowContainer.get(); + assertThat(window.hasNext()).isTrue(); + assertThat(window) + .hasSize(4) + .extracting(Function.identity()) + .satisfies(e -> assertThat(e.getId()).isEqualTo(duplicates.get(0).getId()), Index.atIndex(3)) + .extracting(ScrollingEntity::getA) + .containsExactly("A0", "B0", "C0", "D0"); + + repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, window.positionAt(window.size() - 1)) + .as(StepVerifier::create) + .consumeNextWith(windowContainer::set) + .verifyComplete(); + window = windowContainer.get(); + assertThat(window.hasNext()).isTrue(); + assertThat(window) + .hasSize(4) + .extracting(Function.identity()) + .satisfies(e -> assertThat(e.getId()).isEqualTo(duplicates.get(1).getId()), Index.atIndex(0)) + .extracting(ScrollingEntity::getA) + .containsExactly("D0", "E0", "F0", "G0"); + + repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, window.positionAt(window.size() - 1)) + .as(StepVerifier::create) + .consumeNextWith(windowContainer::set) + .verifyComplete(); + window = windowContainer.get(); + assertThat(window.isLast()).isTrue(); + assertThat(window).extracting(ScrollingEntity::getA) + .containsExactly("H0", "I0"); + } + + @Test + void backwardWithDuplicatesManualIteration(@Autowired ReactiveScrollingRepository repository) { + + // Recreate the last position + var last = repository.findFirstByA("I0").block(); + var keys = Map.of( + "foobar", Values.value(last.getA()), + "b", Values.value(last.getB()), + Constants.NAME_OF_ADDITIONAL_SORT, Values.value(last.getId().toString()) + ); + + var duplicates = new ArrayList(); + repository.findAllByAOrderById("D0").as(StepVerifier::create) + .recordWith(() -> duplicates) + .expectNextCount(2) + .verifyComplete(); + + var windowContainer = new AtomicReference>(); + repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, KeysetScrollPosition.of(keys, KeysetScrollPosition.Direction.Backward)) + .as(StepVerifier::create) + .consumeNextWith(windowContainer::set) + .verifyComplete(); + var window = windowContainer.get(); + assertThat(window.hasNext()).isTrue(); + assertThat(window) + .hasSize(4) + .extracting(ScrollingEntity::getA) + .containsExactly("F0", "G0", "H0", "I0"); + + var pos = ((KeysetScrollPosition) window.positionAt(0)); + pos = KeysetScrollPosition.of(pos.getKeys(), KeysetScrollPosition.Direction.Backward); + repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, pos) + .as(StepVerifier::create) + .consumeNextWith(windowContainer::set) + .verifyComplete(); + window = windowContainer.get(); + assertThat(window.hasNext()).isTrue(); + assertThat(window) + .hasSize(4) + .extracting(Function.identity()) + .extracting(ScrollingEntity::getA) + .containsExactly("C0", "D0", "D0", "E0"); + + pos = ((KeysetScrollPosition) window.positionAt(0)); + pos = KeysetScrollPosition.of(pos.getKeys(), KeysetScrollPosition.Direction.Backward); + repository.findTop4By(ScrollingEntity.SORT_BY_B_AND_A, pos) + .as(StepVerifier::create) + .consumeNextWith(windowContainer::set) + .verifyComplete(); + window = windowContainer.get(); + assertThat(window.isLast()).isTrue(); + assertThat(window).extracting(ScrollingEntity::getA) + .containsExactly("A0", "B0"); + } + + @Configuration + @EnableNeo4jRepositories + @EnableReactiveNeo4jRepositories + static class Config extends Neo4jReactiveTestConfiguration { + + @Bean + public Driver driver() { + return neo4jConnectionSupport.getDriver(); + } + + @Bean + public BookmarkCapture bookmarkCapture() { + return new BookmarkCapture(); + } + + @Override + public ReactiveTransactionManager reactiveTransactionManager(Driver driver, ReactiveDatabaseSelectionProvider databaseSelectionProvider) { + + BookmarkCapture bookmarkCapture = bookmarkCapture(); + return new ReactiveNeo4jTransactionManager(driver, databaseSelectionProvider, Neo4jBookmarkManager.create(bookmarkCapture)); + } + + @Override + public boolean isCypher5Compatible() { + return neo4jConnectionSupport.isCypher5SyntaxCompatible(); + } + + } +} diff --git a/src/test/java/org/springframework/data/neo4j/integration/reactive/repositories/ReactiveScrollingRepository.java b/src/test/java/org/springframework/data/neo4j/integration/reactive/repositories/ReactiveScrollingRepository.java new file mode 100644 index 0000000000..d87e59d46a --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/integration/reactive/repositories/ReactiveScrollingRepository.java @@ -0,0 +1,41 @@ +/* + * Copyright 2011-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.integration.reactive.repositories; + +import java.util.UUID; + +import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Window; +import org.springframework.data.neo4j.integration.shared.common.ScrollingEntity; +import org.springframework.data.neo4j.repository.ReactiveNeo4jRepository; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * @author Michael J. Simons + */ +public interface ReactiveScrollingRepository extends ReactiveNeo4jRepository { + + Flux findTop4ByOrderByB(); + + Mono> findTop4By(Sort sort, ScrollPosition position); + + Mono findFirstByA(String a); + + Flux findAllByAOrderById(String a); +} diff --git a/src/test/java/org/springframework/data/neo4j/integration/shared/common/ScrollingEntity.java b/src/test/java/org/springframework/data/neo4j/integration/shared/common/ScrollingEntity.java new file mode 100644 index 0000000000..4d991f82c9 --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/integration/shared/common/ScrollingEntity.java @@ -0,0 +1,102 @@ +/* + * Copyright 2011-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.integration.shared.common; + +import java.time.LocalDateTime; +import java.util.UUID; + +import org.neo4j.driver.QueryRunner; +import org.springframework.data.domain.Sort; +import org.springframework.data.neo4j.core.schema.GeneratedValue; +import org.springframework.data.neo4j.core.schema.Id; +import org.springframework.data.neo4j.core.schema.Node; +import org.springframework.data.neo4j.core.schema.Property; + +/** + * An entity that is specifically designed to test the keyset based pagination. + * + * @author Michael J. Simons + */ +@Node +public class ScrollingEntity { + + /** + * Sorting by b and a will not be unique for 3 and D0, so this will trigger the additional condition based on the id + */ + public static final Sort SORT_BY_B_AND_A = Sort.by(Sort.Order.asc("b"), Sort.Order.desc("a")); + + public static void createTestData(QueryRunner queryRunner) { + queryRunner.run("MATCH (n) DETACH DELETE n"); + queryRunner.run(""" + UNWIND (range(0, 8) + [3]) AS i WITH i, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' AS letters + CREATE (n:ScrollingEntity { + id: randomUUID(), + foobar: (substring(letters, (toInteger(i) % 26), 1) + (i / 26)), + b: i, + c: (localdatetime() + duration({ days: i }) + duration({ seconds: i * toInteger(rand()*10) })) + }) + RETURN n + """); + } + + @Id + @GeneratedValue + private UUID id; + + @Property("foobar") + private String a; + + private Integer b; + + private LocalDateTime c; + + public UUID getId() { + return id; + } + + public String getA() { + return a; + } + + public void setA(String a) { + this.a = a; + } + + public Integer getB() { + return b; + } + + public void setB(Integer b) { + this.b = b; + } + + public LocalDateTime getC() { + return c; + } + + public void setC(LocalDateTime c) { + this.c = c; + } + + @Override + public String toString() { + return "ScrollingEntity{" + + "a='" + a + '\'' + + ", b=" + b + + ", c=" + c + + '}'; + } +} diff --git a/src/test/java/org/springframework/data/neo4j/repository/query/CypherAdapterUtilsTest.java b/src/test/java/org/springframework/data/neo4j/repository/query/CypherAdapterUtilsTest.java new file mode 100644 index 0000000000..01e3434091 --- /dev/null +++ b/src/test/java/org/springframework/data/neo4j/repository/query/CypherAdapterUtilsTest.java @@ -0,0 +1,65 @@ +/* + * Copyright 2011-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.neo4j.repository.query; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.LocalDateTime; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.neo4j.cypherdsl.core.Cypher; +import org.neo4j.cypherdsl.core.renderer.Configuration; +import org.neo4j.cypherdsl.core.renderer.Renderer; +import org.springframework.data.domain.KeysetScrollPosition; +import org.springframework.data.domain.Sort; +import org.springframework.data.neo4j.core.mapping.Constants; +import org.springframework.data.neo4j.core.mapping.Neo4jMappingContext; +import org.springframework.data.neo4j.integration.shared.common.ScrollingEntity; + +/** + * @author Michael J. Simons + */ +class CypherAdapterUtilsTest { + + @Test + void shouldCombineSortKeysetProper() { + + var mappingContext = new Neo4jMappingContext(); + var entity = mappingContext.getPersistentEntity(ScrollingEntity.class); + var n = Constants.NAME_OF_TYPED_ROOT_NODE.apply(entity); + + var condition = CypherAdapterUtils.combineKeysetIntoCondition(entity, + KeysetScrollPosition.of(Map.of("foobar", "D0", "b", 3, "c", LocalDateTime.of(2023, 3, 19, 14, 21, 8, 716))), + Sort.by(Sort.Order.asc("b"), Sort.Order.desc("a"), Sort.Order.asc("c")) + ); + + var expected = """ + MATCH (scrollingEntity) + WHERE (((scrollingEntity.b > $pcdsl01 + OR (scrollingEntity.b = $pcdsl01 + AND scrollingEntity.foobar < $pcdsl02)) + OR (scrollingEntity.foobar = $pcdsl02 + AND scrollingEntity.c > $pcdsl03)) + OR (scrollingEntity.b = $pcdsl01 + AND scrollingEntity.foobar = $pcdsl02 + AND scrollingEntity.c = $pcdsl03)) + RETURN scrollingEntity"""; + + assertThat(Renderer.getRenderer(Configuration.prettyPrinting()).render(Cypher.match(Cypher.anyNode(n)).where(condition).returning(n).build())) + .isEqualTo(expected); + } +} diff --git a/src/test/java/org/springframework/data/neo4j/repository/query/ReactiveRepositoryQueryTest.java b/src/test/java/org/springframework/data/neo4j/repository/query/ReactiveRepositoryQueryTest.java index 37ca92b58e..f1b7792f45 100644 --- a/src/test/java/org/springframework/data/neo4j/repository/query/ReactiveRepositoryQueryTest.java +++ b/src/test/java/org/springframework/data/neo4j/repository/query/ReactiveRepositoryQueryTest.java @@ -21,13 +21,11 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - import java.lang.reflect.Method; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.UnaryOperator; import java.util.regex.Pattern; import org.junit.jupiter.api.Nested; @@ -61,6 +59,9 @@ import org.springframework.data.repository.reactive.ReactiveCrudRepository; import org.springframework.util.ReflectionUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + /** * Unit tests for *
    @@ -145,7 +146,8 @@ void shouldWarnWhenUsingSortedAndCustomQuery(LogbackCapture logbackCapture) { Collections.emptySet(), parameterAccessor, Neo4jQueryType.DEFAULT, - () -> (typeSystem, mapAccessor) -> new TestEntity() + () -> (typeSystem, mapAccessor) -> new TestEntity(), + UnaryOperator.identity() ); assertThat(logbackCapture.getFormattedMessages()) .anyMatch(s -> s.matches( @@ -179,7 +181,8 @@ void orderBySpelShouldWork(LogbackCapture logbackCapture) { Collections.emptySet(), parameterAccessor, Neo4jQueryType.DEFAULT, - () -> (typeSystem, mapAccessor) -> new TestEntity() + () -> (typeSystem, mapAccessor) -> new TestEntity(), + UnaryOperator.identity() ); assertThat(pq.getQueryFragmentsAndParameters().getCypherQuery()) .isEqualTo("MATCH (n:Test) RETURN n ORDER BY name ASC SKIP $skip LIMIT $limit"); @@ -215,7 +218,8 @@ void literalReplacementsShouldWork() { Collections.emptySet(), parameterAccessor, Neo4jQueryType.DEFAULT, - () -> (typeSystem, mapAccessor) -> new TestEntity() + () -> (typeSystem, mapAccessor) -> new TestEntity(), + UnaryOperator.identity() ); return pq.getQueryFragmentsAndParameters().getCypherQuery(); }).block();