Skip to content

Commit 2b2abe8

Browse files
committed
Fix parameter detection for Contains method for Linq provider
1 parent b89c16b commit 2b2abe8

File tree

5 files changed

+194
-29
lines changed

5 files changed

+194
-29
lines changed

src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ public class AnotherEntityRequired
2222

2323
public virtual ISet<AnotherEntity> RelatedItems { get; set; } = new HashSet<AnotherEntity>();
2424

25+
public virtual ISet<AnotherEntityRequired> RequiredRelatedItems { get; set; } = new HashSet<AnotherEntityRequired>();
26+
2527
public virtual bool? NullableBool { get; set; }
2628
}
2729

src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,9 @@
1919
<key column="Id"/>
2020
<one-to-many class="AnotherEntity"/>
2121
</set>
22+
<set name="RequiredRelatedItems" lazy="true" inverse="true">
23+
<key column="Id"/>
24+
<one-to-many class="AnotherEntityRequired"/>
25+
</set>
2226
</class>
2327
</hibernate-mapping>

src/NHibernate.Test/Async/Linq/ParameterTests.cs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,80 @@ public async Task UsingEntityParameterTwiceAsync()
8181
1));
8282
}
8383

84+
[Test]
85+
public async Task UsingEntityParameterForCollectionAsync()
86+
{
87+
var item = await (db.OrderLines.FirstAsync());
88+
await (AssertTotalParametersAsync(
89+
db.Orders.Where(o => o.OrderLines.Contains(item)),
90+
1));
91+
}
92+
93+
[Test]
94+
public async Task UsingProxyParameterForCollectionAsync()
95+
{
96+
var item = await (session.LoadAsync<Order>(10248));
97+
Assert.That(NHibernateUtil.IsInitialized(item), Is.False);
98+
await (AssertTotalParametersAsync(
99+
db.Customers.Where(o => o.Orders.Contains(item)),
100+
1));
101+
}
102+
103+
[Test]
104+
public async Task UsingFieldProxyParameterForCollectionAsync()
105+
{
106+
var item = await (session.Query<AnotherEntityRequired>().FirstAsync());
107+
await (AssertTotalParametersAsync(
108+
session.Query<AnotherEntityRequired>().Where(o => o.RequiredRelatedItems.Contains(item)),
109+
1));
110+
}
111+
112+
[Test]
113+
public async Task UsingEntityParameterInSubQueryAsync()
114+
{
115+
var item = await (db.Customers.FirstAsync());
116+
var subQuery = db.Orders.Select(o => o.Customer).Where(o => o == item);
117+
await (AssertTotalParametersAsync(
118+
db.Orders.Where(o => subQuery.Contains(o.Customer)),
119+
1));
120+
}
121+
122+
[Test]
123+
public async Task UsingEntityParameterForCollectionSelectionAsync()
124+
{
125+
var item = await (db.OrderLines.FirstAsync());
126+
await (AssertTotalParametersAsync(
127+
db.Orders.SelectMany(o => o.OrderLines).Where(o => o == item),
128+
1));
129+
}
130+
131+
[Test]
132+
public async Task UsingFieldProxyParameterForCollectionSelectionAsync()
133+
{
134+
var item = await (session.Query<AnotherEntityRequired>().FirstAsync());
135+
await (AssertTotalParametersAsync(
136+
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => o == item),
137+
1));
138+
}
139+
140+
[Test]
141+
public async Task UsingEntityListParameterForCollectionSelectionAsync()
142+
{
143+
var items = new[] {await (db.OrderLines.FirstAsync())};
144+
await (AssertTotalParametersAsync(
145+
db.Orders.SelectMany(o => o.OrderLines).Where(o => items.Contains(o)),
146+
1));
147+
}
148+
149+
[Test]
150+
public async Task UsingFieldProxyListParameterForCollectionSelectionAsync()
151+
{
152+
var items = new[] {await (session.Query<AnotherEntityRequired>().FirstAsync())};
153+
await (AssertTotalParametersAsync(
154+
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => items.Contains(o)),
155+
1));
156+
}
157+
84158
[Test]
85159
public async Task UsingTwoEntityParametersAsync()
86160
{

src/NHibernate.Test/Linq/ParameterTests.cs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,80 @@ public void UsingEntityParameterTwice()
6969
1);
7070
}
7171

