Skip to content

Commit 8612d62

Browse files
committed
CSHARP-4248: Add custom Lookup LINQ extension methods to support $lookup.
1 parent fb137fb commit 8612d62

18 files changed

+1098
-17
lines changed

src/MongoDB.Driver/Core/Misc/Feature.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ public class Feature
7777
private static readonly Feature __legacyWireProtocol = new Feature("LegacyWireProtocol", WireVersion.Zero, WireVersion.Server51);
7878
private static readonly Feature __listDatabasesAuthorizedDatabases = new Feature("ListDatabasesAuthorizedDatabases", WireVersion.Server40);
7979
private static readonly Feature __loadBalancedMode = new Feature("LoadBalancedMode", WireVersion.Server50);
80+
private static readonly Feature __loookupConciseSyntax = new Feature("LoookupConciseSyntax", WireVersion.Server50);
81+
private static readonly Feature __loookupDocuments= new Feature("LoookupDocuments", WireVersion.Server60);
8082
private static readonly Feature __mmapV1StorageEngine = new Feature("MmapV1StorageEngine", WireVersion.Zero, WireVersion.Server42);
8183
private static readonly Feature __pickAccumulatorsNewIn52 = new Feature("PickAccumulatorsNewIn52", WireVersion.Server52);
8284
private static readonly Feature __regexMatch = new Feature("RegexMatch", WireVersion.Server42);
@@ -358,6 +360,16 @@ public class Feature
358360
/// </summary>
359361
public static Feature LoadBalancedMode => __loadBalancedMode;
360362

363+
/// <summary>
364+
/// Gets the lookup concise syntax feature.
365+
/// </summary>
366+
public static Feature LookupConciseSyntax => __loookupConciseSyntax;
367+
368+
/// <summary>
369+
/// Gets the lookup documents feature.
370+
/// </summary>
371+
public static Feature LookupDocuments => __loookupDocuments;
372+
361373
/// <summary>
362374
/// Gets the mmapv1 storage engine feature.
363375
/// </summary>

src/MongoDB.Driver/IMongoCollection.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323

