Skip to content

Use relaxed type mapping for aggregations by default. #3545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>3.2.0-SNAPSHOT</version>
<version>3.2.0-GH-3542-SNAPSHOT</version>
<packaging>pom</packaging>

<name>Spring Data MongoDB</name>
Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-benchmarks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>3.2.0-SNAPSHOT</version>
<version>3.2.0-GH-3542-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-distribution/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>3.2.0-SNAPSHOT</version>
<version>3.2.0-GH-3542-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>3.2.0-SNAPSHOT</version>
<version>3.2.0-GH-3542-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions.DomainTypeMapping;
import org.springframework.data.mongodb.core.aggregation.CountOperation;
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
Expand All @@ -36,6 +37,7 @@
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.util.Lazy;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
Expand All @@ -52,41 +54,44 @@ class AggregationUtil {

QueryMapper queryMapper;
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
Lazy<AggregationOperationContext> untypedMappingContext;

AggregationUtil(QueryMapper queryMapper,
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext) {

this.queryMapper = queryMapper;
this.mappingContext = mappingContext;
this.untypedMappingContext = Lazy
.of(() -> new RelaxedTypeBasedAggregationOperationContext(Object.class, mappingContext, queryMapper));
}

/**
* Prepare the {@link AggregationOperationContext} for a given aggregation by either returning the context itself it
* is not {@literal null}, create a {@link TypeBasedAggregationOperationContext} if the aggregation contains type
* information (is a {@link TypedAggregation}) or use the {@link Aggregation#DEFAULT_CONTEXT}.
*
* @param aggregation must not be {@literal null}.
* @param context can be {@literal null}.
* @return the root {@link AggregationOperationContext} to use.
*/
AggregationOperationContext prepareAggregationContext(Aggregation aggregation,
@Nullable AggregationOperationContext context) {
AggregationOperationContext createAggregationContext(Aggregation aggregation, @Nullable Class<?> inputType) {

if (context != null) {
return context;
if (aggregation.getOptions().getDomainTypeMapping() == DomainTypeMapping.NONE) {
return Aggregation.DEFAULT_CONTEXT;
}

if (!(aggregation instanceof TypedAggregation)) {
return new RelaxedTypeBasedAggregationOperationContext(Object.class, mappingContext, queryMapper);
}

Class<?> inputType = ((TypedAggregation) aggregation).getInputType();
if(inputType == null) {
return untypedMappingContext.get();
}

if (aggregation.getOptions().getDomainTypeMapping() == DomainTypeMapping.STRICT
&& !aggregation.getPipeline().containsUnionWith()) {
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
}

if (aggregation.getPipeline().containsUnionWith()) {
return new RelaxedTypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
}

return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
inputType = ((TypedAggregation) aggregation).getInputType();
if (aggregation.getOptions().getDomainTypeMapping() == DomainTypeMapping.STRICT
&& !aggregation.getPipeline().containsUnionWith()) {
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
}

return new RelaxedTypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.springframework.data.mongodb.core.BulkOperations.BulkMode;
import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext;
import org.springframework.data.mongodb.core.EntityOperations.AdaptibleEntity;
import org.springframework.data.mongodb.core.QueryOperations.AggregateContext;
import org.springframework.data.mongodb.core.QueryOperations.CountContext;
import org.springframework.data.mongodb.core.QueryOperations.DeleteContext;
import org.springframework.data.mongodb.core.QueryOperations.DistinctQueryContext;
Expand Down Expand Up @@ -1988,7 +1989,7 @@ public <O> AggregationResults<O> aggregate(TypedAggregation<?> aggregation, Stri
public <O> AggregationResults<O> aggregate(Aggregation aggregation, Class<?> inputType, Class<O> outputType) {

return aggregate(aggregation, getCollectionName(inputType), outputType,
new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper));
queryOperations.createAggregationContext(aggregation, inputType).getAggregationOperationContext());
}

/* (non-Javadoc)
Expand Down Expand Up @@ -2095,9 +2096,12 @@ protected <O> AggregationResults<O> aggregate(Aggregation aggregation, String co
Assert.notNull(aggregation, "Aggregation pipeline must not be null!");
Assert.notNull(outputType, "Output type must not be null!");

AggregationOperationContext contextToUse = new AggregationUtil(queryMapper, mappingContext)
.prepareAggregationContext(aggregation, context);
return doAggregate(aggregation, collectionName, outputType, contextToUse);
return doAggregate(aggregation, collectionName, outputType, queryOperations.createAggregationContext(aggregation, context));
}

private <O> AggregationResults<O> doAggregate(Aggregation aggregation, String collectionName, Class<O> outputType,
AggregateContext context) {
return doAggregate(aggregation, collectionName, outputType, context.getAggregationOperationContext());
}

@SuppressWarnings("ConstantConditions")
Expand Down Expand Up @@ -2185,11 +2189,10 @@ protected <O> CloseableIterator<O> aggregateStream(Aggregation aggregation, Stri
Assert.notNull(outputType, "Output type must not be null!");
Assert.isTrue(!aggregation.getOptions().isExplain(), "Can't use explain option with streaming!");

AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);
AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context);
AggregateContext aggregateContext = queryOperations.createAggregationContext(aggregation, context);

AggregationOptions options = aggregation.getOptions();
List<Document> pipeline = aggregationUtil.createPipeline(aggregation, rootContext);
List<Document> pipeline = aggregateContext.getAggregationPipeline();

if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Streaming aggregation: {} in collection {}", serializeToJsonSafely(pipeline), collectionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@
import org.springframework.data.mongodb.core.MappedDocument.MappedUpdate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.convert.UpdateMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
Expand All @@ -48,6 +52,7 @@
import org.springframework.data.mongodb.core.query.UpdateDefinition.ArrayFilter;
import org.springframework.data.mongodb.util.BsonUtils;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.util.Lazy;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
Expand Down Expand Up @@ -194,6 +199,31 @@ DeleteContext deleteSingleContext(Query query) {
return new DeleteContext(query, false);
}

/**
* Create a new {@link AggregateContext} for the given {@link Aggregation}.
*
* @param aggregation must not be {@literal null}.
* @param inputType fallback mapping type in case of untyped aggregation. Can be {@literal null}.
* @return new instance of {@link AggregateContext}.
* @since 3.2
*/
AggregateContext createAggregationContext(Aggregation aggregation, @Nullable Class<?> inputType) {
return new AggregateContext(aggregation, inputType);
}

/**
* Create a new {@link AggregateContext} for the given {@link Aggregation}.
*
* @param aggregation must not be {@literal null}.
* @param aggregationOperationContext the {@link AggregationOperationContext} to use. Can be {@literal null}.
* @return new instance of {@link AggregateContext}.
* @since 3.2
*/
AggregateContext createAggregationContext(Aggregation aggregation,
@Nullable AggregationOperationContext aggregationOperationContext) {
return new AggregateContext(aggregation, aggregationOperationContext);
}

/**
* {@link QueryContext} encapsulates common tasks required to convert a {@link Query} into its MongoDB document
* representation, mapping fieldnames, as well as determinging and applying {@link Collation collations}.
Expand Down Expand Up @@ -341,7 +371,8 @@ private DistinctQueryContext(@Nullable Object query, String fieldName) {
}

@Override
Document getMappedFields(@Nullable MongoPersistentEntity<?> entity, Class<?> targetType, ProjectionFactory projectionFactory) {
Document getMappedFields(@Nullable MongoPersistentEntity<?> entity, Class<?> targetType,
ProjectionFactory projectionFactory) {
return getMappedFields(entity);
}

Expand Down Expand Up @@ -709,7 +740,8 @@ List<Document> getUpdatePipeline(@Nullable Class<?> domainType) {

Class<?> type = domainType != null ? domainType : Object.class;

AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(type, mappingContext, queryMapper);
AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(type, mappingContext,
queryMapper);
return aggregationUtil.createPipeline((AggregationUpdate) update, context);
}

Expand Down Expand Up @@ -759,4 +791,105 @@ boolean isMulti() {
return multi;
}
}

/**
* A context class that encapsulates common tasks required when running {@literal aggregations}.
*
* @since 3.2
*/
class AggregateContext {

private Aggregation aggregation;
private Lazy<AggregationOperationContext> aggregationOperationContext;
private Lazy<List<Document>> pipeline;
private @Nullable Class<?> inputType;

/**
* Creates new instance of {@link AggregateContext} extracting the input type from either the
* {@link org.springframework.data.mongodb.core.aggregation.Aggregation} in case of a {@link TypedAggregation} or
* the given {@literal aggregationOperationContext} if present. <br />
* Creates a new {@link AggregationOperationContext} if none given, based on the {@link Aggregation} input type and
* the desired {@link AggregationOptions#getDomainTypeMapping() domain type mapping}. <br />
* Pipelines are mapped on first access of {@link #getAggregationPipeline()} and cached for reuse.
*
* @param aggregation the source aggregation.
* @param aggregationOperationContext can be {@literal null}.
*/
AggregateContext(Aggregation aggregation, @Nullable AggregationOperationContext aggregationOperationContext) {

this.aggregation = aggregation;
if (aggregation instanceof TypedAggregation) {
this.inputType = ((TypedAggregation) aggregation).getInputType();
} else if (aggregationOperationContext instanceof TypeBasedAggregationOperationContext) {
this.inputType = ((TypeBasedAggregationOperationContext) aggregationOperationContext).getType();
}
this.aggregationOperationContext = Lazy.of(() -> aggregationOperationContext != null ? aggregationOperationContext
: aggregationUtil.createAggregationContext(aggregation, getInputType()));
this.pipeline = Lazy.of(() -> aggregationUtil.createPipeline(this.aggregation, getAggregationOperationContext()));
}

/**
* Creates new instance of {@link AggregateContext} extracting the input type from either the
* {@link org.springframework.data.mongodb.core.aggregation.Aggregation} in case of a {@link TypedAggregation} or
* the given {@literal aggregationOperationContext} if present. <br />
* Creates a new {@link AggregationOperationContext} based on the {@link Aggregation} input type and the desired
* {@link AggregationOptions#getDomainTypeMapping() domain type mapping}. <br />
* Pipelines are mapped on first access of {@link #getAggregationPipeline()} and cached for reuse.
*
* @param aggregation the source aggregation.
* @param inputType can be {@literal null}.
*/
AggregateContext(Aggregation aggregation, @Nullable Class<?> inputType) {

this.aggregation = aggregation;

if (aggregation instanceof TypedAggregation) {
this.inputType = ((TypedAggregation) aggregation).getInputType();
} else {
this.inputType = inputType;
}

this.aggregationOperationContext = Lazy
.of(() -> aggregationUtil.createAggregationContext(aggregation, getInputType()));
this.pipeline = Lazy.of(() -> aggregationUtil.createPipeline(this.aggregation, getAggregationOperationContext()));
}

/**
* Obtain the already mapped pipeline.
*
* @return never {@literal null}.
*/
List<Document> getAggregationPipeline() {
return pipeline.get();
}

/**
* @return {@literal true} if the last aggregation stage is either {@literal $out} or {@literal $merge}.
* @see AggregationPipeline#isOutOrMerge()
*/
boolean isOutOrMerge() {
return aggregation.getPipeline().isOutOrMerge();
}

/**
* Obtain the {@link AggregationOperationContext} used for mapping the pipeline.
*
* @return never {@literal null}.
*/
AggregationOperationContext getAggregationOperationContext() {
return aggregationOperationContext.get();
}

/**
* @return the input type to map the pipeline against. Can be {@literal null}.
*/
@Nullable
Class<?> getInputType() {
return inputType;
}

Document getAggregationCommand(String collectionName) {
return aggregationUtil.createCommand(collectionName, aggregation, getAggregationOperationContext());
}
}
}
Loading