Skip to content

Commit 030d1ad

Browse files
authored
CSHARP-5436: Optimize special case of Any with constant array and fie… (#1585)
1 parent 858abaa commit 030d1ad

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,53 @@ static bool OperatorMapsNullToNull(AstUnaryOperator @operator)
105105
}
106106
}
107107

108+
public override AstNode VisitExprFilter(AstExprFilter node)
109+
{
110+
var optimizedNode = base.VisitExprFilter(node);
111+
112+
if (optimizedNode is AstExprFilter exprFilter &&
113+
exprFilter.Expression is AstUnaryExpression unaryExpression &&
114+
unaryExpression.Operator == AstUnaryOperator.AnyElementTrue &&
115+
unaryExpression.Arg is AstMapExpression mapExpression &&
116+
mapExpression.Input is AstConstantExpression inputConstant &&
117+
inputConstant.Value is BsonArray inputArrayValue &&
118+
mapExpression.In is AstBinaryExpression inBinaryExpression &&
119+
inBinaryExpression.Operator == AstBinaryOperator.Eq &&
120+
TryGetBinaryExpressionArguments(inBinaryExpression, out AstFieldPathExpression fieldPathExpression, out AstVarExpression varExpression) &&
121+
fieldPathExpression.Path.Length > 1 && fieldPathExpression.Path[0] == '$' && fieldPathExpression.Path[1] != '$' &&
122+
varExpression == mapExpression.As)
123+
{
124+
// { $expr : { $anyElementTrue : { $map : { input : <constantArray>, as : "<var>", in : { $eq : ["$<dottedFieldName>", "$$<var>"] } } } } }
125+
// => { "<dottedFieldName>" : { $in : <constantArray> } }
126+
return AstFilter.In(AstFilter.Field(fieldPathExpression.Path.Substring(1)), inputArrayValue);
127+
}
128+
129+
return optimizedNode;
130+
131+
static bool TryGetBinaryExpressionArguments<T1, T2>(AstBinaryExpression binaryExpression, out T1 arg1, out T2 arg2)
132+
where T1 : AstNode
133+
where T2 : AstNode
134+
{
135+
if (binaryExpression.Arg1 is T1 arg1AsT1 && binaryExpression.Arg2 is T2 arg2AsT2)
136+
{
137+
arg1 = arg1AsT1;
138+
arg2 = arg2AsT2;
139+
return true;
140+
}
141+
142+
if (binaryExpression.Arg1 is T2 arg1AsT2 && binaryExpression.Arg1 is T1 arg2AsT1)
143+
{
144+
arg1 = arg2AsT1;
145+
arg2 = arg1AsT2;
146+
return true;
147+
}
148+
149+
arg1 = null;
150+
arg2 = null;
151+
return false;
152+
}
153+
}
154+
108155
public override AstNode VisitFieldOperationFilter(AstFieldOperationFilter node)
109156
{
110157
node = (AstFieldOperationFilter)base.VisitFieldOperationFilter(node);

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslatorTests.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ public void Any_with_predicate_should_work(
5959
results.Should().Equal(false, false, false, true);
6060
}
6161

62+
[Fact]
63+
public void Any_on_constant_array_should_be_optimized()
64+
{
65+
var collection = CreateCollection();
66+
67+
var obj = new[] { 1, 2, 3 };
68+
var queryable = collection.AsQueryable()
69+
.Where(x => obj.Any(y => x.Id == y));
70+
71+
var stages = Translate(collection, queryable);
72+
AssertStages(stages, "{ $match : { _id : { $in : [1, 2, 3] } } }");
73+
74+
var results = queryable.ToList();
75+
results.Select(x => x.Id).Should().Equal(1, 2, 3);
76+
}
77+
6278
private IMongoCollection<C> CreateCollection()
6379
{
6480
var collection = GetCollection<C>("test");

0 commit comments

Comments
 (0)