2424
namespace MongoDB.Driver
2525
{
26+
internal interface IMongoCollection
27+
{
28+
CollectionNamespace CollectionNamespace { get; }
29+
30+
IBsonSerializer DocumentSerializer { get; }
31+
}
32+
2633
/// <summary>
2734
/// Represents a typed collection in MongoDB.
2835
/// </summary>
@@ -31,7 +38,7 @@ namespace MongoDB.Driver
3138
/// <see cref="MongoCollectionBase{TDocument}"/>.
3239
/// </remarks>
3340
/// <typeparam name="TDocument">The type of the documents stored in the collection.</typeparam>
34-
public interface IMongoCollection<TDocument>
41+
public interface IMongoCollection<TDocument> // TODO: derive from IMongoCollection in 4.0
3542
{
3643
/// <summary>
3744
/// Gets the namespace of the collection.

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Stages/AstLookupWithMatchingFieldsAndPipelineStage.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public AstLookupWithMatchingFieldsAndPipelineStage(
3939
AstPipeline pipeline,
4040
string @as)
4141
{
42-
_from = Ensure.IsNotNull(from, nameof(from));
42+
_from = from; // null when using $documents in the pipeline
4343
_localField = Ensure.IsNotNull(localField, nameof(localField));
4444
_foreignField = Ensure.IsNotNull(foreignField, nameof(foreignField));
4545
_let = let?.AsReadOnlyList(); // can be null for an uncorrelated subquery
@@ -66,7 +66,7 @@ public override BsonValue Render()
6666
{
6767
{ "$lookup", new BsonDocument()
6868
{
69-
{ "from", _from },
69+
{ "from", _from, _from != null },
7070
{ "localField", _localField },
7171
{ "foreignField", _foreignField },
7272
{ "let", () => new BsonDocument(_let.Select(l => l.RenderAsElement())), _let?.Count > 0 },

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Stages/AstLookupWithPipelineStage.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public AstLookupWithPipelineStage(
3535
AstPipeline pipeline,
3636
string @as)
3737
{
38-
_from = Ensure.IsNotNull(from, nameof(from));
38+
_from = from; // null when using $documents in the pipeline
3939
_let = let?.AsReadOnlyList(); // can be null for an uncorrelated subquery
4040
_pipeline = Ensure.IsNotNull(pipeline, nameof(pipeline));
4141
_as = Ensure.IsNotNull(@as, nameof(@as));
@@ -58,7 +58,7 @@ public override BsonValue Render()
5858
{
5959
{ "$lookup", new BsonDocument()
6060
{
61-
{ "from", _from },
61+
{ "from", _from, _from != null },
6262
{ "let", () => new BsonDocument(_let.Select(l => l.RenderAsElement())), _let?.Count > 0 },
6363
{ "pipeline", _pipeline.Render() },
6464
{ "as", _as }

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Stages/AstStage.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,24 @@ public static AstStage Lookup(
179179
return new AstLookupWithMatchingFieldsAndPipelineStage(from, localField, foreignField, let, pipeline, @as);
180180
}
181181

182+
public static AstStage Lookup(
183+
IEnumerable<AstComputedField> let,
184+
AstPipeline pipeline,
185+
string @as)
186+
{
187+
return new AstLookupWithPipelineStage(from: null, let, pipeline, @as);
188+
}
189+
190+
public static AstStage Lookup(
191+
string localField,
192+
string foreignField,
193+
IEnumerable<AstComputedField> let,
194+
AstPipeline pipeline,
195+
string @as)
196+
{
197+
return new AstLookupWithMatchingFieldsAndPipelineStage(from: null, localField, foreignField, let, pipeline, @as);
198+
}
199+
182200
public static AstStage Match(AstFilter filter)
183201
{
184202
return new AstMatchStage(filter);

src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ public static object Evaluate(this Expression expression)
3636
}
3737
}
3838

39-
public static (string CollectionName, IBsonSerializer DocumentSerializer) GetCollectionInfo(this Expression innerExpression, Expression containerExpression)
39+
public static (string CollectionName, IBsonSerializer DocumentSerializer) GetCollectionInfoFromQueryable(this Expression queryableExpression, Expression containerExpression)
4040
{
41-
if (innerExpression is ConstantExpression constantExpression &&
41+
if (queryableExpression is ConstantExpression constantExpression &&
4242
constantExpression.Value is IQueryable queryable &&
4343
queryable.Provider is IMongoQueryProviderInternal mongoQueryProvider &&
4444
mongoQueryProvider.CollectionNamespace != null)
4545
{
4646
return (mongoQueryProvider.CollectionNamespace.CollectionName, mongoQueryProvider.PipelineInputSerializer);
4747
}
4848

49-
var message = $"inner expression must be a MongoDB IQueryable against a collection";
50-
throw new ExpressionNotSupportedException(innerExpression, containerExpression, because: message);
49+
var message = "expression must be a MongoDB IQueryable against a collection";
50+
throw new ExpressionNotSupportedException(queryableExpression, containerExpression, because: message);
5151
}
5252

5353
public static TValue GetConstantValue<TValue>(this Expression expression, Expression containingExpression)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq.Expressions;
17+
using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor;
18+
19+
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
20+
{
21+
internal class ExpressionIsReferencedVisitor : ExpressionVisitor
22+
{
23+
private readonly Expression _expression;
24+
private bool _expressionIsReferenced;
25+
26+
public ExpressionIsReferencedVisitor(Expression expression)
27+
{
28+
_expression = expression;
29+
}
30+
31+
public bool ExpressionIsReferenced => _expressionIsReferenced;
32+
33+
public override Expression Visit(Expression node)
34+
{
35+
// once we know the expression is referenced we can short circuit any further visiting
36+
if (_expressionIsReferenced)
37+
{
38+
return node;
39+
}
40+
else if (node == _expression)
41+
{
42+
_expressionIsReferenced = true;
43+
return node;
44+
}
45+
else
46+
{
47+
return base.Visit(node);
48+
}
49+
}
50+
}
51+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LambdaExpressionExtensions.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2424
{
2525
internal static class LambdaExpressionExtensions
2626
{
27+
public static bool LambdaBodyReferencesParameter(this LambdaExpression lambda, ParameterExpression parameter)
28+
{
29+
var visitor = new ExpressionIsReferencedVisitor(parameter);
30+
visitor.Visit(lambda.Body);
31+
return visitor.ExpressionIsReferenced;
32+
}
33+
2734
public static string TranslateToDottedFieldName(this LambdaExpression fieldSelectorLambda, TranslationContext context, IBsonSerializer parameterSerializer)
2835
{
2936
var parameterExpression = fieldSelectorLambda.Parameters.Single();

src/MongoDB.Driver/Linq/Linq3Implementation/MongoQueryProvider.cs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ internal abstract class MongoQueryProvider : IMongoQueryProviderInternal
3434

3535
// constructors
3636
protected MongoQueryProvider(
37-
IClientSessionHandle session,
37+
IClientSessionHandle session,
3838
AggregateOptions options)
3939
{
4040
_session = session;
@@ -63,6 +63,7 @@ internal sealed class MongoQueryProvider<TDocument> : MongoQueryProvider
6363
private readonly IMongoCollection<TDocument> _collection;
6464
private readonly IMongoDatabase _database;
6565
private ExecutableQuery<TDocument> _executedQuery;
66+
private readonly IBsonSerializer _pipelineInputSerializer;
6667

6768
// constructors
6869
public MongoQueryProvider(
@@ -72,6 +73,7 @@ public MongoQueryProvider(
7273
: base(session, options)
7374
{
7475
_collection = Ensure.IsNotNull(collection, nameof(collection));
76+
_pipelineInputSerializer = collection.DocumentSerializer;
7577
}
7678

7779
public MongoQueryProvider(
@@ -81,14 +83,24 @@ public MongoQueryProvider(
8183
: base(session, options)
8284
{
8385
_database = Ensure.IsNotNull(database, nameof(database));
86+
_pipelineInputSerializer = NoPipelineInputSerializer.Instance;
87+
}
88+
89+
internal MongoQueryProvider(
90+
IBsonSerializer pipelineInputSerializer,
91+
IClientSessionHandle session,
92+
AggregateOptions options)
93+
: base(session, options)
94+
{
95+
_pipelineInputSerializer = Ensure.IsNotNull(pipelineInputSerializer, nameof(pipelineInputSerializer));
8496
}
8597

8698
// public properties
8799
public IMongoCollection<TDocument> Collection => _collection;
88100
public override CollectionNamespace CollectionNamespace => _collection == null ? null : _collection.CollectionNamespace;
89101
public IMongoDatabase Database => _database;
90102
public override BsonDocument[] LoggedStages => _executedQuery?.LoggedStages;
91-
public override IBsonSerializer PipelineInputSerializer => _collection == null ? NoPipelineInputSerializer.Instance : _collection.DocumentSerializer;
103+
public override IBsonSerializer PipelineInputSerializer => _pipelineInputSerializer;
92104

93105
// public methods
94106
public override IQueryable CreateQuery(Expression expression)
@@ -137,8 +149,8 @@ public Task<TResult> ExecuteAsync<TResult>(ExecutableQuery<TDocument, TResult> e
137149
public override ExpressionTranslationOptions GetTranslationOptions()
138150
{
139151
var translationOptions = _options?.TranslationOptions;
140-
var database = _database ?? _collection.Database;
141-
return translationOptions.AddMissingOptionsFrom(database.Client.Settings.TranslationOptions);
152+
var database = _database ?? _collection?.Database;
153+
return translationOptions.AddMissingOptionsFrom(database?.Client.Settings.TranslationOptions);
142154
}
143155
}
144156
}

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ internal static class MongoQueryableMethod
6161
private static readonly MethodInfo __firstWithPredicateAsync;
6262
private static readonly MethodInfo __longCountAsync;
6363
private static readonly MethodInfo __longCountWithPredicateAsync;
64+
private static readonly MethodInfo __lookupWithDocumentsAndLocalFieldAndForeignField;
65+
private static readonly MethodInfo __lookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline;
66+
private static readonly MethodInfo __lookupWithDocumentsAndPipeline;
67+
private static readonly MethodInfo __lookupWithFromAndLocalFieldAndForeignField;
68+
private static readonly MethodInfo __lookupWithFromAndLocalFieldAndForeignFieldAndPipeline;
69+
private static readonly MethodInfo __lookupWithFromAndPipeline;
6470
private static readonly MethodInfo __maxAsync;
6571
private static readonly MethodInfo __maxWithSelectorAsync;
6672
private static readonly MethodInfo __minAsync;
@@ -211,6 +217,12 @@ static MongoQueryableMethod()
211217
__firstWithPredicateAsync = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate, CancellationToken cancellationToken) => source.FirstAsync(predicate, cancellationToken));
212218
__longCountAsync = ReflectionInfo.Method((IQueryable<object> source, CancellationToken cancellationToken) => source.LongCountAsync(cancellationToken));
213219
__longCountWithPredicateAsync = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate, CancellationToken cancellationToken) => source.LongCountAsync(predicate, cancellationToken));
220+
__lookupWithDocumentsAndLocalFieldAndForeignField = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, IEnumerable<object>>> documents, Expression<Func<object, object>> localField, Expression<Func<object, object>> foreignField) => source.Lookup(documents, localField, foreignField));
221+
__lookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, IEnumerable<object>>> documents, Expression<Func<object, object>> localField, Expression<Func<object, object>> foreignField, Expression<Func<object, IQueryable<object>, IQueryable<object>>> pipeline) => source.Lookup(documents, localField, foreignField, pipeline));
222+
__lookupWithDocumentsAndPipeline = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, IEnumerable<object>>> documents, Expression<Func<object, IQueryable<object>, IQueryable<object>>> pipeline) => source.Lookup(documents, pipeline));
223+
__lookupWithFromAndLocalFieldAndForeignField = ReflectionInfo.Method((IQueryable<object> source, IMongoCollection<object> from, Expression<Func<object, object>> localField, Expression<Func<object, object>> foreignField) => source.Lookup(from, localField, foreignField));
224+
__lookupWithFromAndLocalFieldAndForeignFieldAndPipeline = ReflectionInfo.Method((IQueryable<object> source, IMongoCollection<object> from, Expression<Func<object, object>> localField, Expression<Func<object, object>> foreignField, Expression<Func<object, IQueryable<object>, IQueryable<object>>> pipeline) => source.Lookup(from, localField, foreignField, pipeline));
225+
__lookupWithFromAndPipeline = ReflectionInfo.Method((IQueryable<object> source, IMongoCollection<object> from, Expression<Func<object, IQueryable<object>, IQueryable<object>>> pipeline) => source.Lookup(from, pipeline));
214226
__maxAsync = ReflectionInfo.Method((IQueryable<object> source, CancellationToken cancellationToken) => source.MaxAsync(cancellationToken));
215227
__maxWithSelectorAsync = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object>> selector, CancellationToken cancellationToken) => source.MaxAsync(selector, cancellationToken));
216228
__minAsync = ReflectionInfo.Method((IQueryable<object> source, CancellationToken cancellationToken) => source.MinAsync(cancellationToken));
@@ -360,6 +372,12 @@ static MongoQueryableMethod()
360372
public static MethodInfo FirstWithPredicateAsync => __firstWithPredicateAsync;
361373
public static MethodInfo LongCountAsync => __longCountAsync;
362374
public static MethodInfo LongCountWithPredicateAsync => __longCountWithPredicateAsync;
375+
public static MethodInfo LookupWithDocumentsAndLocalFieldAndForeignField => __lookupWithDocumentsAndLocalFieldAndForeignField;
376+
public static MethodInfo LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline => __lookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline;
377+
public static MethodInfo LookupWithDocumentsAndPipeline => __lookupWithDocumentsAndPipeline;
378+
public static MethodInfo LookupWithFromAndLocalFieldAndForeignField => __lookupWithFromAndLocalFieldAndForeignField;
379+
public static MethodInfo LookupWithFromAndLocalFieldAndForeignFieldAndPipeline => __lookupWithFromAndLocalFieldAndForeignFieldAndPipeline;
380+
public static MethodInfo LookupWithFromAndPipeline => __lookupWithFromAndPipeline;
363381
public static MethodInfo MaxAsync => __maxAsync;
364382
public static MethodInfo MaxWithSelectorAsync => __maxWithSelectorAsync;
365383
public static MethodInfo MinAsync => __minAsync;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ExpressionToPipelineTranslator.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ public static AstPipeline Translate(TranslationContext context, Expression expre
5252
return GroupJoinMethodToPipelineTranslator.Translate(context, methodCallExpression);
5353
case "Join":
5454
return JoinMethodToPipelineTranslator.Translate(context, methodCallExpression);
55+
case "Lookup":
56+
return LookupMethodToPipelineTranslator.Translate(context, methodCallExpression);
5557
case "OfType":
5658
return OfTypeMethodToPipelineTranslator.Translate(context, methodCallExpression);
5759
case "OrderBy":

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres
5757
var wrappedOuterSerializer = WrappedValueSerializer.Create("_outer", outerSerializer);
5858

5959
var innerExpression = arguments[1];
60-
var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfo(containerExpression: expression);
60+
var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfoFromQueryable(containerExpression: expression);
6161

6262
var outerKeySelectorLambda = ExpressionHelper.UnquoteLambda(arguments[2]);
6363
var localField = outerKeySelectorLambda.TranslateToDottedFieldName(context, wrappedOuterSerializer);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres
6363
AstProject.Exclude("_id"));
6464
var wrappedOuterSerializer = WrappedValueSerializer.Create("_outer", outerSerializer);
6565

66-
var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfo(containerExpression: expression);
66+
var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfoFromQueryable(containerExpression: expression);
6767
var localField = outerKeySelectorLambda.TranslateToDottedFieldName(context, wrappedOuterSerializer);
6868
var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, innerSerializer);
6969

0 commit comments

Comments
 (0)