Skip to content

Commit dbb7f96

Browse files
committed
CSHARP-4882: Support Skip and Take in expressions.
1 parent a3d3d4f commit dbb7f96

File tree

9 files changed

+707
-29
lines changed

9 files changed

+707
-29
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,19 @@ public static AstExpression Add(params AstExpression[] args)
8686
return new AstConstantExpression(value);
8787
}
8888

89+
if (args.Length == 2)
90+
{
91+
if (args[0].IsZero())
92+
{
93+
return args[1];
94+
}
95+
96+
if (args[1].IsZero())
97+
{
98+
return args[0];
99+
}
100+
}
101+
89102
var flattenedArgs = FlattenNaryArgs(args, AstNaryOperator.Add);
90103
return new AstNaryExpression(AstNaryOperator.Add, flattenedArgs);
91104
}
@@ -565,11 +578,31 @@ public static AstExpression Max(AstExpression array)
565578
return new AstUnaryExpression(AstUnaryOperator.Max, array);
566579
}
567580

581+
public static AstExpression Max(AstExpression arg1, AstExpression arg2)
582+
{
583+
if (AllArgsAreConstantInt32s([arg1, arg2], out var values))
584+
{
585+
return values.Max();
586+
}
587+
588+
return new AstNaryExpression(AstNaryOperator.Max, [arg1, arg2]);
589+
}
590+
568591
public static AstExpression Min(AstExpression array)
569592
{
570593
return new AstUnaryExpression(AstUnaryOperator.Min, array);
571594
}
572595

