diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ExpansionRegressionTests.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ExpansionRegressionTests.cs new file mode 100644 index 00000000000..86c809a8c1d --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ExpansionRegressionTests.cs @@ -0,0 +1,101 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + [TestFixture] + public class ExpansionRegressionTestsAsync : GH1879BaseFixtureAsync + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Save(new Invoice { InvoiceNumber = 1, Amount = 10, SpecialAmount = 100, Paid = false }); + session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 100, Paid = true }); + session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = false }); + session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = true }); + + session.Flush(); + transaction.Commit(); + } + } + + protected override void Configure(Configuration configuration) + { + configuration.LinqToHqlGeneratorsRegistry(); + } + + private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry + { + public TestLinqToHqlGeneratorsRegistry() + { + this.Merge(new ObjectEquality()); + } + } + + private class ObjectEquality : IHqlGeneratorForMethod + { + public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(), visitor.Visit(arguments[0]).AsExpression()); + } + + public IEnumerable SupportedMethods + { + get + { + yield return ReflectHelper.GetMethodDefinition(x => x.Equals(x)); + } + } + } + + [Test] + public async Task MethodShouldNotExpandForNonConditionalOrCoalesceAsync() + { + using (var session = OpenSession()) + { + Assert.That(await (session.Query().CountAsync(e => ((object)(e.Amount + e.SpecialAmount)).Equals(110))), Is.EqualTo(2)); + } + } + + [Test] + public async Task MethodShouldNotExpandForConditionalWithPropertyAccessorAsync() + { + using (var session = OpenSession()) + { + Assert.That(await (session.Query().CountAsync(e => ((object)(e.Paid ? e.Amount : e.SpecialAmount)).Equals(10))), Is.EqualTo(2)); + } + } + + [Test] + public async Task MethodShouldNotExpandForCoalesceWithPropertyAccessorAsync() + { + using (var session = OpenSession()) + { + Assert.That(await (session.Query().CountAsync(e => ((object)(e.SpecialAmount ?? e.Amount)).Equals(100))), Is.EqualTo(2)); + } + } + } +} diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs index 6eb360584df..d57b65898b5 100644 --- a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs @@ -62,6 +62,9 @@ protected override HbmMapping GetMappings() { rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); rc.Property(x => x.InvoiceNumber); + rc.Property(x => x.Amount); + rc.Property(x => x.SpecialAmount); + rc.Property(x => x.Paid); rc.ManyToOne(x => x.Project, m => m.Column("ProjectId")); rc.ManyToOne(x => x.Issue, m => m.Column("IssueId")); }); @@ -122,9 +125,9 @@ protected async Task AreEqualAsync( expectedResult = await (expectedQuery(session.Query()).ToListAsync(cancellationToken)); } catch (OperationCanceledException) { throw; } - catch + catch (Exception e) { - Assert.Ignore("Not currently supported query"); + Assert.Ignore($"Not currently supported query: {e}"); } var testResult = await (actualQuery(session.Query()).ToListAsync(cancellationToken)); diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs index 6bc31cfa9c7..924622c108a 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs @@ -58,5 +58,8 @@ public class Invoice public virtual int InvoiceNumber { get; set; } public virtual Project Project { get; set; } public virtual Issue Issue { get; set; } + public virtual int Amount { get; set; } + public virtual int? SpecialAmount { get; set; } + public virtual bool Paid { get; set; } } } diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/ExpansionRegressionTests.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/ExpansionRegressionTests.cs new file mode 100644 index 00000000000..834a06bea2a --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/ExpansionRegressionTests.cs @@ -0,0 +1,89 @@ +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public class ExpansionRegressionTests : GH1879BaseFixture + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Save(new Invoice { InvoiceNumber = 1, Amount = 10, SpecialAmount = 100, Paid = false }); + session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 100, Paid = true }); + session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = false }); + session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = true }); + + session.Flush(); + transaction.Commit(); + } + } + + protected override void Configure(Configuration configuration) + { + configuration.LinqToHqlGeneratorsRegistry(); + } + + private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry + { + public TestLinqToHqlGeneratorsRegistry() + { + this.Merge(new ObjectEquality()); + } + } + + private class ObjectEquality : IHqlGeneratorForMethod + { + public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(), visitor.Visit(arguments[0]).AsExpression()); + } + + public IEnumerable SupportedMethods + { + get + { + yield return ReflectHelper.GetMethodDefinition(x => x.Equals(x)); + } + } + } + + [Test] + public void MethodShouldNotExpandForNonConditionalOrCoalesce() + { + using (var session = OpenSession()) + { + Assert.That(session.Query().Count(e => ((object)(e.Amount + e.SpecialAmount)).Equals(110)), Is.EqualTo(2)); + } + } + + [Test] + public void MethodShouldNotExpandForConditionalWithPropertyAccessor() + { + using (var session = OpenSession()) + { + Assert.That(session.Query().Count(e => ((object)(e.Paid ? e.Amount : e.SpecialAmount)).Equals(10)), Is.EqualTo(2)); + } + } + + [Test] + public void MethodShouldNotExpandForCoalesceWithPropertyAccessor() + { + using (var session = OpenSession()) + { + Assert.That(session.Query().Count(e => ((object)(e.SpecialAmount ?? e.Amount)).Equals(100)), Is.EqualTo(2)); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs index 926cf7c3f6c..3fe9211a449 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs @@ -49,6 +49,9 @@ protected override HbmMapping GetMappings() { rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); rc.Property(x => x.InvoiceNumber); + rc.Property(x => x.Amount); + rc.Property(x => x.SpecialAmount); + rc.Property(x => x.Paid); rc.ManyToOne(x => x.Project, m => m.Column("ProjectId")); rc.ManyToOne(x => x.Issue, m => m.Column("IssueId")); }); @@ -108,9 +111,9 @@ protected void AreEqual( { expectedResult = expectedQuery(session.Query()).ToList(); } - catch + catch (Exception e) { - Assert.Ignore("Not currently supported query"); + Assert.Ignore($"Not currently supported query: {e}"); } var testResult = actualQuery(session.Query()).ToList(); diff --git a/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs b/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs index 86370dda35d..b83be7077b4 100644 --- a/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs +++ b/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs @@ -77,7 +77,7 @@ public void Transform(ResultOperatorBase resultOperator) protected override Expression VisitMember(MemberExpression node) { var result = (MemberExpression) base.VisitMember(node); - if (QueryReferenceCounter.CountReferences(result.Expression) > 1) + if (ShouldRewrite(result.Expression)) { return ConditionalQueryReferenceMemberExpressionRewriter.Rewrite(result.Expression, node); } @@ -90,39 +90,44 @@ protected override Expression VisitMethodCall(MethodCallExpression node) var isExtension = node.Method.GetCustomAttributes().Any(); var methodObject = isExtension ? node.Arguments[0] : node.Object; - if (methodObject != null && QueryReferenceCounter.CountReferences(methodObject) > 1) + if (ShouldRewrite(methodObject)) { return ConditionalQueryReferenceMethodCallExpressionRewriter.Rewrite(methodObject, node); } return result; } - } - - private class QueryReferenceCounter : RelinqExpressionVisitor - { - private readonly System.Type _queryType; - private int _queryReferenceCount; - private QueryReferenceCounter(System.Type queryType) + private bool ShouldRewrite(Expression expr, System.Type queryType = null) { - _queryType = queryType; - } + if (expr == null) + { + return false; + } + + // Strip Converts + while (expr.NodeType == ExpressionType.Convert || expr.NodeType == ExpressionType.ConvertChecked) + { + expr = ((UnaryExpression)expr).Operand; + } - protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) - { - if (_queryType.IsAssignableFrom(expression.Type)) + if (expr is QuerySourceReferenceExpression && queryType?.IsAssignableFrom(expr.Type) == true) { - _queryReferenceCount++; + return true; } - return base.VisitQuerySourceReference(expression); - } + queryType = queryType ?? expr.Type; - public static int CountReferences(Expression node) - { - var visitor = new QueryReferenceCounter(node.Type); - visitor.Visit(node); - return visitor._queryReferenceCount; + if (expr.NodeType == ExpressionType.Coalesce && expr is BinaryExpression coalesce) + { + return ShouldRewrite(coalesce.Left, queryType) && ShouldRewrite(coalesce.Right, queryType); + } + + if (expr.NodeType == ExpressionType.Conditional && expr is ConditionalExpression conditional) + { + return ShouldRewrite(conditional.IfFalse, queryType) && ShouldRewrite(conditional.IfTrue, queryType); + } + + return false; } }