Skip to content

Commit 696e53f

Browse files
christophstroblmp911de
authored andcommitted
DATAMONGO-1538 - Add support for $let to aggregation.
We now support $let in aggregation $project stage. ExpressionVariable total = newExpressionVariable("total").forExpression(ADD.of(field("price"), field("tax"))); ExpressionVariable discounted = newExpressionVariable("discounted").forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D)); newAggregation(Sales.class, project() .and(define(total, discounted) .andApply(MULTIPLY.of(field("total"), field("discounted")))) .as("finalTotal")); Original pull request: #417.
1 parent f512d8c commit 696e53f

File tree

5 files changed

+388
-49
lines changed

5 files changed

+388
-49
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.Arrays;
20+
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.LinkedHashMap;
2223
import java.util.List;
@@ -26,6 +27,7 @@
2627
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond.OtherwiseBuilder;
2728
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond.ThenBuilder;
2829
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Filter.AsBuilder;
30+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
2931
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
3032
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
3133
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
@@ -1984,6 +1986,28 @@ public static Map.AsBuilder mapItemsOf(String fieldReference) {
19841986
public static Map.AsBuilder mapItemsOf(AggregationExpression expression) {
19851987
return Map.itemsOf(expression);
19861988
}
1989+
1990+
/**
1991+
* Start creating new {@link Let} that allows definition of {@link ExpressionVariable} that can be used within a
1992+
* nested {@link AggregationExpression}.
1993+
*
1994+
* @param variables must not be {@literal null}.
1995+
* @return
1996+
*/
1997+
public static Let.LetBuilder define(ExpressionVariable... variables) {
1998+
return Let.define(variables);
1999+
}
2000+
2001+
/**
2002+
* Start creating new {@link Let} that allows definition of {@link ExpressionVariable} that can be used within a
2003+
* nested {@link AggregationExpression}.
2004+
*
2005+
* @param variables must not be {@literal null}.
2006+
* @return
2007+
*/
2008+
public static Let.LetBuilder define(Collection<ExpressionVariable> variables) {
2009+
return Let.define(variables);
2010+
}
19872011
}
19882012