596+
public static AstExpression Min(AstExpression arg1, AstExpression arg2)
597+
{
598+
if (AllArgsAreConstantInt32s([arg1, arg2], out var values))
599+
{
600+
return values.Min();
601+
}
602+
603+
return new AstNaryExpression(AstNaryOperator.Min, [arg1, arg2]);
604+
}
605+
573606
public static AstExpression Mod(AstExpression arg1, AstExpression arg2)
574607
{
575608
return new AstBinaryExpression(AstBinaryOperator.Mod, arg1, arg2);
@@ -806,6 +839,17 @@ public static AstExpression SubstrCP(AstExpression arg, AstExpression index, Ast
806839

807840
public static AstExpression Subtract(AstExpression arg1, AstExpression arg2)
808841
{
842+
if (AllArgsAreConstantInt32s([arg1, arg2], out var values))
843+
{
844+
var value = values[0] - values[1];
845+
return new AstConstantExpression(value);
846+
}
847+
848+
if (arg2.IsZero())
849+
{
850+
return arg1;
851+
}
852+
809853
return new AstBinaryExpression(AstBinaryOperator.Subtract, arg1, arg2);
810854
}
811855

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using MongoDB.Bson;
17+
18+
namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions
19+
{
20+
internal static class AstExpressionExtensions
21+
{
22+
public static bool IsInt32Constant(this AstExpression expression, out int value)
23+
{
24+
if (expression is AstConstantExpression constantExpression &&
25+
constantExpression.Value is BsonInt32 bsonInt32)
26+
{
27+
value = bsonInt32.Value;
28+
return true;
29+
}
30+
31+
value = default;
32+
return false;
33+
}
34+
35+
public static bool IsMaxInt32(this AstExpression expression)
36+
=> expression.IsInt32Constant(out var value) && value == int.MaxValue;
37+
38+
public static bool IsZero(this AstExpression expression)
39+
=> expression is AstConstantExpression constantExpression && constantExpression.Value == 0;
40+
}
41+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstSliceExpression.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public override AstNode Accept(AstNodeVisitor visitor)
4949
public override BsonValue Render()
5050
{
5151
var args =
52-
_position == null ?
52+
(_position == null || _position.IsZero()) ?
5353
new BsonArray { _array.Render(), _n.Render() } :
5454
new BsonArray { _array.Render(), _position.Render(), _n.Render() };
5555

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
17+
using System.Linq;
1618
using MongoDB.Bson;
1719
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
1820
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
@@ -399,6 +401,54 @@ public override AstNode VisitNotFilterOperation(AstNotFilterOperation node)
399401
return base.VisitNotFilterOperation(node);
400402
}
401403

404+
public override AstNode VisitSliceExpression(AstSliceExpression node)
405+
{
406+
node = (AstSliceExpression)base.VisitSliceExpression(node);
407+
var array = node.Array;
408+
var position = node.Position ?? 0; // map null to zero
409+
var n = node.N;
410+
411+
if (position.IsZero() && n.IsMaxInt32())
412+
{
413+
// { $slice : [array, 0, maxint] } => array
414+
return array;
415+
}
416+
417+
if (array is AstConstantExpression arrayConstant &&
418+
arrayConstant.Value is BsonArray bsonArrayConstant &&
419+
position.IsInt32Constant(out var positionValue) && positionValue >= 0 &&
420+
n.IsInt32Constant(out var nValue) && nValue >= 0)
421+
{
422+
// { slice : [array, position, n] } => array.Skip(position).Take(n) when all arguments are non-negative constants
423+
return AstExpression.Constant(new BsonArray(bsonArrayConstant.Skip(positionValue).Take(nValue)));
424+
}
425+
426+
if (array is AstSliceExpression inner &&
427+
(inner.Position ?? 0).IsInt32Constant(out var innerPosition) && innerPosition >= 0 &&
428+
inner.N.IsInt32Constant(out var innerN) && innerN >= 0 &&
429+
position.IsInt32Constant(out var outerPosition) && outerPosition >= 0 &&
430+
n.IsInt32Constant(out var outerN) && outerN >= 0)
431+
{
432+
// the following simplifcations are only valid when all position and n values are known to be non-negative (so they have to be constants)
433+
// { $slice : [{ $slice : [inner.Array, innerPosition, maxint] }, outerPosition, maxint] } => { $slice : [inner.Array, innerPosition + outerPosition, maxint] }
434+
// { $slice : [{ $slice : [inner.Array, innerPosition, maxint] }, outerPosition, outerN] } => { $slice : [inner.Array, innerPosition + outerPosition, outerN] }
435+
// { $slice : [{ $slice : [inner.Array, innerPosition, innerN] }, outerPosition, maxint] } => { $slice : [inner.Array, innerPosition + outerPosition, max(innerN - outerPosition, 0)] }
436+
// { $slice : [{ $slice : [inner.Array, innerPosition, innerN] }, outerPosition, outerN] } => { $slice : [inner.Array, innerPosition + outerPosition, min(max(innerN - outerPosition, 0), outerN)] }
437+
var combinedPosition = AstExpression.Add(innerPosition, outerPosition);
438+
var combinedN = (innerN, outerN) switch
439+
{
440+
(int.MaxValue, int.MaxValue) => int.MaxValue, // check whether both are int.MaxValue before checking one at a time
441+
(int.MaxValue, _) => outerN,
442+
(_, int.MaxValue) => Math.Max(innerN - outerPosition, 0),
443+
_ => Math.Min(Math.Max(innerN - outerPosition, 0), outerN)
444+
};
445+
446+
return AstExpression.Slice(inner.Array, combinedPosition, combinedN);
447+
}
448+
449+
return node;
450+
}
451+
402452
public override AstNode VisitUnaryExpression(AstUnaryExpression node)
403453
{
404454
// { $first : <arg> } => { $arrayElemAt : [<arg>, 0] } (or -1 for $last)

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/PartialEvaluator.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using System.Linq;
1919
using System.Linq.Expressions;
2020
using System.Reflection;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
2122

2223
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2324
{
@@ -229,13 +230,14 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
229230
var result = base.VisitMethodCall(node);
230231

231232
var method = node.Method;
232-
if (IsCustomLinqExtensionMethod(method))
233+
if (IsCustomLinqExtensionMethod(method) ||
234+
method.Is(QueryableMethod.AsQueryable))
233235
{
234236
_cannotBeEvaluated = true;
235237
}
236238

237239
return result;
238-
}
240+
}
239241

240242
private bool IsCustomLinqExtensionMethod(MethodInfo method)
241243
{

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ public static AggregationExpression Translate(TranslationContext context, Method
8181
case "StrLenBytes": return StrLenBytesMethodToAggregationExpressionTranslator.Translate(context, expression);
8282
case "Subtract": return SubtractMethodToAggregationExpressionTranslator.Translate(context, expression);
8383
case "Sum": return SumMethodToAggregationExpressionTranslator.Translate(context, expression);
84-
case "Take": return TakeMethodToAggregationExpressionTranslator.Translate(context, expression);
8584
case "ToArray": return ToArrayMethodToAggregationExpressionTranslator.Translate(context, expression);
8685
case "ToList": return ToListMethodToAggregationExpressionTranslator.Translate(context, expression);
8786
case "ToString": return ToStringMethodToAggregationExpressionTranslator.Translate(context, expression);
@@ -173,6 +172,10 @@ public static AggregationExpression Translate(TranslationContext context, Method
173172
case "ThenByDescending":
174173
return OrderByMethodToAggregationExpressionTranslator.Translate(context, expression);
175174

175+
case "Skip":
176+
case "Take":
177+
return SkipOrTakeMethodToAggregationExpressionTranslator.Translate(context, expression);
178+
176179
case "StandardDeviationPopulation":
177180
case "StandardDeviationSample":
178181
return StandardDeviationMethodsToAggregationExpressionTranslator.Translate(context, expression);
Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222

2323
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
2424
{
25-
internal static class TakeMethodToAggregationExpressionTranslator
25+
internal static class SkipOrTakeMethodToAggregationExpressionTranslator
2626
{
27-
private static MethodInfo[] __takeMethods =
27+
private static MethodInfo[] __skipOrTakeMethods =
2828
{
29+
EnumerableMethod.Skip,
2930
EnumerableMethod.Take,
30-
QueryableMethod.Take
31+
QueryableMethod.Skip,
32+
QueryableMethod.Take,
3133
};
3234

3335
private static MethodInfo[] __skipMethods =
@@ -36,40 +38,37 @@ internal static class TakeMethodToAggregationExpressionTranslator
3638
QueryableMethod.Skip
3739
};
3840

41+
private static MethodInfo[] __takeMethods =
42+
{
43+
EnumerableMethod.Take,
44+
QueryableMethod.Take
45+
};
46+
3947
public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
4048
{
4149
var method = expression.Method;
4250
var arguments = expression.Arguments;
4351

44-
if (method.IsOneOf(__takeMethods))
52+
if (method.IsOneOf(__skipOrTakeMethods))
4553
{
4654
var sourceExpression = arguments[0];
47-
var countExpression = arguments[1];
48-
Expression skipExpression = null;
49-
if (sourceExpression is MethodCallExpression sourceSkipExpression && sourceSkipExpression.Method.IsOneOf(__skipMethods))
50-
{
51-
sourceExpression = sourceSkipExpression.Arguments[0];
52-
skipExpression = sourceSkipExpression.Arguments[1];
53-
}
54-
5555
var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
5656
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
57+
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
58+
var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);
5759

60+
var countExpression = arguments[1];
5861
var countTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
59-
AstExpression ast;
60-
if (skipExpression == null)
61-
{
62-
ast = AstExpression.Slice(sourceTranslation.Ast, countTranslation.Ast);
63-
}
64-
else
62+
var countAst = AstExpression.Max(countTranslation.Ast, 0); // map negative numbers to 0
63+
64+
var ast = method switch
6565
{
66-
var skipTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, skipExpression);
67-
ast = AstExpression.Slice(sourceTranslation.Ast, skipTranslation.Ast, countTranslation.Ast);
68-
}
69-
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
70-
var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);
66+
_ when method.IsOneOf(__skipMethods) => AstExpression.Slice(sourceTranslation.Ast, countAst, int.MaxValue),
67+
_ when method.IsOneOf(__takeMethods) => AstExpression.Slice(sourceTranslation.Ast, countAst),
68+
_ => throw new ExpressionNotSupportedException(expression)
69+
};
7170

72-
return new AggregationExpression(expression, ast, serializer);
71+
return new AggregationExpression(expression, ast, resultSerializer);
7372
}
7473

7574
throw new ExpressionNotSupportedException(expression);

0 commit comments

Comments
 (0)