Skip to content

Commit 8d4fa59

Browse files
DATAMONGO-1538 - Add support for $let to aggregation.
We now support $let in aggregation $project stage. ExpressionVariable total = newExpressionVariable("total").forExpression(ADD.of(field("price"), field("tax"))); ExpressionVariable discounted = newExpressionVariable("discounted).forExpression(ConditionalOperator.newBuilder().when(field("applyDiscount")).then(0.9D).otherwise(1.0D)); newAggregation(Sales.class, project() .and(define(total, discounted) .andApply(MULTIPLY.of(field("total"), field("discounted")))) .as("finalTotal"));
1 parent 5d45936 commit 8d4fa59

File tree

5 files changed

+396
-9
lines changed

5 files changed

+396
-9
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java

Lines changed: 187 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.Arrays;
20+
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.LinkedHashMap;
2223
import java.util.List;
@@ -2344,6 +2345,7 @@ public static Abs absoluteValueOf(AggregationExpression expression) {
23442345
Assert.notNull(expression, "Expression must not be null!");
23452346
return new Abs(expression);
23462347
}
2348+
23472349
/**
23482350
* Creates new {@link Abs}.
23492351
*
@@ -2495,7 +2497,6 @@ protected String getMongoMethod() {
24952497
return "$divide";
24962498
}
24972499

2498-
24992500
/**
25002501
* Creates new {@link Divide}.
25012502
*
@@ -4390,7 +4391,7 @@ protected String getMongoMethod() {
43904391

43914392
/**
43924393
* Creates new {@link Second}.
4393-
*
4394+
*
43944395
* @param fieldReference must not be {@literal null}.
43954396
* @return
43964397
*/
@@ -4705,7 +4706,7 @@ protected String getMongoMethod() {
47054706

47064707
/**
47074708
* Creates new {@link Max}.
4708-
*
4709+
*
47094710
* @param fieldReference must not be {@literal null}.
47104711
* @return
47114712
*/
@@ -4730,7 +4731,7 @@ public static Max maxOf(AggregationExpression expression) {
47304731
/**
47314732
* Creates new {@link Max} with all previously added arguments appending the given one. <br />
47324733
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
4733-
*
4734+
*
47344735
* @param fieldReference must not be {@literal null}.
47354736
* @return
47364737
*/
@@ -4809,7 +4810,7 @@ public static Min minOf(AggregationExpression expression) {
48094810
/**
48104811
* Creates new {@link Min} with all previously added arguments appending the given one. <br />
48114812
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
4812-
*
4813+
*
48134814
* @param fieldReference must not be {@literal null}.
48144815
* @return
48154816
*/
@@ -4942,7 +4943,7 @@ protected String getMongoMethod() {
49424943

49434944
/**
49444945
* Creates new {@link StdDevSamp}.
4945-
*
4946+
*
49464947
* @param fieldReference must not be {@literal null}.
49474948
* @return
49484949
*/
@@ -5715,4 +5716,184 @@ public static Not not(AggregationExpression expression) {
57155716
}
57165717
}
57175718

5719+
/**
5720+
* {@link AggregationExpression} for {@code $let} that binds {@link AggregationExpression} to variables for use in the
5721+
* specified {@code in} expression, and returns the result of the expression.
5722+
*
5723+
* @author Christoph Strobl
5724+
* @since 1.10
5725+
*/
5726+
class Let implements AggregationExpression {
5727+
5728+
private final List<ExpressionVariable> vars;
5729+
private final AggregationExpression expression;
5730+
5731+
private Let(List<ExpressionVariable> vars, AggregationExpression expression) {
5732+
5733+
this.vars = vars;
5734+
this.expression = expression;
5735+
}
5736+
5737+
/**
5738+
* Start creating new {@link Let} by defining the variables for {@code $vars}.
5739+
*
5740+
* @param variables must not be {@literal null}.
5741+
* @return
5742+
*/
5743+
public static LetBuilder define(final Collection<ExpressionVariable> variables) {
5744+
5745+
Assert.notNull(variables, "Variables must not be null!");
5746+
5747+
return new LetBuilder() {
5748+
@Override
5749+
public Let andApply(final AggregationExpression expression) {
5750+
5751+
Assert.notNull(expression, "Expression must not be null!");
5752+
return new Let(new ArrayList<ExpressionVariable>(variables), expression);
5753+
}
5754+
};
5755+
}
5756+
5757+
/**
5758+
* Start creating new {@link Let} by defining the variables for {@code $vars}.
5759+
*
5760+
* @param variables must not be {@literal null}.
5761+
* @return
5762+
*/
5763+
public static LetBuilder define(final ExpressionVariable... variables) {
5764+
5765+
Assert.notNull(variables, "Variables must not be null!");
5766+
5767+
return new LetBuilder() {
5768+
@Override
5769+
public Let andApply(final AggregationExpression expression) {
5770+
5771+
Assert.notNull(expression, "Expression must not be null!");
5772+
return new Let(Arrays.asList(variables), expression);
5773+
}
5774+
};
5775+
}
5776+
5777+
public interface LetBuilder {
5778+
5779+
/**
5780+
* Define the {@link AggregationExpression} to evaluate.
5781+
*
5782+
* @param expression must not be {@literal null}.
5783+
* @return
5784+
*/
5785+
Let andApply(AggregationExpression expression);
5786+
}
5787+
5788+
@Override
5789+
public DBObject toDbObject(final AggregationOperationContext context) {
5790+
5791+
return toLet(new ExposedFieldsAggregationOperationContext(
5792+
ExposedFields.synthetic(Fields.fields(getVariableNames())), context) {
5793+
5794+
@Override
5795+
public FieldReference getReference(Field field) {
5796+
5797+
FieldReference ref = null;
5798+
try {
5799+
ref = context.getReference(field);
5800+
} catch (Exception e) {
5801+
// just ignore that one.
5802+
}
5803+
return ref != null ? ref : super.getReference(field);
5804+
}
5805+
});
5806+
}
5807+
5808+
private String[] getVariableNames() {
5809+
5810+
String[] varNames = new String[this.vars.size()];
5811+
for (int i = 0; i < this.vars.size(); i++) {
5812+
varNames[i] = this.vars.get(i).variableName;
5813+
}
5814+
return varNames;
5815+
}
5816+
5817+
private DBObject toLet(AggregationOperationContext context) {
5818+
5819+
DBObject letExpression = new BasicDBObject();
5820+
5821+
DBObject mappedVars = new BasicDBObject();
5822+
for (ExpressionVariable var : this.vars) {
5823+
mappedVars.putAll(getMappedVariable(var, context));
5824+
}
5825+
5826+
letExpression.put("vars", mappedVars);
5827+
letExpression.put("in", getMappedIn(context));
5828+
5829+
return new BasicDBObject("$let", letExpression);
5830+
}
5831+
5832+
private DBObject getMappedVariable(ExpressionVariable var, AggregationOperationContext context) {
5833+
5834+
return new BasicDBObject(var.variableName, var.expression instanceof AggregationExpression
5835+
? ((AggregationExpression) var.expression).toDbObject(context) : var.expression);
5836+
}
5837+
5838+
private Object getMappedIn(AggregationOperationContext context) {
5839+
return expression.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(context));
5840+
}
5841+
5842+
/**
5843+
* @author Christoph Strobl
5844+
*/
5845+
public static class ExpressionVariable {
5846+
5847+
private final String variableName;
5848+
private final Object expression;
5849+
5850+
/**
5851+
* Creates new {@link ExpressionVariable}.
5852+
*
5853+
* @param variableName can be {@literal null}.
5854+
* @param expression can be {@literal null}.
5855+
*/
5856+
private ExpressionVariable(String variableName, Object expression) {
5857+
5858+
this.variableName = variableName;
5859+
this.expression = expression;
5860+
}
5861+
5862+
/**
5863+
* Create a new {@link ExpressionVariable} with given name.
5864+
*
5865+
* @param variableName must not be {@literal null}.
5866+
* @return never {@literal null}.
5867+
*/
5868+
public static ExpressionVariable newVariable(String variableName) {
5869+
5870+
Assert.notNull(variableName, "VariableName must not be null!");
5871+
return new ExpressionVariable(variableName, null);
5872+
}
5873+
5874+
/**
5875+
* Create a new {@link ExpressionVariable} with current name and given {@literal expression}.
5876+
*
5877+
* @param expression must not be {@literal null}.
5878+
* @return never {@literal null}.
5879+
*/
5880+
public ExpressionVariable forExpression(AggregationExpression expression) {
5881+
5882+
Assert.notNull(expression, "Expression must not be null!");
5883+
return new ExpressionVariable(variableName, expression);
5884+
}
5885+
5886+
/**
5887+
* Create a new {@link ExpressionVariable} with current name and given {@literal expressionObject}.
5888+
*
5889+
* @param expressionObject must not be {@literal null}.
5890+
* @return never {@literal null}.
5891+
*/
5892+
public ExpressionVariable forExpression(DBObject expressionObject) {
5893+
5894+
Assert.notNull(expressionObject, "Expression must not be null!");
5895+
return new ExpressionVariable(variableName, expressionObject);
5896+
}
5897+
}
5898+
}
57185899
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
@Deprecated
3737
public enum AggregationFunctionExpressions {
3838

39-
SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT;
39+
SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD, MULTIPLY;
4040

4141
/**
4242
* Returns an {@link AggregationExpression} build from the current {@link Enum} name and the given parameters.

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.Arrays;
20+
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.List;
2223

24+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Let.ExpressionVariable;
2325
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
2426
import org.springframework.data.mongodb.core.aggregation.Fields.AggregationField;
2527
import org.springframework.data.mongodb.core.aggregation.ProjectionOperation.ProjectionOperationBuilder.FieldProjection;
@@ -850,7 +852,7 @@ public ProjectionOperationBuilder differenceToArray(String array) {
850852
/**
851853
* Generates a {@code $setIsSubset} expression that takes array of the previously mentioned field and returns
852854
* {@literal true} if it is a subset of the given {@literal array}.
853-
*
855+
*
854856
* @param array must not be {@literal null}.
855857
* @return never {@literal null}.
856858
* @since 1.10
@@ -1193,7 +1195,35 @@ public ProjectionOperationBuilder dateAsFormattedString(String format) {
11931195
return this.operation.and(AggregationExpressions.DateToString.dateOf(name).toString(format));
11941196
}
11951197

1196-
/*
1198+
/**
1199+
* Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the
1200+
* result of the expression.
1201+
*
1202+
* @param valueExpression The {@link AggregationExpression} bound to {@literal variableName}.
1203+
* @param variableName The variable name to be used in the {@literal in} {@link AggregationExpression}.
1204+
* @param in The {@link AggregationExpression} to evaluate.
1205+
* @return never {@literal null}.
1206+
* @since 1.10
1207+
*/
1208+
public ProjectionOperationBuilder let(AggregationExpression valueExpression, String variableName,
1209+
AggregationExpression in) {
1210+
return this.operation.and(AggregationExpressions.Let.define(ExpressionVariable.newVariable(variableName).forExpression(valueExpression)).andApply(in));
1211+
}
1212+
1213+
/**
1214+
* Generates a {@code $let} expression that binds variables for use in the specified expression, and returns the
1215+
* result of the expression.
1216+
*
1217+
* @param variables The bound {@link ExpressionVariable}s.
1218+
* @param in The {@link AggregationExpression} to evaluate.
1219+
* @return never {@literal null}.
1220+
* @since 1.10
1221+
*/
1222+
public ProjectionOperationBuilder let(Collection<ExpressionVariable> variables, AggregationExpression in) {
1223+
return this.operation.and(AggregationExpressions.Let.define(variables).andApply(in));
1224+
}
1225+
1226+
/*
11971227
* (non-Javadoc)
11981228
* @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext)
11991229
*/

0 commit comments

Comments
 (0)