72+
[Test]
73+
public void UsingEntityParameterForCollection()
74+
{
75+
var item = db.OrderLines.First();
76+
AssertTotalParameters(
77+
db.Orders.Where(o => o.OrderLines.Contains(item)),
78+
1);
79+
}
80+
81+
[Test]
82+
public void UsingProxyParameterForCollection()
83+
{
84+
var item = session.Load<Order>(10248);
85+
Assert.That(NHibernateUtil.IsInitialized(item), Is.False);
86+
AssertTotalParameters(
87+
db.Customers.Where(o => o.Orders.Contains(item)),
88+
1);
89+
}
90+
91+
[Test]
92+
public void UsingFieldProxyParameterForCollection()
93+
{
94+
var item = session.Query<AnotherEntityRequired>().First();
95+
AssertTotalParameters(
96+
session.Query<AnotherEntityRequired>().Where(o => o.RequiredRelatedItems.Contains(item)),
97+
1);
98+
}
99+
100+
[Test]
101+
public void UsingEntityParameterInSubQuery()
102+
{
103+
var item = db.Customers.First();
104+
var subQuery = db.Orders.Select(o => o.Customer).Where(o => o == item);
105+
AssertTotalParameters(
106+
db.Orders.Where(o => subQuery.Contains(o.Customer)),
107+
1);
108+
}
109+
110+
[Test]
111+
public void UsingEntityParameterForCollectionSelection()
112+
{
113+
var item = db.OrderLines.First();
114+
AssertTotalParameters(
115+
db.Orders.SelectMany(o => o.OrderLines).Where(o => o == item),
116+
1);
117+
}
118+
119+
[Test]
120+
public void UsingFieldProxyParameterForCollectionSelection()
121+
{
122+
var item = session.Query<AnotherEntityRequired>().First();
123+
AssertTotalParameters(
124+
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => o == item),
125+
1);
126+
}
127+
128+
[Test]
129+
public void UsingEntityListParameterForCollectionSelection()
130+
{
131+
var items = new[] {db.OrderLines.First()};
132+
AssertTotalParameters(
133+
db.Orders.SelectMany(o => o.OrderLines).Where(o => items.Contains(o)),
134+
1);
135+
}
136+
137+
[Test]
138+
public void UsingFieldProxyListParameterForCollectionSelection()
139+
{
140+
var items = new[] {session.Query<AnotherEntityRequired>().First()};
141+
AssertTotalParameters(
142+
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => items.Contains(o)),
143+
1);
144+
}
145+
72146
[Test]
73147
public void UsingTwoEntityParameters()
74148
{

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ private static IType GetCandidateType(
117117
if (!ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var mappedType, out _, out _, out _))
118118
continue;
119119

120-
if (mappedType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression))
120+
if (mappedType.IsCollectionType)
121121
{
122122
var collection = (IQueryableCollection) ((IAssociationType) mappedType).GetAssociatedJoinable(sessionFactory);
123123
mappedType = collection.ElementType;
@@ -199,7 +199,6 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
199199
new Dictionary<NamedParameter, HashSet<ConstantExpression>>();
200200
public readonly Dictionary<Expression, HashSet<Expression>> RelatedExpressions =
201201
new Dictionary<Expression, HashSet<Expression>>();
202-
public readonly HashSet<Expression> SequenceSelectorExpressions = new HashSet<Expression>();
203202

204203
public ConstantTypeLocatorVisitor(
205204
bool removeMappedAsCalls,
@@ -305,41 +304,53 @@ protected override Expression VisitConstant(ConstantExpression node)
305304
}
306305

307306
protected override Expression VisitSubQuery(SubQueryExpression node)
307+
{
308+
if (!TryLinkContainsMethod(node.QueryModel))
309+
{
310+
node.QueryModel.TransformExpressions(Visit);
311+
}
312+
313+
return node;
314+
}
315+
316+
private bool TryLinkContainsMethod(QueryModel queryModel)
308317
{
309318
// ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
310319
// ContainsResultOperator where the constant expression is dislocated from the related expression,
311320
// we have to manually link the related expressions.
312-
if (node.QueryModel.ResultOperators.Count == 1 &&
313-
node.QueryModel.ResultOperators[0] is ContainsResultOperator containsOperator &&
314-
node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference &&
315-
querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause &&
316-
mainFromClause.FromExpression is ConstantExpression constantExpression)
321+
if (queryModel.ResultOperators.Count != 1 ||
322+
!(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator) ||
323+
!(queryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference) ||
324+
!(querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause))
317325
{
318-
VisitConstant(constantExpression);
319-
AddRelatedExpression(constantExpression, UnwrapUnary(Visit(containsOperator.Item)));
320-
// Copy all found MemberExpressions to the constant expression
321-
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
322-
if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set))
323-
{
324-
foreach (var nestedMemberExpression in set)
325-
{
326-
AddRelatedExpression(constantExpression, nestedMemberExpression);
327-
}
328-
}
326+
return false;
327+
}
328+
329+
var left = UnwrapUnary(mainFromClause.FromExpression);
330+
var right = UnwrapUnary(containsOperator.Item);
331+
if (left.NodeType == ExpressionType.Constant)
332+
{
333+
// The constant is on the left side (e.g. db.Users.Where(o => users.Contains(o)))
334+
VisitConstant((ConstantExpression) left);
335+
right = UnwrapUnary(Visit(containsOperator.Item));
336+
}
337+
else if (right.NodeType == ExpressionType.Constant)
338+
{
339+
// The constant is on the right side (e.g. db.Customers.Where(o => o.Orders.Contains(item)))
340+
VisitConstant((ConstantExpression) right);
341+
left = UnwrapUnary(Visit(mainFromClause.FromExpression));
329342
}
330343
else
331344
{
332-
// In case a parameter is related to a sequence selector we will have to get the underlying item type
333-
// (e.g. q.Where(o => o.Users.Any(u => u == user)))
334-
if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase))
335-
{
336-
SequenceSelectorExpressions.Add(node.QueryModel.SelectClause.Selector);
337-
}
338-
339-
node.QueryModel.TransformExpressions(Visit);
345+
return false;
340346
}
341347

342-
return node;
348+
// Copy all found MemberExpressions to the constant expression
349+
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
350+
AddRelatedExpression(null, left, right);
351+
AddRelatedExpression(null, right, left);
352+
353+
return true;
343354
}
344355

345356
private void VisitAssign(Expression leftNode, Expression rightNode)
@@ -369,7 +380,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r
369380
left is QuerySourceReferenceExpression)
370381
{
371382
AddRelatedExpression(right, left);
372-
if (NonVoidOperators.Contains(node.NodeType))
383+
if (node != null && NonVoidOperators.Contains(node.NodeType))
373384
{
374385
AddRelatedExpression(node, left);
375386
}
@@ -382,7 +393,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r
382393
foreach (var nestedMemberExpression in set)
383394
{
384395
AddRelatedExpression(right, nestedMemberExpression);
385-
if (NonVoidOperators.Contains(node.NodeType))
396+
if (node != null && NonVoidOperators.Contains(node.NodeType))
386397
{
387398
AddRelatedExpression(node, nestedMemberExpression);
388399
}

0 commit comments

Comments
 (0)