diff --git a/src/NHibernate.Test/Async/Linq/WhereTests.cs b/src/NHibernate.Test/Async/Linq/WhereTests.cs index 65539b61085..bec384c9b31 100644 --- a/src/NHibernate.Test/Async/Linq/WhereTests.cs +++ b/src/NHibernate.Test/Async/Linq/WhereTests.cs @@ -844,6 +844,18 @@ public async Task SelectOnCollectionReturnsResultAsync() Assert.That(result.Children, Is.Not.Empty); } + [Test(Description = "GH-1556")] + public async Task ContainsOnPersistedCollectionAsync() + { + var animal = await (session.Query().SingleAsync(a => a.SerialNumber == "123")); + + var result = await (session.Query() + .Where(e => animal.Children.Contains(e.Father)) + .OrderBy(e => e.Id) + .FirstOrDefaultAsync()); + Assert.That(result, Is.Not.Null); + Assert.That(result.SerialNumber, Is.EqualTo("1121")); + } private static List CanUseCompareInQueryDataSource() { diff --git a/src/NHibernate.Test/Linq/WhereTests.cs b/src/NHibernate.Test/Linq/WhereTests.cs index d16311f70a2..e75e232939a 100644 --- a/src/NHibernate.Test/Linq/WhereTests.cs +++ b/src/NHibernate.Test/Linq/WhereTests.cs @@ -812,6 +812,18 @@ public void SelectOnCollectionReturnsResult() Assert.That(result.Children, Is.Not.Empty); } + [Test(Description = "GH-1556")] + public void ContainsOnPersistedCollection() + { + var animal = session.Query().Single(a => a.SerialNumber == "123"); + + var result = session.Query() + .Where(e => animal.Children.Contains(e.Father)) + .OrderBy(e => e.Id) + .FirstOrDefault(); + Assert.That(result, Is.Not.Null); + Assert.That(result.SerialNumber, Is.EqualTo("1121")); + } private static List CanUseCompareInQueryDataSource() { diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs index 64b473a2ec3..ae520e1904d 100644 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using System.Linq.Expressions; +using NHibernate.Collection; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing; using Remotion.Linq.Parsing.ExpressionVisitors; @@ -12,13 +13,12 @@ internal class NhPartialEvaluatingExpressionVisitor : RelinqExpressionVisitor, I { protected override Expression VisitConstant(ConstantExpression expression) { - var value = expression.Value as Expression; - if (value == null) + if (expression.Value is Expression value) { - return base.VisitConstant(expression); + return EvaluateIndependentSubtrees(value); } - return EvaluateIndependentSubtrees(value); + return base.VisitConstant(expression); } public static Expression EvaluateIndependentSubtrees(Expression expression) @@ -37,6 +37,16 @@ public Expression VisitPartialEvaluationException(PartialEvaluationExceptionExpr internal class NhEvaluatableExpressionFilter : EvaluatableExpressionFilterBase { + public override bool IsEvaluatableConstant(ConstantExpression node) + { + if (node.Value is IPersistentCollection && node.Value is IQueryable) + { + return false; + } + + return base.IsEvaluatableConstant(node); + } + public override bool IsEvaluatableMethodCall(MethodCallExpression node) { if (node == null)