diff --git a/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs b/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs index 9b6492d4cb0..80bd33cb9a1 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs @@ -22,6 +22,8 @@ public class AnotherEntityRequired public virtual ISet RelatedItems { get; set; } = new HashSet(); + public virtual ISet RequiredRelatedItems { get; set; } = new HashSet(); + public virtual bool? NullableBool { get; set; } } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml index 0d9efe4136f..755f83c7c05 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml @@ -19,5 +19,9 @@ + + + + diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 4bb8ca0f7f5..e5d239f7ed2 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -81,6 +81,80 @@ public async Task UsingEntityParameterTwiceAsync() 1)); } + [Test] + public async Task UsingEntityParameterForCollectionAsync() + { + var item = await (db.OrderLines.FirstAsync()); + await (AssertTotalParametersAsync( + db.Orders.Where(o => o.OrderLines.Contains(item)), + 1)); + } + + [Test] + public async Task UsingProxyParameterForCollectionAsync() + { + var item = await (session.LoadAsync(10248)); + Assert.That(NHibernateUtil.IsInitialized(item), Is.False); + await (AssertTotalParametersAsync( + db.Customers.Where(o => o.Orders.Contains(item)), + 1)); + } + + [Test] + public async Task UsingFieldProxyParameterForCollectionAsync() + { + var item = await (session.Query().FirstAsync()); + await (AssertTotalParametersAsync( + session.Query().Where(o => o.RequiredRelatedItems.Contains(item)), + 1)); + } + + [Test] + public async Task UsingEntityParameterInSubQueryAsync() + { + var item = await (db.Customers.FirstAsync()); + var subQuery = db.Orders.Select(o => o.Customer).Where(o => o == item); + await (AssertTotalParametersAsync( + db.Orders.Where(o => subQuery.Contains(o.Customer)), + 1)); + } + + [Test] + public async Task UsingEntityParameterForCollectionSelectionAsync() + { + var item = await (db.OrderLines.FirstAsync()); + await (AssertTotalParametersAsync( + db.Orders.SelectMany(o => o.OrderLines).Where(o => o == item), + 1)); + } + + [Test] + public async Task UsingFieldProxyParameterForCollectionSelectionAsync() + { + var item = await (session.Query().FirstAsync()); + await (AssertTotalParametersAsync( + session.Query().SelectMany(o => o.RequiredRelatedItems).Where(o => o == item), + 1)); + } + + [Test] + public async Task UsingEntityListParameterForCollectionSelectionAsync() + { + var items = new[] {await (db.OrderLines.FirstAsync())}; + await (AssertTotalParametersAsync( + db.Orders.SelectMany(o => o.OrderLines).Where(o => items.Contains(o)), + 1)); + } + + [Test] + public async Task UsingFieldProxyListParameterForCollectionSelectionAsync() + { + var items = new[] {await (session.Query().FirstAsync())}; + await (AssertTotalParametersAsync( + session.Query().SelectMany(o => o.RequiredRelatedItems).Where(o => items.Contains(o)), + 1)); + } + [Test] public async Task UsingTwoEntityParametersAsync() { diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 73028ba8598..fdbb1a73275 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -69,6 +69,80 @@ public void UsingEntityParameterTwice() 1); } + [Test] + public void UsingEntityParameterForCollection() + { + var item = db.OrderLines.First(); + AssertTotalParameters( + db.Orders.Where(o => o.OrderLines.Contains(item)), + 1); + } + + [Test] + public void UsingProxyParameterForCollection() + { + var item = session.Load(10248); + Assert.That(NHibernateUtil.IsInitialized(item), Is.False); + AssertTotalParameters( + db.Customers.Where(o => o.Orders.Contains(item)), + 1); + } + + [Test] + public void UsingFieldProxyParameterForCollection() + { + var item = session.Query().First(); + AssertTotalParameters( + session.Query().Where(o => o.RequiredRelatedItems.Contains(item)), + 1); + } + + [Test] + public void UsingEntityParameterInSubQuery() + { + var item = db.Customers.First(); + var subQuery = db.Orders.Select(o => o.Customer).Where(o => o == item); + AssertTotalParameters( + db.Orders.Where(o => subQuery.Contains(o.Customer)), + 1); + } + + [Test] + public void UsingEntityParameterForCollectionSelection() + { + var item = db.OrderLines.First(); + AssertTotalParameters( + db.Orders.SelectMany(o => o.OrderLines).Where(o => o == item), + 1); + } + + [Test] + public void UsingFieldProxyParameterForCollectionSelection() + { + var item = session.Query().First(); + AssertTotalParameters( + session.Query().SelectMany(o => o.RequiredRelatedItems).Where(o => o == item), + 1); + } + + [Test] + public void UsingEntityListParameterForCollectionSelection() + { + var items = new[] {db.OrderLines.First()}; + AssertTotalParameters( + db.Orders.SelectMany(o => o.OrderLines).Where(o => items.Contains(o)), + 1); + } + + [Test] + public void UsingFieldProxyListParameterForCollectionSelection() + { + var items = new[] {session.Query().First()}; + AssertTotalParameters( + session.Query().SelectMany(o => o.RequiredRelatedItems).Where(o => items.Contains(o)), + 1); + } + [Test] public void UsingTwoEntityParameters() { diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 37e3da19852..9cab3d6cd26 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -117,7 +117,7 @@ private static IType GetCandidateType( if (!ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var mappedType, out _, out _, out _)) continue; - if (mappedType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) + if (mappedType.IsCollectionType) { var collection = (IQueryableCollection) ((IAssociationType) mappedType).GetAssociatedJoinable(sessionFactory); mappedType = collection.ElementType; @@ -176,7 +176,6 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor new Dictionary>(); public readonly Dictionary> RelatedExpressions = new Dictionary>(); - public readonly HashSet SequenceSelectorExpressions = new HashSet(); public ConstantTypeLocatorVisitor( bool removeMappedAsCalls, @@ -282,41 +281,43 @@ protected override Expression VisitConstant(ConstantExpression node) } protected override Expression VisitSubQuery(SubQueryExpression node) + { + if (!TryLinkContainsMethod(node.QueryModel)) + { + node.QueryModel.TransformExpressions(Visit); + } + + return node; + } + + private bool TryLinkContainsMethod(QueryModel queryModel) { // ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of // ContainsResultOperator where the constant expression is dislocated from the related expression, // we have to manually link the related expressions. - if (node.QueryModel.ResultOperators.Count == 1 && - node.QueryModel.ResultOperators[0] is ContainsResultOperator containsOperator && - node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference && - querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && - mainFromClause.FromExpression is ConstantExpression constantExpression) + if (queryModel.ResultOperators.Count != 1 || + !(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator) || + !(queryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference) || + !(querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause)) { - VisitConstant(constantExpression); - AddRelatedExpression(constantExpression, UnwrapUnary(Visit(containsOperator.Item))); - // Copy all found MemberExpressions to the constant expression - // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2) - if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set)) - { - foreach (var nestedMemberExpression in set) - { - AddRelatedExpression(constantExpression, nestedMemberExpression); - } - } + return false; } - else - { - // In case a parameter is related to a sequence selector we will have to get the underlying item type - // (e.g. q.Where(o => o.Users.Any(u => u == user))) - if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase)) - { - SequenceSelectorExpressions.Add(node.QueryModel.SelectClause.Selector); - } - node.QueryModel.TransformExpressions(Visit); + var left = UnwrapUnary(Visit(mainFromClause.FromExpression)); + var right = UnwrapUnary(Visit(containsOperator.Item)); + // The constant is on the left side (e.g. db.Users.Where(o => users.Contains(o))) + // The constant is on the right side (e.g. db.Customers.Where(o => o.Orders.Contains(item))) + if (left.NodeType != ExpressionType.Constant && right.NodeType != ExpressionType.Constant) + { + return false; } - return node; + // Copy all found MemberExpressions to the constant expression + // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2) + AddRelatedExpression(null, left, right); + AddRelatedExpression(null, right, left); + + return true; } private void VisitAssign(Expression leftNode, Expression rightNode) @@ -346,7 +347,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r left is QuerySourceReferenceExpression) { AddRelatedExpression(right, left); - if (NonVoidOperators.Contains(node.NodeType)) + if (node != null && NonVoidOperators.Contains(node.NodeType)) { AddRelatedExpression(node, left); } @@ -359,7 +360,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r foreach (var nestedMemberExpression in set) { AddRelatedExpression(right, nestedMemberExpression); - if (NonVoidOperators.Contains(node.NodeType)) + if (node != null && NonVoidOperators.Contains(node.NodeType)) { AddRelatedExpression(node, nestedMemberExpression); }