Skip to content

Commit 42edbd9

Browse files
DATAMONGO-1540 - Add support for $map (aggregation).
We now support $map operator in aggregation.
1 parent cc64dcc commit 42edbd9

File tree

3 files changed

+215
-14
lines changed

3 files changed

+215
-14
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java

Lines changed: 183 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import java.util.Collections;
2121
import java.util.LinkedHashMap;
2222
import java.util.List;
23-
import java.util.Map;
2423

2524
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Filter.AsBuilder;
25+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Map.ArrayOfBuilder;
2626
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
2727
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
2828
import org.springframework.util.Assert;
@@ -1781,6 +1781,24 @@ private boolean usesFieldRef() {
17811781
}
17821782
}
17831783

1784+
/**
1785+
* Gateway to {@literal Date} aggregation operations.
1786+
*
1787+
* @author Christoph Strobl
1788+
*/
1789+
class VariableOperators {
1790+
1791+
/**
1792+
* Starts building new {@link Map} that applies an {@link AggregationExpression} to each item of a referenced array
1793+
* and returns an array with the applied results.
1794+
*
1795+
* @return
1796+
*/
1797+
public static ArrayOfBuilder map() {
1798+
return Map.map();
1799+
}
1800+
}
1801+
17841802
/**
17851803
* @author Christoph Strobl
17861804
*/
@@ -1809,10 +1827,10 @@ public DBObject toDbObject(Object value, AggregationOperationContext context) {
18091827
args.add(unpack(val, context));
18101828
}
18111829
valueToUse = args;
1812-
} else if (value instanceof Map) {
1830+
} else if (value instanceof java.util.Map) {
18131831

18141832
DBObject dbo = new BasicDBObject();
1815-
for (Map.Entry<String, Object> entry : ((Map<String, Object>) value).entrySet()) {
1833+
for (java.util.Map.Entry<String, Object> entry : ((java.util.Map<String, Object>) value).entrySet()) {
18161834
dbo.put(entry.getKey(), unpack(entry.getValue(), context));
18171835
}
18181836
valueToUse = dbo;
@@ -1866,10 +1884,10 @@ protected List<Object> append(Object value) {
18661884

18671885
protected Object append(String key, Object value) {
18681886

1869-
if (!(value instanceof Map)) {
1887+
if (!(value instanceof java.util.Map)) {
18701888
throw new IllegalArgumentException("o_O");
18711889
}
1872-
Map<String, Object> clone = new LinkedHashMap<String, Object>((Map<String, Object>) value);
1890+
java.util.Map<String, Object> clone = new LinkedHashMap<String, Object>((java.util.Map<String, Object>) value);
18731891
clone.put(key, value);
18741892
return clone;
18751893

@@ -2344,6 +2362,7 @@ public static Abs absoluteValueOf(AggregationExpression expression) {
23442362
Assert.notNull(expression, "Expression must not be null!");
23452363
return new Abs(expression);
23462364
}
2365+
23472366
/**
23482367
* Creates new {@link Abs}.
23492368
*
@@ -2495,7 +2514,6 @@ protected String getMongoMethod() {
24952514
return "$divide";
24962515
}
24972516

2498-
24992517
/**
25002518
* Creates new {@link Divide}.
25012519
*
@@ -4390,7 +4408,7 @@ protected String getMongoMethod() {
43904408

43914409
/**
43924410
* Creates new {@link Second}.
4393-
*
4411+
*
43944412
* @param fieldReference must not be {@literal null}.
43954413
* @return
43964414
*/
@@ -4509,9 +4527,9 @@ public DateToString toString(String format) {
45094527
};
45104528
}
45114529

4512-
private static Map<String, Object> argumentMap(Object date, String format) {
4530+
private static java.util.Map<String, Object> argumentMap(Object date, String format) {
45134531

4514-
Map<String, Object> args = new LinkedHashMap<String, Object>(2);
4532+
java.util.Map<String, Object> args = new LinkedHashMap<String, Object>(2);
45154533
args.put("format", format);
45164534
args.put("date", date);
45174535
return args;
@@ -4705,7 +4723,7 @@ protected String getMongoMethod() {
47054723

47064724
/**
47074725
* Creates new {@link Max}.
4708-
*
4726+
*
47094727
* @param fieldReference must not be {@literal null}.
47104728
* @return
47114729
*/
@@ -4730,7 +4748,7 @@ public static Max maxOf(AggregationExpression expression) {
47304748
/**
47314749
* Creates new {@link Max} with all previously added arguments appending the given one. <br />
47324750
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
4733-
*
4751+
*
47344752
* @param fieldReference must not be {@literal null}.
47354753
* @return
47364754
*/
@@ -4809,7 +4827,7 @@ public static Min minOf(AggregationExpression expression) {
48094827
/**
48104828
* Creates new {@link Min} with all previously added arguments appending the given one. <br />
48114829
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
4812-
*
4830+
*
48134831
* @param fieldReference must not be {@literal null}.
48144832
* @return
48154833
*/
@@ -4942,7 +4960,7 @@ protected String getMongoMethod() {
49424960

49434961
/**
49444962
* Creates new {@link StdDevSamp}.
4945-
*
4963+
*
49464964
* @param fieldReference must not be {@literal null}.
49474965
* @return
49484966
*/
@@ -5715,4 +5733,156 @@ public static Not not(AggregationExpression expression) {
57155733
}
57165734
}
57175735

5736+
/**
5737+
* {@link AggregationExpression} for {@code $map}.
5738+
*/
5739+
class Map implements AggregationExpression {
5740+
5741+
private Object sourceArray;
5742+
private String itemVariableName;
5743+
private AggregationExpression functionToApply;
5744+
5745+
private Map(Object sourceArray, String itemVariableName, AggregationExpression functionToApply) {
5746+
5747+
Assert.notNull(sourceArray, "SourceArray must not be null!");
5748+
Assert.notNull(itemVariableName, "ItemVariableName must not be null!");
5749+
Assert.notNull(functionToApply, "FunctionToApply must not be null!");
5750+
5751+
this.sourceArray = sourceArray;
5752+
this.itemVariableName = itemVariableName;
5753+
this.functionToApply = functionToApply;
5754+
}
5755+
5756+
/**
5757+
* Starts building new {@link Map} that applies an {@link AggregationExpression} to each item of a referenced array
5758+
* and returns an array with the applied results.
5759+
*
5760+
* @return
5761+
*/
5762+
static ArrayOfBuilder map() {
5763+
5764+
return new ArrayOfBuilder() {
5765+
5766+
@Override
5767+
public AsBuilder itemsOf(final String fieldReference) {
5768+
5769+
return new AsBuilder() {
5770+
5771+
@Override
5772+
public FunctionBuilder as(final String variableName) {
5773+
5774+
return new FunctionBuilder() {
5775+
5776+
@Override
5777+
public Map andApply(final AggregationExpression expression) {
5778+
return new Map(Fields.field(fieldReference), variableName, expression);
5779+
}
5780+
};
5781+
}
5782+
};
5783+
}
5784+
5785+
@Override
5786+
public AsBuilder itemsOf(final AggregationExpression source) {
5787+
5788+
return new AsBuilder() {
5789+
5790+
@Override
5791+
public FunctionBuilder as(final String variableName) {
5792+
5793+
return new FunctionBuilder() {
5794+
5795+
@Override
5796+
public Map andApply(final AggregationExpression expression) {
5797+
return new Map(source, variableName, expression);
5798+
}
5799+
};
5800+
}
5801+
};
5802+
}
5803+
};
5804+
};
5805+
5806+
@Override
5807+
public DBObject toDbObject(final AggregationOperationContext context) {
5808+
5809+
return toMap(new ExposedFieldsAggregationOperationContext(
5810+
ExposedFields.synthetic(Fields.fields(itemVariableName)), context) {
5811+
5812+
@Override
5813+
public FieldReference getReference(Field field) {
5814+
5815+
FieldReference ref = null;
5816+
try {
5817+
ref = context.getReference(field);
5818+
} catch (Exception e) {
5819+
// just ignore that one.
5820+
}
5821+
return ref != null ? ref : super.getReference(field);
5822+
}
5823+
});
5824+
}
5825+
5826+
private DBObject toMap(AggregationOperationContext context) {
5827+
5828+
BasicDBObject map = new BasicDBObject();
5829+
5830+
BasicDBObject input;
5831+
if (sourceArray instanceof Field) {
5832+
input = new BasicDBObject("input", context.getReference((Field) sourceArray).toString());
5833+
} else {
5834+
input = new BasicDBObject("input", ((AggregationExpression) sourceArray).toDbObject(context));
5835+
}
5836+
5837+
map.putAll(context.getMappedObject(input));
5838+
map.put("as", itemVariableName);
5839+
map.put("in", functionToApply.toDbObject(new NestedDelegatingExpressionAggregationOperationContext(context)));
5840+
5841+
return new BasicDBObject("$map", map);
5842+
}
5843+
5844+
interface ArrayOfBuilder {
5845+
5846+
/**
5847+
* Set the field that resolves to an array on which to apply the {@link AggregationExpression}.
5848+
*
5849+
* @param fieldReference must not be {@literal null}.
5850+
* @return
5851+
*/
5852+
AsBuilder itemsOf(String fieldReference);
5853+
5854+
/**
5855+
* Set the {@link AggregationExpression} that results in an array on which to apply the
5856+
* {@link AggregationExpression}.
5857+
*
5858+
* @param expression must not be {@literal null}.
5859+
* @return
5860+
*/
5861+
AsBuilder itemsOf(AggregationExpression expression);
5862+
}
5863+
5864+
interface AsBuilder {
5865+
5866+
/**
5867+
* Define the {@literal variableName} for addressing items within the array.
5868+
*
5869+
* @param variableName must not be {@literal null}.
5870+
* @return
5871+
*/
5872+
FunctionBuilder as(String variableName);
5873+
}
5874+
5875+
interface FunctionBuilder {
5876+
5877+
/**
5878+
* Creates new {@link Map} that applies the given {@link AggregationExpression} to each item of the referenced
5879+
* array and returns an array with the applied results.
5880+
*
5881+
* @param expression must not be {@literal null}.
5882+
* @return
5883+
*/
5884+
Map andApply(AggregationExpression expression);
5885+
}
5886+
}
5887+
57185888
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
@Deprecated
3737
public enum AggregationFunctionExpressions {
3838

39-
SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT;
39+
SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD;
4040

4141
/**
4242
* Returns an {@link AggregationExpression} build from the current {@link Enum} name and the given parameters.

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.LiteralOperators;
3737
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.SetOperators;
3838
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.StringOperators;
39+
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.VariableOperators;
3940
import org.springframework.data.mongodb.core.aggregation.ProjectionOperation.ProjectionOperationBuilder;
4041

4142
import com.mongodb.BasicDBObject;
@@ -1673,6 +1674,36 @@ public void shouldRenderNotAggregationExpression() {
16731674
assertThat(agg, is(JSON.parse("{ $project: { result: { $not: [ { $gt: [ \"$qty\", 250 ] } ] } } }")));
16741675
}
16751676

1677+
/**
1678+
* @see DATAMONGO-784
1679+
*/
1680+
@Test
1681+
public void shouldRenderMapAggregationExpression() {
1682+
1683+
DBObject agg = Aggregation.project()
1684+
.and(VariableOperators.map().itemsOf("quizzes").as("grade")
1685+
.andApply(AggregationFunctionExpressions.ADD.of(field("grade"), 2)))
1686+
.as("adjustedGrades").toDBObject(Aggregation.DEFAULT_CONTEXT);
1687+
1688+
assertThat(agg, is(JSON.parse(
1689+
"{ $project:{ adjustedGrades:{ $map: { input: \"$quizzes\", as: \"grade\",in: { $add: [ \"$$grade\", 2 ] }}}}}")));
1690+
}
1691+
1692+
/**
1693+
* @see DATAMONGO-784
1694+
*/
1695+
@Test
1696+
public void shouldRenderMapAggregationExpressionOnExpression() {
1697+
1698+
DBObject agg = Aggregation.project()
1699+
.and(VariableOperators.map().itemsOf(AggregationFunctionExpressions.SIZE.of("foo")).as("grade")
1700+
.andApply(AggregationFunctionExpressions.ADD.of(field("grade"), 2)))
1701+
.as("adjustedGrades").toDBObject(Aggregation.DEFAULT_CONTEXT);
1702+
1703+
assertThat(agg, is(JSON.parse(
1704+
"{ $project:{ adjustedGrades:{ $map: { input: { $size : [\"foo\"]}, as: \"grade\",in: { $add: [ \"$$grade\", 2 ] }}}}}")));
1705+
}
1706+
16761707
private static DBObject exctractOperation(String field, DBObject fromProjectClause) {
16771708
return (DBObject) fromProjectClause.get(field);
16781709
}

0 commit comments

Comments
 (0)