Skip to content

Implement let, pipeline in $lookup #4272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
*/
package org.springframework.data.mongodb.core.aggregation;

import static org.springframework.data.mongodb.core.aggregation.Fields.*;

import java.util.Arrays;
import java.util.List;

import org.bson.Document;
import org.bson.conversions.Bson;
import org.springframework.data.domain.Sort;
Expand All @@ -37,6 +32,11 @@
import org.springframework.data.mongodb.core.query.SerializationUtils;
import org.springframework.util.Assert;

import java.util.Arrays;
import java.util.List;

import static org.springframework.data.mongodb.core.aggregation.Fields.field;

/**
* An {@code Aggregation} is a representation of a list of aggregation steps to be performed by the MongoDB Aggregation
* Framework.
Expand All @@ -50,6 +50,7 @@
* @author Nikolay Bogdanov
* @author Gustavo de Geus
* @author Jérôme Guyon
* @author Sangyong Choi
* @since 1.3
*/
public class Aggregation {
Expand Down Expand Up @@ -664,6 +665,18 @@ public static LookupOperation lookup(Field from, Field localField, Field foreign
return new LookupOperation(from, localField, foreignField, as);
}

public static LookupOperation lookup(String from, String localField, String foreignField, String as, List<AggregationOperation> aggregationOperations) {
return lookup(field(from), field(localField), field(foreignField), field(as), null, new AggregationPipeline(aggregationOperations));
}

public static LookupOperation lookup(String from, String localField, String foreignField, String as, List<LookupOperation.Let.ExpressionVariable> letExpressionVars, List<AggregationOperation> aggregationOperations) {
return lookup(field(from), field(localField), field(foreignField), field(as), new LookupOperation.Let(letExpressionVars), new AggregationPipeline(aggregationOperations));
}

public static LookupOperation lookup(Field from, Field localField, Field foreignField, Field as, LookupOperation.Let let, AggregationPipeline pipeline) {
return new LookupOperation(from, localField, foreignField, as, let, pipeline);
}

/**
* Creates a new {@link CountOperationBuilder}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

import java.util.List;

/**
* Encapsulates the aggregation framework {@code $lookup}-operation. We recommend to use the static factory method
* {@link Aggregation#lookup(String, String, String, String)} instead of creating instances of this class directly.
*
* @author Alessio Fachechi
* @author Christoph Strobl
* @author Mark Paluch
* @author Sangyong Choi
* @since 1.9
* @see <a href="https://docs.mongodb.com/manual/reference/operator/aggregation/lookup/">MongoDB Aggregation Framework:
* $lookup</a>
Expand All @@ -39,6 +42,11 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
private final Field foreignField;
private final ExposedField as;

@Nullable
private final Let let;
@Nullable
private final AggregationPipeline pipeline;

/**
* Creates a new {@link LookupOperation} for the given {@link Field}s.
*
Expand All @@ -48,7 +56,10 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe
* @param as must not be {@literal null}.
*/
public LookupOperation(Field from, Field localField, Field foreignField, Field as) {
this(from, localField, foreignField, as, null, null);
}

public LookupOperation(Field from, Field localField, Field foreignField, Field as, @Nullable Let let, @Nullable AggregationPipeline pipeline) {
Assert.notNull(from, "From must not be null");
Assert.notNull(localField, "LocalField must not be null");
Assert.notNull(foreignField, "ForeignField must not be null");
Expand All @@ -58,6 +69,8 @@ public LookupOperation(Field from, Field localField, Field foreignField, Field a
this.localField = localField;
this.foreignField = foreignField;
this.as = new ExposedField(as, true);
this.let = let;
this.pipeline = pipeline;
}

@Override
Expand All @@ -75,6 +88,14 @@ public Document toDocument(AggregationOperationContext context) {
lookupObject.append("foreignField", foreignField.getTarget());
lookupObject.append("as", as.getTarget());

if (let != null) {
lookupObject.append("let", let.toDocument(context));
}

if (pipeline != null) {
lookupObject.append("pipeline", pipeline.toDocuments(context));
}

return new Document(getOperator(), lookupObject);
}

Expand Down Expand Up @@ -184,4 +205,49 @@ public ForeignFieldBuilder localField(String name) {
return this;
}
}

public static class Let implements AggregationExpression{

private final List<ExpressionVariable> vars;

public Let(List<ExpressionVariable> vars) {
Assert.notEmpty(vars, "'let' must not be null or empty");
this.vars = vars;
}

@Override
public Document toDocument(AggregationOperationContext context) {
return toLet();
}

private Document toLet() {
Document mappedVars = new Document();

for (ExpressionVariable var : this.vars) {
mappedVars.putAll(getMappedVariable(var));
}

return mappedVars;
}

private Document getMappedVariable(ExpressionVariable var) {
return new Document(var.variableName, prefixDollarSign(var.expression));
}

private String prefixDollarSign(String expression) {
return "$" + expression;
}

public static class ExpressionVariable {

private final String variableName;

private final String expression;

public ExpressionVariable(String variableName, String expression) {
this.variableName = variableName;
this.expression = expression;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,10 @@
*/
package org.springframework.data.mongodb.core.aggregation;

import static org.springframework.data.domain.Sort.Direction.*;
import static org.springframework.data.mongodb.core.aggregation.Aggregation.*;
import static org.springframework.data.mongodb.core.query.Criteria.*;
import static org.springframework.data.mongodb.test.util.Assertions.*;

import com.mongodb.client.MongoCollection;
import lombok.Builder;

import java.io.BufferedInputStream;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.temporal.ChronoField;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Scanner;
import java.util.stream.Stream;

import org.assertj.core.data.Offset;
import org.bson.Document;
import org.bson.types.ObjectId;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -66,17 +41,28 @@
import org.springframework.data.mongodb.core.index.GeoSpatialIndexType;
import org.springframework.data.mongodb.core.index.GeospatialIndex;
import org.springframework.data.mongodb.core.mapping.MongoId;
import org.springframework.data.mongodb.core.query.BasicQuery;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.repository.Person;
import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion;
import org.springframework.data.mongodb.test.util.MongoTemplateExtension;
import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.springframework.data.mongodb.test.util.MongoVersion;
import org.springframework.data.mongodb.test.util.Template;
import org.springframework.data.mongodb.test.util.*;

import com.mongodb.client.MongoCollection;
import java.io.BufferedInputStream;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.*;
import java.time.temporal.ChronoField;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.stream.Stream;

import static org.springframework.data.domain.Sort.Direction.ASC;
import static org.springframework.data.domain.Sort.Direction.DESC;
import static org.springframework.data.mongodb.core.aggregation.Aggregation.*;
import static org.springframework.data.mongodb.core.query.Criteria.where;
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
import static org.springframework.data.mongodb.test.util.Assertions.assertThatIllegalArgumentException;

/**
* Tests for {@link MongoTemplate#aggregate(Aggregation, Class, Class)}.
Expand All @@ -90,6 +76,7 @@
* @author Maninder Singh
* @author Sergey Shcherbakov
* @author Minsu Kim
* @author Sangyong Choi
*/
@ExtendWith(MongoTemplateExtension.class)
public class AggregationTests {
Expand Down Expand Up @@ -1518,6 +1505,53 @@ void shouldLookupPeopleCorectly() {
assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1");
}

@Test
void shouldLookupPeopleCorrectlyWithPipeline() {
createUsersWithReferencedPersons();

TypedAggregation<User> agg = newAggregation(User.class, //
lookup(
"person",
"_id",
"firstname",
"linkedPerson",
List.of(match(where("firstname").is("u1")))), //
sort(ASC, "id"));

AggregationResults<Document> results = mongoTemplate.aggregate(agg, User.class, Document.class);

List<Document> mappedResults = results.getMappedResults();

Document firstItem = mappedResults.get(0);

assertThat(firstItem).containsEntry("_id", "u1");
assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1");
}

@Test
void shouldLookupPeopleCorrectlyWithPipelineAndLet() {
createUsersWithReferencedPersons();

TypedAggregation<User> agg = newAggregation(User.class, //
lookup(
"person",
"_id",
"firstname",
"linkedPerson",
List.of(new LookupOperation.Let.ExpressionVariable("personFirstname", "firstname")),
List.of(match(where("firstname").is("u1")))),
sort(ASC, "id"));

AggregationResults<Document> results = mongoTemplate.aggregate(agg, User.class, Document.class);

List<Document> mappedResults = results.getMappedResults();

Document firstItem = mappedResults.get(0);

assertThat(firstItem).containsEntry("_id", "u1");
assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1");
}

@Test // DATAMONGO-1326
void shouldGroupByAndLookupPeopleCorectly() {

Expand Down