19892013
/**
@@ -6694,4 +6718,185 @@ public Cond otherwiseValueOf(AggregationExpression expression) {
66946718
}
66956719
}
66966720
}
6721+
6722+
/**
6723+
* {@link AggregationExpression} for {@code $let} that binds {@link AggregationExpression} to variables for use in the
6724+
* specified {@code in} expression, and returns the result of the expression.
6725+
*
6726+
* @author Christoph Strobl
6727+
* @since 1.10
6728+
*/
6729+
class Let implements AggregationExpression {
6730+
6731+
private final List<ExpressionVariable> vars;
6732+
private final AggregationExpression expression;
6733+
6734+
private Let(List<ExpressionVariable> vars, AggregationExpression expression) {
6735+
6736+
this.vars = vars;
6737+
this.expression = expression;
6738+
}
6739+
6740+
/**
6741+
* Start creating new {@link Let} by defining the variables for {@code $vars}.
6742+
*
6743+
* @param variables must not be {@literal null}.
6744+
* @return
6745+
*/
6746+
public static LetBuilder define(final Collection<ExpressionVariable> variables) {
6747+
6748+
Assert.notNull(variables, "Variables must not be null!");
6749+
6750+
return new LetBuilder() {
6751+
@Override
6752+
public Let andApply(final AggregationExpression expression) {
6753+
6754+
Assert.notNull(expression, "Expression must not be null!");
6755+
return new Let(new ArrayList<ExpressionVariable>(variables), expression);
6756+
}
6757+
};
6758+
}
6759+
6760+
/**
6761+
* Start creating new {@link Let} by defining the variables for {@code $vars}.
6762+
*
6763+
* @param variables must not be {@literal null}.
6764+
* @return
6765+
*/
6766+
public static LetBuilder define(final ExpressionVariable... variables) {
6767+
6768+
Assert.notNull(variables, "Variables must not be null!");
6769+
6770+
return new LetBuilder() {
6771+
@Override
6772+
public Let andApply(final AggregationExpression expression) {
6773+
6774+
Assert.notNull(expression, "Expression must not be null!");
6775+
return new Let(Arrays.asList(variables), expression);
6776+
}
6777+
};
6778+
}
6779+
6780+
public interface LetBuilder {
6781+
6782+
/**
6783+
* Define the {@link AggregationExpression} to evaluate.
6784+
*
6785+
* @param expression must not be {@literal null}.
6786+
* @return
6787+
*/
6788+
Let andApply(AggregationExpression expression);
6789+
}
6790+
6791+
@Override
6792+
public Document toDocument(final AggregationOperationContext context) {
6793+
6794+
return toLet(new ExposedFieldsAggregationOperationContext(
6795+
ExposedFields.synthetic(Fields.fields(getVariableNames())), context) {
6796+
6797+
@Override
6798+
public FieldReference getReference(Field field) {
6799+
6800+
FieldReference ref = null;
6801+
try {
6802+
ref = context.getReference(field);
6803+
} catch (Exception e) {
6804+
// just ignore that one.
6805+
}
6806+
return ref != null ? ref : super.getReference(field);
6807+
}
6808+
});
6809+
}
6810+
6811+
private String[] getVariableNames() {
6812+
6813+
String[] varNames = new String[this.vars.size()];
6814+
for (int i = 0; i < this.vars.size(); i++) {
6815+
varNames[i] = this.vars.get(i).variableName;
6816+
}
6817+
return varNames;
6818+
}
6819+
6820+
private Document toLet(AggregationOperationContext context) {
6821+
6822+
Document letExpression = new Document();
6823+
6824+
Document mappedVars = new Document();
6825+
for (ExpressionVariable var : this.vars) {
6826+
mappedVars.putAll(getMappedVariable(var, context));
6827+
}
6828+
6829+
letExpression.put("vars", mappedVars);
6830+
letExpression.put("in", getMappedIn(context));
6831+
6832+
return new Document("$let", letExpression);
6833+
}
6834+
6835+
private Document getMappedVariable(ExpressionVariable var, AggregationOperationContext context) {
6836+
6837+
return new Document(var.variableName, var.expression instanceof AggregationExpression
6838+
? ((AggregationExpression) var.expression).toDocument(context) : var.expression);
6839+
}
6840+
6841+
private Object getMappedIn(AggregationOperationContext context) {
6842+
return expression.toDocument(new NestedDelegatingExpressionAggregationOperationContext(context));
6843+
}
6844+
6845+
/**
6846+
* @author Christoph Strobl
6847+
*/
6848+
public static class ExpressionVariable {
6849+
6850+
private final String variableName;
6851+
private final Object expression;
6852+
6853+
/**
6854+
* Creates new {@link ExpressionVariable}.
6855+
*
6856+
* @param variableName can be {@literal null}.
6857+
* @param expression can be {@literal null}.
6858+
*/
6859+
private ExpressionVariable(String variableName, Object expression) {
6860+
6861+
this.variableName = variableName;
6862+
this.expression = expression;
6863+
}
6864+
6865+
/**
6866+
* Create a new {@link ExpressionVariable} with given name.
6867+
*
6868+
* @param variableName must not be {@literal null}.
6869+
* @return never {@literal null}.
6870+
*/
6871+
public static ExpressionVariable newVariable(String variableName) {
6872+
6873+
Assert.notNull(variableName, "VariableName must not be null!");
6874+
return new ExpressionVariable(variableName, null);
6875+
}
6876+
6877+
/**
6878+
* Create a new {@link ExpressionVariable} with current name and given {@literal expression}.
6879+
*
6880+
* @param expression must not be {@literal null}.
6881+
* @return never {@literal null}.
6882+
*/
6883+
public ExpressionVariable forExpression(AggregationExpression expression) {
6884+
6885+
Assert.notNull(expression, "Expression must not be null!");
6886+
return new ExpressionVariable(variableName, expression);
6887+
}
6888+
6889+
/**
6890+
* Create a new {@link ExpressionVariable} with current name and given {@literal expressionObject}.
6891+
*
6892+
* @param expressionObject must not be {@literal null}.
6893+
* @return never {@literal null}.
6894+
*/
6895+
public ExpressionVariable forExpression(Document expressionObject) {
6896+
6897+
Assert.notNull(expressionObject, "Expression must not be null!");
6898+
return new ExpressionVariable(variableName, expressionObject);
6899+
}
6900+
}
6901+
}
66976902
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
@Deprecated
3636
public enum AggregationFunctionExpressions {
3737

38-
SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD;
38+
SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD, MULTIPLY;
3939

4040
/**
4141
* Returns an {@link AggregationExpression} build from the current {@link Enum} name and the given parameters.

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.Arrays;
20+
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.List;
2223

2324
import org.bson.Document;
25+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
2426
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond;
2527
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.IfNull;
2628
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
@@ -850,7 +852,7 @@ public ProjectionOperationBuilder differenceToArray(String array) {
850852
/**
851853
* Generates a {@code $setIsSubset} expression that takes array of the previously mentioned field and returns
852854
* {@literal true} if it is a subset of the given {@literal array}.
853-
*
855+
*
854856
* @param array must not be {@literal null}.
855857
* @return never {@literal null}.
856858
* @since 1.10
@@ -1193,7 +1195,35 @@ public ProjectionOperationBuilder dateAsFormattedString(String format) {
11931195
return this.operation.and(AggregationExpressions.DateToString.dateOf(name).toString(format));
11941196
}
11951197

1196-
/*
1198+
/**
1199+
* Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the
1200+
* result of the expression.
1201+
*
1202+
* @param valueExpression The {@link AggregationExpression} bound to {@literal variableName}.
1203+
* @param variableName The variable name to be used in the {@literal in} {@link AggregationExpression}.
1204+
* @param in The {@link AggregationExpression} to evaluate.
1205+
* @return never {@literal null}.
1206+
* @since 1.10
1207+
*/
1208+
public ProjectionOperationBuilder let(AggregationExpression valueExpression, String variableName,
1209+
AggregationExpression in) {
1210+
return this.operation.and(AggregationExpressions.Let.define(ExpressionVariable.newVariable(variableName).forExpression(valueExpression)).andApply(in));
1211+
}
1212+
1213+
/**
1214+
* Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the
1215+
* result of the expression.
1216+
*
1217+
* @param variables The bound {@link ExpressionVariable}s.
1218+
* @param in The {@link AggregationExpression} to evaluate.
1219+
* @return never {@literal null}.
1220+
* @since 1.10
1221+
*/
1222+
public ProjectionOperationBuilder let(Collection<ExpressionVariable> variables, AggregationExpression in) {
1223+
return this.operation.and(AggregationExpressions.Let.define(variables).andApply(in));
1224+
}
1225+
1226+
/*
11971227
* (non-Javadoc)
11981228
* @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDocument(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext)
11991229
*/

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
import org.springframework.data.mongodb.core.Venue;
6161
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Cond;
6262
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.ConditionalOperators;
63+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let;
64+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
6365
import org.springframework.data.mongodb.core.aggregation.AggregationTests.CarDescriptor.Entry;
6466
import org.springframework.data.mongodb.core.index.GeospatialIndex;
6567
import org.springframework.data.mongodb.core.query.Criteria;
@@ -647,9 +649,9 @@ public void aggregationUsingIfNullProjection() {
647649
mongoTemplate.insert(new LineItem("idonly", null, 0));
648650

649651
TypedAggregation<LineItem> aggregation = newAggregation(LineItem.class, //
650-
project("id") //
651-
.and("caption")//
652-
.applyCondition(ConditionalOperators.ifNull("caption").then("unknown")),
652+
project("id") //
653+
.and("caption")//
654+
.applyCondition(ConditionalOperators.ifNull("caption").then("unknown")),
653655
sort(ASC, "id"));
654656

655657
assertThat(aggregation.toString(), is(notNullValue()));
@@ -1541,6 +1543,36 @@ public void filterShouldBeAppliedCorrectly() {
15411543
Sales.builder().id("2").items(Collections.<Item> emptyList()).build()));
15421544
}
15431545

