From b59900b2d6a418c6987fa54bd832ec965b14b1ba Mon Sep 17 00:00:00 2001 From: sangyongchoi Date: Sun, 15 Jan 2023 18:24:40 +0900 Subject: [PATCH 1/3] implement 'let' and 'pipeline' in $lookup --- .../mongodb/core/aggregation/Aggregation.java | 13 +++ .../core/aggregation/LookupOperation.java | 70 +++++++++++++ .../core/aggregation/AggregationTests.java | 99 +++++++++++++------ 3 files changed, 150 insertions(+), 32 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index 2d69c799ea..60e727482f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -50,6 +50,7 @@ * @author Nikolay Bogdanov * @author Gustavo de Geus * @author Jérôme Guyon + * @author Sangyong Choi * @since 1.3 */ public class Aggregation { @@ -664,6 +665,18 @@ public static LookupOperation lookup(Field from, Field localField, Field foreign return new LookupOperation(from, localField, foreignField, as); } + public static LookupOperation lookup(String from, String localField, String foreignField, String as, AggregationPipeline pipeline) { + return lookup(field(from), field(localField), field(foreignField), field(as), null, pipeline); + } + + public static LookupOperation lookup(String from, String localField, String foreignField, String as, LookupOperation.Let let, AggregationPipeline pipeline) { + return lookup(field(from), field(localField), field(foreignField), field(as), let, pipeline); + } + + public static LookupOperation lookup(Field from, Field localField, Field foreignField, Field as, LookupOperation.Let let, AggregationPipeline pipeline) { + return new LookupOperation(from, localField, foreignField, as, let, pipeline); + } + /** * Creates a new {@link CountOperationBuilder}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java index 0439876dbe..274dd7c151 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java @@ -21,6 +21,8 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import java.util.List; + /** * Encapsulates the aggregation framework {@code $lookup}-operation. We recommend to use the static factory method * {@link Aggregation#lookup(String, String, String, String)} instead of creating instances of this class directly. @@ -28,6 +30,7 @@ * @author Alessio Fachechi * @author Christoph Strobl * @author Mark Paluch + * @author Sangyong Choi * @since 1.9 * @see MongoDB Aggregation Framework: * $lookup @@ -39,6 +42,11 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe private final Field foreignField; private final ExposedField as; + @Nullable + private final Let let; + @Nullable + private final AggregationPipeline pipeline; + /** * Creates a new {@link LookupOperation} for the given {@link Field}s. * @@ -48,7 +56,14 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe * @param as must not be {@literal null}. */ public LookupOperation(Field from, Field localField, Field foreignField, Field as) { + this(from, localField, foreignField, as, null, null); + } + + public LookupOperation(Field from, Field localField, Field foreignField, Field as, @Nullable AggregationPipeline pipeline) { + this(from, localField, foreignField, as, null, pipeline); + } + public LookupOperation(Field from, Field localField, Field foreignField, Field as, @Nullable Let let, @Nullable AggregationPipeline pipeline) { Assert.notNull(from, "From must not be null"); Assert.notNull(localField, "LocalField must not be null"); Assert.notNull(foreignField, "ForeignField must not be null"); @@ -58,6 +73,8 @@ public LookupOperation(Field from, Field localField, Field foreignField, Field a this.localField = localField; this.foreignField = foreignField; this.as = new ExposedField(as, true); + this.let = let; + this.pipeline = pipeline; } @Override @@ -75,6 +92,14 @@ public Document toDocument(AggregationOperationContext context) { lookupObject.append("foreignField", foreignField.getTarget()); lookupObject.append("as", as.getTarget()); + if (let != null) { + lookupObject.append("let", let.toDocument(context)); + } + + if (pipeline != null) { + lookupObject.append("pipeline", pipeline.toDocuments(context)); + } + return new Document(getOperator(), lookupObject); } @@ -184,4 +209,49 @@ public ForeignFieldBuilder localField(String name) { return this; } } + + public static class Let implements AggregationExpression{ + + private final List vars; + + public Let(List vars) { + Assert.notEmpty(vars, "'let' must not be null or empty"); + this.vars = vars; + } + + @Override + public Document toDocument(AggregationOperationContext context) { + return toLet(); + } + + private Document toLet() { + Document mappedVars = new Document(); + + for (ExpressionVariable var : this.vars) { + mappedVars.putAll(getMappedVariable(var)); + } + + return mappedVars; + } + + private Document getMappedVariable(ExpressionVariable var) { + return new Document(var.variableName, prefixDollarSign(var.expression)); + } + + private String prefixDollarSign(String expression) { + return "$" + expression; + } + + public static class ExpressionVariable { + + private final String variableName; + + private final String expression; + + public ExpressionVariable(String variableName, String expression) { + this.variableName = variableName; + this.expression = expression; + } + } + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index c98b0c0b5f..90a8dbbccd 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -15,35 +15,10 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.springframework.data.domain.Sort.Direction.*; -import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; -import static org.springframework.data.mongodb.core.query.Criteria.*; -import static org.springframework.data.mongodb.test.util.Assertions.*; - +import com.mongodb.client.MongoCollection; import lombok.Builder; - -import java.io.BufferedInputStream; -import java.text.ParseException; -import java.text.SimpleDateFormat; -import java.time.Instant; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.time.ZoneId; -import java.time.ZonedDateTime; -import java.time.temporal.ChronoField; -import java.time.temporal.ChronoUnit; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Date; -import java.util.List; -import java.util.Scanner; -import java.util.stream.Stream; - import org.assertj.core.data.Offset; import org.bson.Document; -import org.bson.types.ObjectId; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -70,13 +45,23 @@ import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.Person; -import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion; -import org.springframework.data.mongodb.test.util.MongoTemplateExtension; -import org.springframework.data.mongodb.test.util.MongoTestTemplate; -import org.springframework.data.mongodb.test.util.MongoVersion; -import org.springframework.data.mongodb.test.util.Template; +import org.springframework.data.mongodb.test.util.*; -import com.mongodb.client.MongoCollection; +import java.io.BufferedInputStream; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.time.*; +import java.time.temporal.ChronoField; +import java.time.temporal.ChronoUnit; +import java.util.*; +import java.util.stream.Stream; + +import static org.springframework.data.domain.Sort.Direction.ASC; +import static org.springframework.data.domain.Sort.Direction.DESC; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; +import static org.springframework.data.mongodb.core.query.Criteria.where; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; +import static org.springframework.data.mongodb.test.util.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link MongoTemplate#aggregate(Aggregation, Class, Class)}. @@ -90,6 +75,7 @@ * @author Maninder Singh * @author Sergey Shcherbakov * @author Minsu Kim + * @author Sangyong Choi */ @ExtendWith(MongoTemplateExtension.class) public class AggregationTests { @@ -1518,6 +1504,55 @@ void shouldLookupPeopleCorectly() { assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1"); } + @Test + void shouldLookupPeopleCorrectlyWithPipeline() { + createUsersWithReferencedPersons(); + + TypedAggregation agg = newAggregation(User.class, // + lookup( + "person", + "_id", + "firstname", + "linkedPerson", + AggregationPipeline.of(match(where("firstname").is("u1")))), // + sort(ASC, "id")); + + AggregationResults results = mongoTemplate.aggregate(agg, User.class, Document.class); + + List mappedResults = results.getMappedResults(); + + Document firstItem = mappedResults.get(0); + + assertThat(firstItem).containsEntry("_id", "u1"); + assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1"); + } + + @Test + void shouldLookupPeopleCorrectlyWithPipelineAndLet() { + createUsersWithReferencedPersons(); + + TypedAggregation agg = newAggregation(User.class, // + lookup( + "person", + "_id", + "firstname", + "linkedPerson", + new LookupOperation.Let( + List.of(new LookupOperation.Let.ExpressionVariable("test", "test")) + ), + AggregationPipeline.of(match(where("firstname").is("u1")))), // + sort(ASC, "id")); + + AggregationResults results = mongoTemplate.aggregate(agg, User.class, Document.class); + + List mappedResults = results.getMappedResults(); + + Document firstItem = mappedResults.get(0); + + assertThat(firstItem).containsEntry("_id", "u1"); + assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1"); + } + @Test // DATAMONGO-1326 void shouldGroupByAndLookupPeopleCorectly() { From 8f83054ca6a6ae609b3dc5e4c7035c5be41fe259 Mon Sep 17 00:00:00 2001 From: sangyongchoi Date: Sun, 15 Jan 2023 18:54:06 +0900 Subject: [PATCH 2/3] modify parameter type --- .../mongodb/core/aggregation/Aggregation.java | 18 +++++++++--------- .../core/aggregation/AggregationTests.java | 9 ++++----- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index 60e727482f..b986474861 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -15,11 +15,6 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.springframework.data.mongodb.core.aggregation.Fields.*; - -import java.util.Arrays; -import java.util.List; - import org.bson.Document; import org.bson.conversions.Bson; import org.springframework.data.domain.Sort; @@ -37,6 +32,11 @@ import org.springframework.data.mongodb.core.query.SerializationUtils; import org.springframework.util.Assert; +import java.util.Arrays; +import java.util.List; + +import static org.springframework.data.mongodb.core.aggregation.Fields.field; + /** * An {@code Aggregation} is a representation of a list of aggregation steps to be performed by the MongoDB Aggregation * Framework. @@ -665,12 +665,12 @@ public static LookupOperation lookup(Field from, Field localField, Field foreign return new LookupOperation(from, localField, foreignField, as); } - public static LookupOperation lookup(String from, String localField, String foreignField, String as, AggregationPipeline pipeline) { - return lookup(field(from), field(localField), field(foreignField), field(as), null, pipeline); + public static LookupOperation lookup(String from, String localField, String foreignField, String as, List aggregationOperations) { + return lookup(field(from), field(localField), field(foreignField), field(as), null, new AggregationPipeline(aggregationOperations)); } - public static LookupOperation lookup(String from, String localField, String foreignField, String as, LookupOperation.Let let, AggregationPipeline pipeline) { - return lookup(field(from), field(localField), field(foreignField), field(as), let, pipeline); + public static LookupOperation lookup(String from, String localField, String foreignField, String as, List letExpressionVars, List aggregationOperations) { + return lookup(field(from), field(localField), field(foreignField), field(as), new LookupOperation.Let(letExpressionVars), new AggregationPipeline(aggregationOperations)); } public static LookupOperation lookup(Field from, Field localField, Field foreignField, Field as, LookupOperation.Let let, AggregationPipeline pipeline) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 90a8dbbccd..dcd52b49fb 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -41,6 +41,7 @@ import org.springframework.data.mongodb.core.index.GeoSpatialIndexType; import org.springframework.data.mongodb.core.index.GeospatialIndex; import org.springframework.data.mongodb.core.mapping.MongoId; +import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; @@ -1514,7 +1515,7 @@ void shouldLookupPeopleCorrectlyWithPipeline() { "_id", "firstname", "linkedPerson", - AggregationPipeline.of(match(where("firstname").is("u1")))), // + List.of(match(where("firstname").is("u1")))), // sort(ASC, "id")); AggregationResults results = mongoTemplate.aggregate(agg, User.class, Document.class); @@ -1537,10 +1538,8 @@ void shouldLookupPeopleCorrectlyWithPipelineAndLet() { "_id", "firstname", "linkedPerson", - new LookupOperation.Let( - List.of(new LookupOperation.Let.ExpressionVariable("test", "test")) - ), - AggregationPipeline.of(match(where("firstname").is("u1")))), // + List.of(new LookupOperation.Let.ExpressionVariable("personFirstname", "firstname")), + List.of(match(where("firstname").is("u1")))), sort(ASC, "id")); AggregationResults results = mongoTemplate.aggregate(agg, User.class, Document.class); From bb1ab1d944154da9ab5e1e719841d937282185cf Mon Sep 17 00:00:00 2001 From: sangyongchoi Date: Sun, 15 Jan 2023 18:54:52 +0900 Subject: [PATCH 3/3] remove unused constructor --- .../data/mongodb/core/aggregation/LookupOperation.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java index 274dd7c151..77361e0214 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java @@ -59,10 +59,6 @@ public LookupOperation(Field from, Field localField, Field foreignField, Field a this(from, localField, foreignField, as, null, null); } - public LookupOperation(Field from, Field localField, Field foreignField, Field as, @Nullable AggregationPipeline pipeline) { - this(from, localField, foreignField, as, null, pipeline); - } - public LookupOperation(Field from, Field localField, Field foreignField, Field as, @Nullable Let let, @Nullable AggregationPipeline pipeline) { Assert.notNull(from, "From must not be null"); Assert.notNull(localField, "LocalField must not be null");