Skip to content

Commit 3906bdc

Browse files
committed
CSHARP-4872: Add support for Append and Prepend in aggregate expressions.
1 parent 720cdf5 commit 3906bdc

File tree

5 files changed

+211
-0
lines changed

5 files changed

+211
-0
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ internal static class EnumerableMethod
3232
private static readonly MethodInfo __all;
3333
private static readonly MethodInfo __any;
3434
private static readonly MethodInfo __anyWithPredicate;
35+
private static readonly MethodInfo __append;
3536
private static readonly MethodInfo __averageDecimal;
3637
private static readonly MethodInfo __averageDecimalWithSelector;
3738
private static readonly MethodInfo __averageDouble;
@@ -138,6 +139,7 @@ internal static class EnumerableMethod
138139
private static readonly MethodInfo __ofType;
139140
private static readonly MethodInfo __orderBy;
140141
private static readonly MethodInfo __orderByDescending;
142+
private static readonly MethodInfo __prepend;
141143
private static readonly MethodInfo __range;
142144
private static readonly MethodInfo __repeat;
143145
private static readonly MethodInfo __reverse;
@@ -195,6 +197,7 @@ static EnumerableMethod()
195197
__all = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.All(predicate));
196198
__any = ReflectionInfo.Method((IEnumerable<object> source) => source.Any());
197199
__anyWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.Any(predicate));
200+
__append = ReflectionInfo.Method((IEnumerable<object> source, object element) => source.Append(element));
198201
__averageDecimal = ReflectionInfo.Method((IEnumerable<decimal> source) => source.Average());
199202
__averageDecimalWithSelector = ReflectionInfo.Method((IEnumerable<object> source, Func<object, decimal> selector) => source.Average(selector));
200203
__averageDouble = ReflectionInfo.Method((IEnumerable<double> source) => source.Average());
@@ -301,6 +304,7 @@ static EnumerableMethod()
301304
__ofType = ReflectionInfo.Method((IEnumerable source) => source.OfType<object>());
302305
__orderBy = ReflectionInfo.Method((IEnumerable<object> source, Func<object, object> keySelector) => source.OrderBy(keySelector));
303306
__orderByDescending = ReflectionInfo.Method((IEnumerable<object> source, Func<object, object> keySelector) => source.OrderByDescending(keySelector));
307+
__prepend = ReflectionInfo.Method((IEnumerable<object> source, object element) => source.Prepend(element));
304308
__range = ReflectionInfo.Method((int start, int count) => Enumerable.Range(start, count));
305309
__repeat = ReflectionInfo.Method((object element, int count) => Enumerable.Repeat(element, count));
306310
__reverse = ReflectionInfo.Method((IEnumerable<object> source) => source.Reverse());
@@ -357,6 +361,7 @@ static EnumerableMethod()
357361
public static MethodInfo All => __all;
358362
public static MethodInfo Any => __any;
359363
public static MethodInfo AnyWithPredicate => __anyWithPredicate;
364+
public static MethodInfo Append => __append;
360365
public static MethodInfo AverageDecimal => __averageDecimal;
361366
public static MethodInfo AverageDecimalWithSelector => __averageDecimalWithSelector;
362367
public static MethodInfo AverageDouble => __averageDouble;
@@ -463,6 +468,7 @@ static EnumerableMethod()
463468
public static MethodInfo OfType => __ofType;
464469
public static MethodInfo OrderBy => __orderBy;
465470
public static MethodInfo OrderByDescending => __orderByDescending;
471+
public static MethodInfo Prepend => __prepend;
466472
public static MethodInfo Range => __range;
467473
public static MethodInfo Repeat => __repeat;
468474
public static MethodInfo Reverse => __reverse;

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ internal static class QueryableMethod
3030
private static readonly MethodInfo __all;
3131
private static readonly MethodInfo __any;
3232
private static readonly MethodInfo __anyWithPredicate;
33+
private static readonly MethodInfo __append;
3334
private static readonly MethodInfo __asQueryable;
3435
private static readonly MethodInfo __averageDecimal;
3536
private static readonly MethodInfo __averageDecimalWithSelector;
@@ -86,6 +87,7 @@ internal static class QueryableMethod
8687
private static readonly MethodInfo __ofType;
8788
private static readonly MethodInfo __orderBy;
8889
private static readonly MethodInfo __orderByDescending;
90+
private static readonly MethodInfo __prepend;
8991
private static readonly MethodInfo __reverse;
9092
private static readonly MethodInfo __select;
9193
private static readonly MethodInfo __selectMany;
@@ -136,6 +138,7 @@ static QueryableMethod()
136138
__all = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.All(predicate));
137139
__any = ReflectionInfo.Method((IQueryable<object> source) => source.Any());
138140
__anyWithPredicate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.Any(predicate));
141+
__append = ReflectionInfo.Method((IQueryable<object> source, object element) => source.Append(element));
139142
__asQueryable = ReflectionInfo.Method((IEnumerable<object> source) => source.AsQueryable());
140143
__averageDecimal = ReflectionInfo.Method((IQueryable<decimal> source) => source.Average());
141144
__averageDecimalWithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, decimal>> selector) => source.Average(selector));
@@ -192,6 +195,7 @@ static QueryableMethod()
192195
__ofType = ReflectionInfo.Method((IQueryable source) => source.OfType<object>());
193196
__orderBy = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object>> keySelector) => source.OrderBy(keySelector));
194197
__orderByDescending = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object>> keySelector) => source.OrderByDescending(keySelector));
198+
__prepend = ReflectionInfo.Method((IQueryable<object> source, object element) => source.Prepend(element));
195199
__reverse = ReflectionInfo.Method((IQueryable<object> source) => source.Reverse());
196200
__select = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object>> selector) => source.Select(selector));
197201
__selectMany = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, IEnumerable<object>>> selector) => source.SelectMany(selector));
@@ -241,6 +245,7 @@ static QueryableMethod()
241245
public static MethodInfo All => __all;
242246
public static MethodInfo Any => __any;
243247
public static MethodInfo AnyWithPredicate => __anyWithPredicate;
248+
public static MethodInfo Append => __append;
244249
public static MethodInfo AsQueryable => __asQueryable;
245250
public static MethodInfo AverageDecimal => __averageDecimal;
246251
public static MethodInfo AverageDecimalWithSelector => __averageDecimalWithSelector;
@@ -297,6 +302,7 @@ static QueryableMethod()
297302
public static MethodInfo OfType => __ofType;
298303
public static MethodInfo OrderBy => __orderBy;
299304
public static MethodInfo OrderByDescending => __orderByDescending;
305+
public static MethodInfo Prepend => __prepend;
300306
public static MethodInfo Reverse => __reverse;
301307
public static MethodInfo Select => __select;
302308
public static MethodInfo SelectMany => __selectMany;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ public static AggregationExpression Translate(TranslationContext context, Method
120120
case "AddYears":
121121
return DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.Translate(context, expression);
122122

123+
case "Append":
124+
case "Prepend":
125+
return AppendOrPrependMethodToAggregationExpressionTranslator.Translate(context, expression);
126+
123127
case "Bottom":
124128
case "BottomN":
125129
case "FirstN":
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq.Expressions;
17+
using System.Reflection;
18+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
19+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
22+
23+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
24+
{
25+
internal static class AppendOrPrependMethodToAggregationExpressionTranslator
26+
{
27+
private static readonly MethodInfo[] __appendOrPrependMethods =
28+
{
29+
EnumerableMethod.Append,
30+
EnumerableMethod.Prepend,
31+
QueryableMethod.Append,
32+
QueryableMethod.Prepend
33+
};
34+
35+
private static readonly MethodInfo[] __appendMethods =
36+
{
37+
EnumerableMethod.Append,
38+
QueryableMethod.Append
39+
};
40+
41+
public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
42+
{
43+
var method = expression.Method;
44+
var arguments = expression.Arguments;
45+
46+
if (method.IsOneOf(__appendOrPrependMethods))
47+
{
48+
var sourceExpression = arguments[0];
49+
var elementExpression = arguments[1];
50+
51+
var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
52+
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
53+
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
54+
55+
AggregationExpression elementTranslation;
56+
if (elementExpression is ConstantExpression elementConstantExpression)
57+
{
58+
var value = elementConstantExpression.Value;
59+
var serializedValue = SerializationHelper.SerializeValue(itemSerializer, value);
60+
elementTranslation = new AggregationExpression(elementExpression, AstExpression.Constant(serializedValue), itemSerializer);
61+
}
62+
else
63+
{
64+
elementTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, elementExpression);
65+
if (!elementTranslation.Serializer.Equals(itemSerializer))
66+
{
67+
throw new ExpressionNotSupportedException(expression, because: "argument serializers are not compatible");
68+
}
69+
}
70+
71+
var ast = method.IsOneOf(__appendMethods) ?
72+
AstExpression.ConcatArrays(sourceTranslation.Ast, AstExpression.ComputedArray(elementTranslation.Ast)) :
73+
AstExpression.ConcatArrays(AstExpression.ComputedArray(elementTranslation.Ast), sourceTranslation.Ast);
74+
var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);
75+
76+
return new AggregationExpression(expression, ast, serializer);
77+
}
78+
79+
throw new ExpressionNotSupportedException(expression);
80+
}
81+
}
82+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using FluentAssertions;
18+
using MongoDB.TestHelpers.XunitExtensions;
19+
using Xunit;
20+
21+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
22+
{
23+
public class CSharp4872Tests : Linq3IntegrationTest
24+
{
25+
[Theory]
26+
[ParameterAttributeData]
27+
public void Append_constant_should_work(
28+
[Values(false, true)] bool withNestedAsQueryable)
29+
{
30+
var collection = GetCollection();
31+
32+
var queryable = withNestedAsQueryable ?
33+
collection.AsQueryable().Select(x => x.A.AsQueryable().Append(4).ToList()) :
34+
collection.AsQueryable().Select(x => x.A.Append(4).ToList());
35+
36+
var stages = Translate(collection, queryable);
37+
AssertStages(stages, "{ $project : { _v : { $concatArrays : ['$A', [4]] }, _id : 0 } }");
38+
39+
var result = queryable.Single();
40+
result.Should().Equal(1, 2, 3, 4);
41+
}
42+
43+
[Theory]
44+
[ParameterAttributeData]
45+
public void Append_expression_should_work(
46+
[Values(false, true)] bool withNestedAsQueryable)
47+
{
48+
var collection = GetCollection();
49+
50+
var queryable = withNestedAsQueryable ?
51+
collection.AsQueryable().Select(x => x.A.AsQueryable().Append(x.B).ToList()) :
52+
collection.AsQueryable().Select(x => x.A.Append(x.B).ToList());
53+
54+
var stages = Translate(collection, queryable);
55+
AssertStages(stages, "{ $project : { _v : { $concatArrays : ['$A', ['$B']] }, _id : 0 } }");
56+
57+
var result = queryable.Single();
58+
result.Should().Equal(1, 2, 3, 4);
59+
}
60+
61+
[Theory]
62+
[ParameterAttributeData]
63+
public void Prepend_constant_should_work(
64+
[Values(false, true)] bool withNestedAsQueryable)
65+
{
66+
var collection = GetCollection();
67+
68+
var queryable = withNestedAsQueryable ?
69+
collection.AsQueryable().Select(x => x.A.AsQueryable().Prepend(4).ToList()) :
70+
collection.AsQueryable().Select(x => x.A.Prepend(4).ToList());
71+
72+
var stages = Translate(collection, queryable);
73+
AssertStages(stages, "{ $project : { _v : { $concatArrays : [[4], '$A'] }, _id : 0 } }");
74+
75+
var result = queryable.Single();
76+
result.Should().Equal(4, 1, 2, 3);
77+
}
78+
79+
[Theory]
80+
[ParameterAttributeData]
81+
public void Prepend_expression_should_work(
82+
[Values(false, true)] bool withNestedAsQueryable)
83+
{
84+
var collection = GetCollection();
85+
86+
var queryable = withNestedAsQueryable ?
87+
collection.AsQueryable().Select(x => x.A.AsQueryable().Prepend(x.B).ToList()) :
88+
collection.AsQueryable().Select(x => x.A.Prepend(x.B).ToList());
89+
90+
var stages = Translate(collection, queryable);
91+
AssertStages(stages, "{ $project : { _v : { $concatArrays : [['$B'], '$A'] }, _id : 0 } }");
92+
93+
var result = queryable.Single();
94+
result.Should().Equal(4, 1, 2, 3);
95+
}
96+
97+
private IMongoCollection<C> GetCollection()
98+
{
99+
var collection = GetCollection<C>("test");
100+
CreateCollection(
101+
collection,
102+
new C { Id = 1, A = [1, 2, 3], B = 4 });
103+
return collection;
104+
}
105+
106+
private class C
107+
{
108+
public int Id { get; set; }
109+
public int[] A { get; set; }
110+
public int B { get; set; }
111+
}
112+
}
113+
}

0 commit comments

Comments
 (0)