Skip to content

Commit de2d941

Browse files
authored
Improve LINQ Contains subquery parameter detection (#3274)
We should always try to detect parameters. And parameter detection shouldn't skip query transformation (Fixes failure case from #3212) Replaces #3212
1 parent 83db507 commit de2d941

File tree

3 files changed

+55
-18
lines changed

3 files changed

+55
-18
lines changed

src/NHibernate.Test/Async/Linq/WhereTests.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using System.Linq;
1616
using System.Linq.Expressions;
1717
using log4net.Core;
18+
using NHibernate.Dialect;
1819
using NHibernate.Engine.Query;
1920
using NHibernate.Linq;
2021
using NHibernate.DomainModel.Northwind.Entities;
@@ -647,6 +648,9 @@ where sheet.Users.Contains(user)
647648
[Test]
648649
public async Task TimesheetsWithEnumerableContainsOnSelectAsync()
649650
{
651+
if (Dialect is MsSqlCeDialect)
652+
Assert.Ignore("Dialect is not supported");
653+
650654
var value = (EnumStoredAsInt32) 1000;
651655
var query = await ((from sheet in db.Timesheets
652656
where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
@@ -655,6 +659,24 @@ where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
655659
Assert.That(query.Count, Is.EqualTo(1));
656660
}
657661

662+
[Test]
663+
public async Task ContainsSubqueryWithCoalesceStringEnumSelectAsync()
664+
{
665+
if (Dialect is MsSqlCeDialect || Dialect is SQLiteDialect)
666+
Assert.Ignore("Dialect is not supported");
667+
668+
var results =
669+
await (db.Timesheets.Where(
670+
o =>
671+
o.Users
672+
.Where(u => u.Id != 0.MappedAs(NHibernateUtil.Int32))
673+
.Select(u => u.Name == u.Name ? u.Enum1 : u.NullableEnum1.Value)
674+
.Contains(EnumStoredAsString.Small))
675+
.ToListAsync());
676+
677+
Assert.That(results.Count, Is.EqualTo(1));
678+
}
679+
658680
[Test]
659681
public async Task SearchOnObjectTypeWithExtensionMethodAsync()
660682
{

src/NHibernate.Test/Linq/WhereTests.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Linq;
66
using System.Linq.Expressions;
77
using log4net.Core;
8+
using NHibernate.Dialect;
89
using NHibernate.Engine.Query;
910
using NHibernate.Linq;
1011
using NHibernate.DomainModel.Northwind.Entities;
@@ -648,6 +649,9 @@ where sheet.Users.Contains(user)
648649
[Test]
649650
public void TimesheetsWithEnumerableContainsOnSelect()
650651
{
652+
if (Dialect is MsSqlCeDialect)
653+
Assert.Ignore("Dialect is not supported");
654+
651655
var value = (EnumStoredAsInt32) 1000;
652656
var query = (from sheet in db.Timesheets
653657
where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
@@ -656,6 +660,24 @@ where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
656660
Assert.That(query.Count, Is.EqualTo(1));
657661
}
658662

663+
[Test]
664+
public void ContainsSubqueryWithCoalesceStringEnumSelect()
665+
{
666+
if (Dialect is MsSqlCeDialect || Dialect is SQLiteDialect)
667+
Assert.Ignore("Dialect is not supported");
668+
669+
var results =
670+
db.Timesheets.Where(
671+
o =>
672+
o.Users
673+
.Where(u => u.Id != 0.MappedAs(NHibernateUtil.Int32))
674+
.Select(u => u.Name == u.Name ? u.Enum1 : u.NullableEnum1.Value)
675+
.Contains(EnumStoredAsString.Small))
676+
.ToList();
677+
678+
Assert.That(results.Count, Is.EqualTo(1));
679+
}
680+
659681
[Test]
660682
public void SearchOnObjectTypeWithExtensionMethod()
661683
{

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -288,42 +288,35 @@ protected override Expression VisitConstant(ConstantExpression node)
288288

289289
protected override Expression VisitSubQuery(SubQueryExpression node)
290290
{
291-
if (!TryLinkContainsMethod(node.QueryModel))
292-
{
293-
node.QueryModel.TransformExpressions(Visit);
294-
}
291+
TryLinkContainsMethod(node.QueryModel);
292+
node.QueryModel.TransformExpressions(Visit);
295293

296294
return node;
297295
}
298296

299-
private bool TryLinkContainsMethod(QueryModel queryModel)
297+
private void TryLinkContainsMethod(QueryModel queryModel)
300298
{
301299
// ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
302300
// ContainsResultOperator where the constant expression is dislocated from the related expression,
303301
// we have to manually link the related expressions.
304302
if (queryModel.ResultOperators.Count != 1 ||
305-
!(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator) ||
306-
!(queryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference) ||
307-
!(querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause))
303+
!(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator))
308304
{
309-
return false;
305+
return;
310306
}
311307

312-
var left = UnwrapUnary(Visit(mainFromClause.FromExpression));
308+
Expression selector =
309+
queryModel.SelectClause.Selector is QuerySourceReferenceExpression { ReferencedQuerySource: MainFromClause mainFromClause }
310+
? mainFromClause.FromExpression
311+
: queryModel.SelectClause.Selector;
312+
313+
var left = UnwrapUnary(Visit(selector));
313314
var right = UnwrapUnary(Visit(containsOperator.Item));
314-
// The constant is on the left side (e.g. db.Users.Where(o => users.Contains(o)))
315-
// The constant is on the right side (e.g. db.Customers.Where(o => o.Orders.Contains(item)))
316-
if (left.NodeType != ExpressionType.Constant && right.NodeType != ExpressionType.Constant)
317-
{
318-
return false;
319-
}
320315

321316
// Copy all found MemberExpressions to the constant expression
322317
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
323318
AddRelatedExpression(null, left, right);
324319
AddRelatedExpression(null, right, left);
325-
326-
return true;
327320
}
328321

329322
private void VisitAssign(Expression leftNode, Expression rightNode)

0 commit comments

Comments
 (0)