diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index 2d69c799ea..b986474861 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -15,11 +15,6 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.springframework.data.mongodb.core.aggregation.Fields.*; - -import java.util.Arrays; -import java.util.List; - import org.bson.Document; import org.bson.conversions.Bson; import org.springframework.data.domain.Sort; @@ -37,6 +32,11 @@ import org.springframework.data.mongodb.core.query.SerializationUtils; import org.springframework.util.Assert; +import java.util.Arrays; +import java.util.List; + +import static org.springframework.data.mongodb.core.aggregation.Fields.field; + /** * An {@code Aggregation} is a representation of a list of aggregation steps to be performed by the MongoDB Aggregation * Framework. @@ -50,6 +50,7 @@ * @author Nikolay Bogdanov * @author Gustavo de Geus * @author Jérôme Guyon + * @author Sangyong Choi * @since 1.3 */ public class Aggregation { @@ -664,6 +665,18 @@ public static LookupOperation lookup(Field from, Field localField, Field foreign return new LookupOperation(from, localField, foreignField, as); } + public static LookupOperation lookup(String from, String localField, String foreignField, String as, List aggregationOperations) { + return lookup(field(from), field(localField), field(foreignField), field(as), null, new AggregationPipeline(aggregationOperations)); + } + + public static LookupOperation lookup(String from, String localField, String foreignField, String as, List letExpressionVars, List aggregationOperations) { + return lookup(field(from), field(localField), field(foreignField), field(as), new LookupOperation.Let(letExpressionVars), new AggregationPipeline(aggregationOperations)); + } + + public static LookupOperation lookup(Field from, Field localField, Field foreignField, Field as, LookupOperation.Let let, AggregationPipeline pipeline) { + return new LookupOperation(from, localField, foreignField, as, let, pipeline); + } + /** * Creates a new {@link CountOperationBuilder}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java index 0439876dbe..77361e0214 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java @@ -21,6 +21,8 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import java.util.List; + /** * Encapsulates the aggregation framework {@code $lookup}-operation. We recommend to use the static factory method * {@link Aggregation#lookup(String, String, String, String)} instead of creating instances of this class directly. @@ -28,6 +30,7 @@ * @author Alessio Fachechi * @author Christoph Strobl * @author Mark Paluch + * @author Sangyong Choi * @since 1.9 * @see MongoDB Aggregation Framework: * $lookup @@ -39,6 +42,11 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe private final Field foreignField; private final ExposedField as; + @Nullable + private final Let let; + @Nullable + private final AggregationPipeline pipeline; + /** * Creates a new {@link LookupOperation} for the given {@link Field}s. * @@ -48,7 +56,10 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe * @param as must not be {@literal null}. */ public LookupOperation(Field from, Field localField, Field foreignField, Field as) { + this(from, localField, foreignField, as, null, null); + } + public LookupOperation(Field from, Field localField, Field foreignField, Field as, @Nullable Let let, @Nullable AggregationPipeline pipeline) { Assert.notNull(from, "From must not be null"); Assert.notNull(localField, "LocalField must not be null"); Assert.notNull(foreignField, "ForeignField must not be null"); @@ -58,6 +69,8 @@ public LookupOperation(Field from, Field localField, Field foreignField, Field a this.localField = localField; this.foreignField = foreignField; this.as = new ExposedField(as, true); + this.let = let; + this.pipeline = pipeline; } @Override @@ -75,6 +88,14 @@ public Document toDocument(AggregationOperationContext context) { lookupObject.append("foreignField", foreignField.getTarget()); lookupObject.append("as", as.getTarget()); + if (let != null) { + lookupObject.append("let", let.toDocument(context)); + } + + if (pipeline != null) { + lookupObject.append("pipeline", pipeline.toDocuments(context)); + } + return new Document(getOperator(), lookupObject); } @@ -184,4 +205,49 @@ public ForeignFieldBuilder localField(String name) { return this; } } + + public static class Let implements AggregationExpression{ + + private final List vars; + + public Let(List vars) { + Assert.notEmpty(vars, "'let' must not be null or empty"); + this.vars = vars; + } + + @Override + public Document toDocument(AggregationOperationContext context) { + return toLet(); + } + + private Document toLet() { + Document mappedVars = new Document(); + + for (ExpressionVariable var : this.vars) { + mappedVars.putAll(getMappedVariable(var)); + } + + return mappedVars; + } + + private Document getMappedVariable(ExpressionVariable var) { + return new Document(var.variableName, prefixDollarSign(var.expression)); + } + + private String prefixDollarSign(String expression) { + return "$" + expression; + } + + public static class ExpressionVariable { + + private final String variableName; + + private final String expression; + + public ExpressionVariable(String variableName, String expression) { + this.variableName = variableName; + this.expression = expression; + } + } + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index c98b0c0b5f..dcd52b49fb 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -15,35 +15,10 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.springframework.data.domain.Sort.Direction.*; -import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; -import static org.springframework.data.mongodb.core.query.Criteria.*; -import static org.springframework.data.mongodb.test.util.Assertions.*; - +import com.mongodb.client.MongoCollection; import lombok.Builder; - -import java.io.BufferedInputStream; -import java.text.ParseException; -import java.text.SimpleDateFormat; -import java.time.Instant; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.time.ZoneId; -import java.time.ZonedDateTime; -import java.time.temporal.ChronoField; -import java.time.temporal.ChronoUnit; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Date; -import java.util.List; -import java.util.Scanner; -import java.util.stream.Stream; - import org.assertj.core.data.Offset; import org.bson.Document; -import org.bson.types.ObjectId; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -66,17 +41,28 @@ import org.springframework.data.mongodb.core.index.GeoSpatialIndexType; import org.springframework.data.mongodb.core.index.GeospatialIndex; import org.springframework.data.mongodb.core.mapping.MongoId; +import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.Person; -import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion; -import org.springframework.data.mongodb.test.util.MongoTemplateExtension; -import org.springframework.data.mongodb.test.util.MongoTestTemplate; -import org.springframework.data.mongodb.test.util.MongoVersion; -import org.springframework.data.mongodb.test.util.Template; +import org.springframework.data.mongodb.test.util.*; -import com.mongodb.client.MongoCollection; +import java.io.BufferedInputStream; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.time.*; +import java.time.temporal.ChronoField; +import java.time.temporal.ChronoUnit; +import java.util.*; +import java.util.stream.Stream; + +import static org.springframework.data.domain.Sort.Direction.ASC; +import static org.springframework.data.domain.Sort.Direction.DESC; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; +import static org.springframework.data.mongodb.core.query.Criteria.where; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; +import static org.springframework.data.mongodb.test.util.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link MongoTemplate#aggregate(Aggregation, Class, Class)}. @@ -90,6 +76,7 @@ * @author Maninder Singh * @author Sergey Shcherbakov * @author Minsu Kim + * @author Sangyong Choi */ @ExtendWith(MongoTemplateExtension.class) public class AggregationTests { @@ -1518,6 +1505,53 @@ void shouldLookupPeopleCorectly() { assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1"); } + @Test + void shouldLookupPeopleCorrectlyWithPipeline() { + createUsersWithReferencedPersons(); + + TypedAggregation agg = newAggregation(User.class, // + lookup( + "person", + "_id", + "firstname", + "linkedPerson", + List.of(match(where("firstname").is("u1")))), // + sort(ASC, "id")); + + AggregationResults results = mongoTemplate.aggregate(agg, User.class, Document.class); + + List mappedResults = results.getMappedResults(); + + Document firstItem = mappedResults.get(0); + + assertThat(firstItem).containsEntry("_id", "u1"); + assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1"); + } + + @Test + void shouldLookupPeopleCorrectlyWithPipelineAndLet() { + createUsersWithReferencedPersons(); + + TypedAggregation agg = newAggregation(User.class, // + lookup( + "person", + "_id", + "firstname", + "linkedPerson", + List.of(new LookupOperation.Let.ExpressionVariable("personFirstname", "firstname")), + List.of(match(where("firstname").is("u1")))), + sort(ASC, "id")); + + AggregationResults results = mongoTemplate.aggregate(agg, User.class, Document.class); + + List mappedResults = results.getMappedResults(); + + Document firstItem = mappedResults.get(0); + + assertThat(firstItem).containsEntry("_id", "u1"); + assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1"); + } + @Test // DATAMONGO-1326 void shouldGroupByAndLookupPeopleCorectly() {