1546+
/**
1547+
* @see DATAMONGO-1538
1548+
*/
1549+
@Test
1550+
public void letShouldBeAppliedCorrectly() {
1551+
1552+
assumeTrue(mongoVersion.isGreaterThanOrEqualTo(THREE_DOT_TWO));
1553+
1554+
Sales2 sales1 = Sales2.builder().id("1").price(10).tax(0.5F).applyDiscount(true).build();
1555+
Sales2 sales2 = Sales2.builder().id("2").price(10).tax(0.25F).applyDiscount(false).build();
1556+
1557+
mongoTemplate.insert(Arrays.asList(sales1, sales2), Sales2.class);
1558+
1559+
ExpressionVariable total = ExpressionVariable.newVariable("total")
1560+
.forExpression(AggregationFunctionExpressions.ADD.of(Fields.field("price"), Fields.field("tax")));
1561+
ExpressionVariable discounted = ExpressionVariable.newVariable("discounted")
1562+
.forExpression(Cond.when("applyDiscount").then(0.9D).otherwise(1.0D));
1563+
1564+
TypedAggregation<Sales2> agg = Aggregation.newAggregation(Sales2.class,
1565+
Aggregation.project()
1566+
.and(Let.define(total, discounted).andApply(
1567+
AggregationFunctionExpressions.MULTIPLY.of(Fields.field("total"), Fields.field("discounted"))))
1568+
.as("finalTotal"));
1569+
1570+
AggregationResults<Document> result = mongoTemplate.aggregate(agg, Document.class);
1571+
assertThat(result.getMappedResults(),
1572+
contains(new Document("_id", "1").append("finalTotal", 9.450000000000001D),
1573+
new Document("_id", "2").append("finalTotal", 10.25D)));
1574+
}
1575+
15441576
private void createUsersWithReferencedPersons() {
15451577

15461578
mongoTemplate.dropCollection(User.class);
@@ -1782,6 +1814,9 @@ public InventoryItem(int id, String item, String description, int qty) {
17821814
}
17831815
}
17841816

1817+
/**
1818+
* @DATAMONGO-1491
1819+
*/
17851820
@lombok.Data
17861821
@Builder
17871822
static class Sales {
@@ -1790,6 +1825,9 @@ static class Sales {
17901825
List<Item> items;
17911826
}
17921827

1828+
/**
1829+
* @DATAMONGO-1491
1830+
*/
17931831
@lombok.Data
17941832
@Builder
17951833
static class Item {
@@ -1799,4 +1837,17 @@ static class Item {
17991837
Integer quantity;
18001838
Long price;
18011839
}
1840+
1841+
/**
1842+
* @DATAMONGO-1538
1843+
*/
1844+
@lombok.Data
1845+
@Builder
1846+
static class Sales2 {
1847+
1848+
String id;
1849+
Integer price;
1850+
Float tax;
1851+
boolean applyDiscount;
1852+
}
18021853
}

0 commit comments

Comments
 (0)