From 5e6e99ffad3131ac96e68c611aebd56e9ee39eba Mon Sep 17 00:00:00 2001 From: Duncan M Date: Fri, 19 Oct 2018 15:47:57 -0600 Subject: [PATCH 1/5] GH-1879 - Allow Coalesce and Conditional logic on entity properties and collections (LINQ) --- .../CoalesceChildThenAccessExtensionMethod.cs | 158 ++++++++++++++ .../GH1879/CoalesceChildThenAccessMember.cs | 121 +++++++++++ .../GH1879/CoalesceChildThenAccessMethod.cs | 158 ++++++++++++++ .../CoalesceSiblingsThenAccessMember.cs | 106 +++++++++ .../GH1879/ConditionalThenAccessMember.cs | 141 ++++++++++++ .../GH1879/ConditionalThenMethodCall.cs | 108 ++++++++++ .../NHSpecificTest/GH1879/Entity.cs | 60 ++++++ .../NHSpecificTest/GH1879/FixtureByCode.cs | 107 +++++++++ .../NHSpecificTest/GH1879/TestExtensions.cs | 10 + .../ConditionalQueryReferenceExpander.cs | 204 ++++++++++++++++++ .../ReWriters/SubQueryConditionalExpander.cs | 134 ++++++++++++ .../Visitors/MemberExpressionJoinDetector.cs | 12 +- .../Linq/Visitors/QueryModelVisitor.cs | 6 + .../Linq/Visitors/WhereJoinDetector.cs | 8 + 14 files changed, 1330 insertions(+), 3 deletions(-) create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessExtensionMethod.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessMember.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessMethod.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceSiblingsThenAccessMember.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/ConditionalThenAccessMember.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/ConditionalThenMethodCall.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/TestExtensions.cs create mode 100644 src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs create mode 100644 src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessExtensionMethod.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessExtensionMethod.cs new file mode 100644 index 00000000000..d9976fe112e --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessExtensionMethod.cs @@ -0,0 +1,158 @@ +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public class CoalesceChildThenAccessExtensionMethod : GH1879BaseFixture + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var corpA = new CorporateClient { Name = "Alpha", CorporateId = "1234" }; + var corpB = new CorporateClient { Name = "Beta", CorporateId = "5647" }; + var clientZ = new Client { Name = null }; // A null value should propagate if the entity is non-null + session.Save(clientA); + session.Save(clientB); + session.Save(corpA); + session.Save(corpB); + session.Save(clientZ); + + var projectA = new Project { Name = "A", BillingClient = null, Client = clientA }; + var projectB = new Project { Name = "B", BillingClient = corpB, Client = clientA }; + var projectC = new Project { Name = "C", BillingClient = null, Client = clientB }; + var projectD = new Project { Name = "D", BillingClient = corpA, Client = clientB }; + var projectE = new Project { Name = "E", BillingClient = clientZ, Client = clientA }; + var projectZ = new Project { Name = "Z", BillingClient = null, Client = null }; + session.Save(projectA); + session.Save(projectB); + session.Save(projectC); + session.Save(projectD); + session.Save(projectE); + session.Save(projectZ); + + session.Save(new Issue { Name = "01", Project = null, Client = null }); + session.Save(new Issue { Name = "02", Project = null, Client = clientA }); + session.Save(new Issue { Name = "03", Project = null, Client = clientB }); + session.Save(new Issue { Name = "04", Project = projectC, Client = clientA }); + session.Save(new Issue { Name = "05", Project = projectA, Client = clientB }); + session.Save(new Issue { Name = "06", Project = projectB, Client = clientA }); + session.Save(new Issue { Name = "07", Project = projectD, Client = clientB }); + session.Save(new Issue { Name = "08", Project = projectZ, Client = corpA }); + session.Save(new Issue { Name = "09", Project = projectZ, Client = corpB }); + session.Save(new Issue { Name = "10", Project = projectE, Client = clientA }); + + session.Flush(); + transaction.Commit(); + } + } + + protected override void Configure(Configuration configuration) + { + configuration.LinqToHqlGeneratorsRegistry(); + } + + private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry + { + public TestLinqToHqlGeneratorsRegistry() + { + this.Merge(new TestHqlGeneratorForMethod()); + } + } + + private class TestHqlGeneratorForMethod : IHqlGeneratorForMethod + { + /// + public IEnumerable SupportedMethods => new [] + { + ReflectHelper.GetMethodDefinition(x => x.NameByExtension()), + }; + + /// + public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Dot(visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Ident("Name").AsExpression()); + } + } + + [Test] + public void WhereClause() + { + AreEqual( + // Actual + q => q.Where(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension().StartsWith("A")), + // Expected + q => q.Where(i => (i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension()).StartsWith("A")) + ); + } + + [Test] + public void SelectClause() + { + AreEqual( + // Actual + q => q.OrderBy(i =>i.Name) + .Select(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension()), + // Expected + q => q.OrderBy(i =>i.Name) + .Select(i => i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension()) + ); + } + + [Test] + public void SelectClauseToAnon() + { + AreEqual( + // Actual + q => q.OrderBy(i =>i.Name) + .Select(i => new { Key =i.Name, Client = (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension() }), + // Expected + q => q.OrderBy(i =>i.Name) + .Select(i => new { Key =i.Name, Client = i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension() }) + ); + } + + [Test] + public void OrderByClause() + { + AreEqual( + // Actual + q => q.OrderBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension() ?? "ZZZ") + .ThenBy(i =>i.Name) + .Select(i =>i.Name), + // Expected + q => q.OrderBy(i => (i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension()) ?? "ZZZ") + .ThenBy(i =>i.Name) + .Select(i =>i.Name) + ); + } + + [Test] + public void GroupByClause() + { + AreEqual( + // Actual + q => q.GroupBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension()) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(i => i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension()) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + ); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessMember.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessMember.cs new file mode 100644 index 00000000000..fb2b5595b06 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessMember.cs @@ -0,0 +1,121 @@ +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public class CoalesceChildThenAccessMember : GH1879BaseFixture + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var corpA = new CorporateClient { Name = "Alpha", CorporateId = "1234" }; + var corpB = new CorporateClient { Name = "Beta", CorporateId = "5647" }; + var clientZ = new Client { Name = null }; // A null value should propagate if the entity is non-null + session.Save(clientA); + session.Save(clientB); + session.Save(corpA); + session.Save(corpB); + session.Save(clientZ); + + var projectA = new Project { Name = "A", BillingClient = null, Client = clientA }; + var projectB = new Project { Name = "B", BillingClient = corpB, Client = clientA }; + var projectC = new Project { Name = "C", BillingClient = null, Client = clientB }; + var projectD = new Project { Name = "D", BillingClient = corpA, Client = clientB }; + var projectE = new Project { Name = "E", BillingClient = clientZ, Client = clientA }; + var projectZ = new Project { Name = "Z", BillingClient = null, Client = null }; + session.Save(projectA); + session.Save(projectB); + session.Save(projectC); + session.Save(projectD); + session.Save(projectE); + session.Save(projectZ); + + session.Save(new Issue { Name = "01", Project = null, Client = null }); + session.Save(new Issue { Name = "02", Project = null, Client = clientA }); + session.Save(new Issue { Name = "03", Project = null, Client = clientB }); + session.Save(new Issue { Name = "04", Project = projectC, Client = clientA }); + session.Save(new Issue { Name = "05", Project = projectA, Client = clientB }); + session.Save(new Issue { Name = "06", Project = projectB, Client = clientA }); + session.Save(new Issue { Name = "07", Project = projectD, Client = clientB }); + session.Save(new Issue { Name = "08", Project = projectZ, Client = corpA }); + session.Save(new Issue { Name = "09", Project = projectZ, Client = corpB }); + session.Save(new Issue { Name = "10", Project = projectE, Client = clientA }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public void WhereClause() + { + AreEqual( + // Actual + q => q.Where(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name.StartsWith("A")), + // Expected + q => q.Where(i => (i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name).StartsWith("A")) + ); + } + + [Test] + public void SelectClause() + { + AreEqual( + // Actual + q => q.OrderBy(i => i.Name) + .Select(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name), + // Expected + q => q.OrderBy(i => i.Name) + .Select(i => i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name) + ); + } + + [Test] + public void SelectClauseToAnon() + { + AreEqual( + // Actual + q => q.OrderBy(i => i.Name) + .Select(i => new { Key = i.Name, Client = (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name }), + // Expected + q => q.OrderBy(i => i.Name) + .Select(i => new { Key = i.Name, Client = i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name }) + ); + } + + [Test] + public void OrderByClause() + { + AreEqual( + // Actual + q => q.OrderBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name ?? "ZZZ") + .ThenBy(i => i.Name) + .Select(i => i.Name), + // Expected + q => q.OrderBy(i => (i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name) ?? "ZZZ") + .ThenBy(i => i.Name) + .Select(i => i.Name) + ); + } + + [Test] + public void GroupByClause() + { + AreEqual( + // Actual + q => q.GroupBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(i => i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + ); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessMethod.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessMethod.cs new file mode 100644 index 00000000000..119edc3605b --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceChildThenAccessMethod.cs @@ -0,0 +1,158 @@ +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public class CoalesceChildThenAccessMethod : GH1879BaseFixture + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var corpA = new CorporateClient { Name = "Alpha", CorporateId = "1234" }; + var corpB = new CorporateClient { Name = "Beta", CorporateId = "5647" }; + var clientZ = new Client { Name = null }; // A null value should propagate if the entity is non-null + session.Save(clientA); + session.Save(clientB); + session.Save(corpA); + session.Save(corpB); + session.Save(clientZ); + + var projectA = new Project { Name = "A", BillingClient = null, Client = clientA }; + var projectB = new Project { Name = "B", BillingClient = corpB, Client = clientA }; + var projectC = new Project { Name = "C", BillingClient = null, Client = clientB }; + var projectD = new Project { Name = "D", BillingClient = corpA, Client = clientB }; + var projectE = new Project { Name = "E", BillingClient = clientZ, Client = clientA }; + var projectZ = new Project { Name = "Z", BillingClient = null, Client = null }; + session.Save(projectA); + session.Save(projectB); + session.Save(projectC); + session.Save(projectD); + session.Save(projectE); + session.Save(projectZ); + + session.Save(new Issue { Name = "01", Project = null, Client = null }); + session.Save(new Issue { Name = "02", Project = null, Client = clientA }); + session.Save(new Issue { Name = "03", Project = null, Client = clientB }); + session.Save(new Issue { Name = "04", Project = projectC, Client = clientA }); + session.Save(new Issue { Name = "05", Project = projectA, Client = clientB }); + session.Save(new Issue { Name = "06", Project = projectB, Client = clientA }); + session.Save(new Issue { Name = "07", Project = projectD, Client = clientB }); + session.Save(new Issue { Name = "08", Project = projectZ, Client = corpA }); + session.Save(new Issue { Name = "09", Project = projectZ, Client = corpB }); + session.Save(new Issue { Name = "10", Project = projectE, Client = clientA }); + + session.Flush(); + transaction.Commit(); + } + } + + protected override void Configure(Configuration configuration) + { + configuration.LinqToHqlGeneratorsRegistry(); + } + + private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry + { + public TestLinqToHqlGeneratorsRegistry() + { + this.Merge(new TestHqlGeneratorForMethod()); + } + } + + private class TestHqlGeneratorForMethod : IHqlGeneratorForMethod + { + /// + public IEnumerable SupportedMethods => new [] + { + ReflectHelper.GetMethodDefinition(x => x.NameByMethod()), + }; + + /// + public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Dot(visitor.Visit(targetObject).AsExpression(), treeBuilder.Ident("Name").AsExpression()); + } + } + + [Test] + public void WhereClause() + { + AreEqual( + // Actual + q => q.Where(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod().StartsWith("A")), + // Expected + q => q.Where(i => (i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod()).StartsWith("A")) + ); + } + + [Test] + public void SelectClause() + { + AreEqual( + // Actual + q => q.OrderBy(i =>i.Name) + .Select(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod()), + // Expected + q => q.OrderBy(i =>i.Name) + .Select(i => i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod()) + ); + } + + [Test] + public void SelectClauseToAnon() + { + AreEqual( + // Actual + q => q.OrderBy(i =>i.Name) + .Select(i => new { Key =i.Name, Client = (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod() }), + // Expected + q => q.OrderBy(i =>i.Name) + .Select(i => new { Key =i.Name, Client = i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod() }) + ); + } + + [Test] + public void OrderByClause() + { + AreEqual( + // Actual + q => q.OrderBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod() ?? "ZZZ") + .ThenBy(i =>i.Name) + .Select(i =>i.Name), + // Expected + q => q.OrderBy(i => (i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod()) ?? "ZZZ") + .ThenBy(i =>i.Name) + .Select(i =>i.Name) + ); + } + + [Test] + public void GroupByClause() + { + AreEqual( + // Actual + q => q.GroupBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod()) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(i => i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod()) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + ); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceSiblingsThenAccessMember.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceSiblingsThenAccessMember.cs new file mode 100644 index 00000000000..d4724cc6ee9 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/CoalesceSiblingsThenAccessMember.cs @@ -0,0 +1,106 @@ +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public class CoalesceSiblingsThenAccessMember : GH1879BaseFixture + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var corpA = new CorporateClient { Name = "Alpha", CorporateId = "1234" }; + var corpB = new CorporateClient { Name = "Beta", CorporateId = "5647" }; + var clientZ = new Client { Name = null }; // A null value should propagate if the entity is non-null + session.Save(clientA); + session.Save(clientB); + session.Save(corpA); + session.Save(corpB); + session.Save(clientZ); + + session.Save(new Project { Name = "A", BillingClient = null, CorporateClient = null, Client = clientA }); + session.Save(new Project { Name = "B", BillingClient = null, CorporateClient = null, Client = clientB }); + session.Save(new Project { Name = "C", BillingClient = null, CorporateClient = corpA, Client = clientA }); + session.Save(new Project { Name = "D", BillingClient = null, CorporateClient = corpB, Client = clientA }); + session.Save(new Project { Name = "E", BillingClient = corpA, CorporateClient = null, Client = clientA }); + session.Save(new Project { Name = "F", BillingClient = clientB, CorporateClient = null, Client = clientA }); + session.Save(new Project { Name = "G", BillingClient = clientZ, CorporateClient = null, Client = clientA }); + session.Save(new Project { Name = "Z", BillingClient = null, CorporateClient = null, Client = null }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public void WhereClause() + { + AreEqual( + // Actual + q => q.Where(p => (p.BillingClient ?? p.CorporateClient ?? p.Client).Name.StartsWith("A")), + // Expected + q => q.Where(p => (p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name).StartsWith("A")) + ); + } + + [Test] + public void SelectClause() + { + AreEqual( + // Actual + q => q.OrderBy(p => p.Name) + .Select(p => (p.BillingClient ?? p.CorporateClient ?? p.Client).Name), + // Expected + q => q.OrderBy(p => p.Name) + .Select(p => p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name) + ); + } + + [Test] + public void SelectClauseToAnon() + { + AreEqual( + // Actual + q => q.OrderBy(p => p.Name) + .Select(p => new { Project = p.Name, Client = (p.BillingClient ?? p.CorporateClient ?? p.Client).Name }), + // Expected + q => q.OrderBy(p => p.Name) + .Select(p => new { Project = p.Name, Client = p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name }) + ); + } + + [Test] + public void OrderByClause() + { + AreEqual( + // Actual + q => q.OrderBy(p => (p.BillingClient ?? p.CorporateClient ?? p.Client).Name ?? "ZZZ") + .ThenBy(p => p.Name) + .Select(p => p.Name), + // Expected + q => q.OrderBy(p => (p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name) ?? "ZZZ") + .ThenBy(p => p.Name) + .Select(p => p.Name) + ); + } + + [Test] + public void GroupByClause() + { + AreEqual( + // Actual + q => q.GroupBy(p => (p.BillingClient ?? p.CorporateClient ?? p.Client).Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(p => p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + ); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/ConditionalThenAccessMember.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/ConditionalThenAccessMember.cs new file mode 100644 index 00000000000..3b60d65a515 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/ConditionalThenAccessMember.cs @@ -0,0 +1,141 @@ +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public class ConditionalThenAccessMember : GH1879BaseFixture + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var clientC = new CorporateClient { Name = "Charlie", CorporateId = "1234" }; + session.Save(clientA); + session.Save(clientB); + session.Save(clientC); + + session.Save(new Project { Name = "A", EmailPref = EmailPref.Primary, Client = clientA, BillingClient = clientB, CorporateClient = clientC, }); + session.Save(new Project { Name = "B", EmailPref = EmailPref.Billing, Client = clientA, BillingClient = clientB, CorporateClient = clientC, }); + session.Save(new Project { Name = "C", EmailPref = EmailPref.Corp, Client = clientA, BillingClient = clientB, CorporateClient = clientC, }); + + session.Save(new Project { Name = "D", EmailPref = EmailPref.Primary, Client = null, BillingClient = clientB, CorporateClient = clientC, }); + session.Save(new Project { Name = "E", EmailPref = EmailPref.Billing, Client = clientA, BillingClient = null, CorporateClient = clientC, }); + session.Save(new Project { Name = "F", EmailPref = EmailPref.Corp, Client = clientA, BillingClient = clientB, CorporateClient = null, }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public void WhereClause() + { + AreEqual( + // Actual + q => q.Where(p => (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name.Length > 3), + // Expected + q => q.Where(p => (p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name).Length > 3) + ); + } + + [Test] + public void SelectClause() + { + AreEqual( + // Actual + q => q.OrderBy(p => p.Name) + .Select(p => (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name), + // Expected + q => q.OrderBy(p => p.Name) + .Select(p => p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name) + ); + } + + [Test] + public void SelectClauseToAnon() + { + AreEqual( + // Actual + q => q.OrderBy(p => p.Name) + .Select(p => new { p.Name, Client = (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name }), + // Expected + q => q.OrderBy(p => p.Name) + .Select(p => new { p.Name, Client = p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name }) + ); + } + + [Test] + public void OrderByClause() + { + AreEqual( + // Actual + q => q.OrderBy(p => (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name ?? "ZZZ") + .ThenBy(p => p.Name) + .Select(p => p.Name), + // Expected + q => q.OrderBy(p => (p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name) ?? "ZZZ") + .ThenBy(p => p.Name) + .Select(p => p.Name) + ); + } + + [Test] + public void GroupByClause() + { + AreEqual( + // Actual + q => q.GroupBy(p => (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(p => p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + ); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/ConditionalThenMethodCall.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/ConditionalThenMethodCall.cs new file mode 100644 index 00000000000..12cbfb284df --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/ConditionalThenMethodCall.cs @@ -0,0 +1,108 @@ +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public class ConditionalThenMethodCall : GH1879BaseFixture + { + /// + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Alpha" }; + var clientB = new Client { Name = "Beta" }; + session.Save(clientA); + session.Save(clientB); + + var issue1 = new Issue { Name = "1", Client = null }; + var issue2 = new Issue { Name = "2", Client = clientA }; + var issue3 = new Issue { Name = "3", Client = clientA }; + var issue4 = new Issue { Name = "4", Client = clientA }; + var issue5 = new Issue { Name = "5", Client = clientB }; + session.Save(issue1); + session.Save(issue2); + session.Save(issue3); + session.Save(issue4); + session.Save(issue5); + + session.Save(new Employee { Name = "Andy", ReviewAsPrimary = true, ReviewIssues = { issue1, issue2, issue5 }, WorkIssues = { issue3 } }); + session.Save(new Employee { Name = "Bart", ReviewAsPrimary = false, ReviewIssues = { issue3 }, WorkIssues = { issue4, issue5 } }); + session.Save(new Employee { Name = "Carl", ReviewAsPrimary = true, ReviewIssues = { issue3 }, WorkIssues = { issue1, issue4, issue5 } }); + session.Save(new Employee { Name = "Dorn", ReviewAsPrimary = false, ReviewIssues = { issue3 }, WorkIssues = { issue1, issue4 } }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public void WhereClause() + { + AreEqual( + // Conditional style + q => q.Where(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Any(i => i.Client.Name == "Beta")), + // Expected + q => q.Where(e => e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta")) + ); + } + + [Test] + public void SelectClause() + { + AreEqual( + // Conditional style + q => q.OrderBy(e => e.Name) + .Select(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Any(i => i.Client.Name == "Beta")), + // Expected + q => q.OrderBy(e => e.Name) + .Select(e => e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta")) + ); + } + + [Test] + public void SelectClauseToAnon() + { + AreEqual( + // Conditional style + q => q.OrderBy(e => e.Name) + .Select(e => new { e.Name, Beta = (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Any(i => i.Client.Name == "Beta") }), + // Expected + q => q.OrderBy(e => e.Name) + .Select(e => new { e.Name, Beta = e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta") }) + ); + } + + [Test] + public void OrderByClause() + { + AreEqual( + // Conditional style + q => q.OrderBy(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Count()) + .ThenBy(p => p.Name) + .Select(p => p.Name), + // Expected + q => q.OrderBy(e => e.ReviewAsPrimary ? e.ReviewIssues.Count() : e.WorkIssues.Count()) + .ThenBy(p => p.Name) + .Select(p => p.Name) + ); + } + + [Test] + public void GroupByClause() + { + AreEqual( + // Conditional style + q => q.GroupBy(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Count()) + .OrderBy(x => x.Key) + .Select(grp => grp.Count()), + // Expected + q => q.GroupBy(e => e.ReviewAsPrimary ? e.ReviewIssues.Count() : e.WorkIssues.Count()) + .OrderBy(x => x.Key) + .Select(grp => grp.Count()) + ); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs new file mode 100644 index 00000000000..402395f974b --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs @@ -0,0 +1,60 @@ +using System; +using System.Collections.Generic; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + public class Client + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual string NameByMethod() => Name; + } + + public class CorporateClient : Client + { + public virtual string CorporateId { get; set; } + } + + public class Employee + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + + public virtual bool ReviewAsPrimary { get; set; } + public virtual ICollection WorkIssues { get; set; } = new List(); + public virtual ICollection ReviewIssues { get; set; } = new List(); + } + + public class Project + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual EmailPref EmailPref { get; set; } + public virtual Client Client { get; set; } + public virtual Client BillingClient { get; set; } + public virtual CorporateClient CorporateClient { get; set; } + } + + public enum EmailPref + { + Primary, + Billing, + Corp, + } + + public class Issue + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual Client Client { get; set; } + public virtual Project Project { get; set; } + } + + public class Invoice + { + public virtual Guid Id { get; set; } + public virtual int InvoiceNumber { get; set; } + public virtual Project Project { get; set; } + public virtual Issue Issue { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs new file mode 100644 index 00000000000..cbf5df62ec2 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Mapping.ByCode; +using NHibernate.Type; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public abstract class GH1879BaseFixture : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + }); + mapper.JoinedSubclass(rc => + { + rc.Property(x => x.CorporateId); + }); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + rc.Property(x => x.EmailPref, m => m.Type>()); + rc.ManyToOne(x => x.Client, m => m.Column("ClientId")); + rc.ManyToOne(x => x.BillingClient, m => m.Column("BillingClientId")); + rc.ManyToOne(x => x.CorporateClient, m => m.Column("CorporateClientId")); + }); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + rc.ManyToOne(x => x.Project, m => m.Column("ProjectId")); + rc.ManyToOne(x => x.Client, m => m.Column("ClientId")); + }); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.InvoiceNumber); + rc.ManyToOne(x => x.Project, m => m.Column("ProjectId")); + rc.ManyToOne(x => x.Issue, m => m.Column("IssueId")); + }); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + rc.Property(x => x.ReviewAsPrimary); + rc.Set(x => x.WorkIssues, + m => + { + m.Table("EmployeesToWorkIssues"); + m.Cascade(Mapping.ByCode.Cascade.All | Mapping.ByCode.Cascade.DeleteOrphans); + m.Key(k => k.Column(c => c.Name("EmployeeId")) ); + }, + rel => rel.ManyToMany(m => m.Column("IssueId"))); + rc.Set(x => x.ReviewIssues, + m => + { + m.Table("EmployeesToReviewIssues"); + m.Cascade(Mapping.ByCode.Cascade.All | Mapping.ByCode.Cascade.DeleteOrphans); + m.Key(k => k.Column(c => c.Name("EmployeeId")) ); + }, + rel => rel.ManyToMany(m => m.Column("IssueId"))); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + session.Flush(); + transaction.Commit(); + } + } + + protected void AreEqual( + Func, IQueryable> actualQuery, + Func, IQueryable> expectedQuery) + { + using (var session = OpenSession()) + { + IEnumerable expectedResult = null; + try + { + expectedResult = expectedQuery(session.Query()).ToList(); + } + catch + { + Assert.Ignore("Not currently supported query"); + } + + var testResult = actualQuery(session.Query()).ToList(); + Assert.That(testResult, Is.EqualTo(expectedResult)); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/TestExtensions.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/TestExtensions.cs new file mode 100644 index 00000000000..c9d68cdaf57 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/TestExtensions.cs @@ -0,0 +1,10 @@ +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + internal static class TestExtensions + { + public static string NameByExtension(this Client client) + { + return client.Name; + } + } +} diff --git a/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs b/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs new file mode 100644 index 00000000000..b3d878c107a --- /dev/null +++ b/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs @@ -0,0 +1,204 @@ +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using NHibernate.Linq.Clauses; +using NHibernate.Linq.Visitors; +using Remotion.Linq; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.ReWriters +{ + /// + /// Expands conditional and coalesce expressions that are merging QueryReferences so that they can be followed by + /// Member or Method calls. + /// Ex) query.Where(x => (x.OptionA ?? x.OptionB).Value == value); + /// query.Where(x => (x.UseA ? x.OptionA : x.OptionB).Value = value); + /// + internal class ConditionalQueryReferenceExpander : NhQueryModelVisitorBase + { + private readonly ConditionalQueryReferenceExpressionExpander _expander; + + private ConditionalQueryReferenceExpander() + { + _expander = new ConditionalQueryReferenceExpressionExpander(); + } + + public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) + { + _expander.Transform(selectClause); + } + + public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index) + { + _expander.Transform(ordering); + } + + public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) + { + _expander.Transform(resultOperator); + } + + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) + { + _expander.Transform(whereClause); + } + + public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel queryModel, int index) + { + _expander.Transform(havingClause); + } + + public static void ReWrite(QueryModel queryModel) + { + var visitor = new ConditionalQueryReferenceExpander(); + visitor.VisitQueryModel(queryModel); + } + + private class ConditionalQueryReferenceExpressionExpander : RelinqExpressionVisitor + { + public void Transform(IClause clause) + { + clause.TransformExpressions(Visit); + } + + public void Transform(Ordering ordering) + { + ordering.TransformExpressions(Visit); + } + + public void Transform(ResultOperatorBase resultOperator) + { + resultOperator.TransformExpressions(Visit); + } + + protected override Expression VisitMember(MemberExpression node) + { + var result = (MemberExpression)base.VisitMember(node); + if (QueryReferenceCounter.CountReferences(result.Expression) > 1) + { + return ConditionalQueryReferenceMemberExpressionRewriter.Rewrite(result.Expression, node); + } + return result; + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + var result = (MethodCallExpression)base.VisitMethodCall(node); + var isExtension = node.Method.GetCustomAttributes().Any(); + var methodObject = isExtension ? node.Arguments[0] : node.Object; + + if (methodObject != null && QueryReferenceCounter.CountReferences(methodObject) > 1) + { + return ConditionalQueryReferenceMethodCallExpressionRewriter.Rewrite(methodObject, node); + } + return result; + } + } + + private class QueryReferenceCounter : RelinqExpressionVisitor + { + private readonly System.Type _queryType; + private int _queryReferenceCount; + + private QueryReferenceCounter(System.Type queryType) + { + _queryType = queryType; + } + + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) + { + if (_queryType.IsAssignableFrom(expression.Type)) + { + _queryReferenceCount++; + } + + return base.VisitQuerySourceReference(expression); + } + + public static int CountReferences(Expression node) + { + var visitor = new QueryReferenceCounter(node.Type); + visitor.Visit(node); + return visitor._queryReferenceCount; + } + } + + private abstract class ConditionalQueryReferenceExpressionRewriter : RelinqExpressionVisitor + where T : Expression + where TVisitor : ConditionalQueryReferenceExpressionRewriter, new() + { + protected T OuterExpr { get; private set; } + + private bool _skipUpdate; + private System.Type _queryType; + + protected override Expression VisitBinary(BinaryExpression node) + { + if (node.NodeType != ExpressionType.Coalesce) + { + return base.VisitBinary(node); + } + + // Coalesce expressions must be rewritten to conditionals to keep their logical meaning + // (x ?? y).Prop --> x != null ? x.Prop : y.Prop + return Expression.Condition( + Expression.NotEqual(node.Left, Expression.Constant(null, node.Left.Type)), + Visit(node.Left), + Visit(node.Right)); + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + _skipUpdate = true; + var test = Visit(node.Test); + _skipUpdate = false; + return Expression.Condition(test, Visit(node.IfTrue), Visit(node.IfFalse)); + } + + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) + { + if (!_skipUpdate && _queryType.IsAssignableFrom(expression.Type)) + { + return Rewrite(expression); + } + + return base.VisitQuerySourceReference(expression); + } + + protected abstract T Rewrite(QuerySourceReferenceExpression expression); + + public static Expression Rewrite(Expression expression, T outerExpr) + { + var visitor = new TVisitor { OuterExpr = outerExpr, _queryType = expression.Type }; + return visitor.Visit(expression); + } + } + + private class ConditionalQueryReferenceMemberExpressionRewriter : ConditionalQueryReferenceExpressionRewriter + { + protected override MemberExpression Rewrite(QuerySourceReferenceExpression expression) + { + return Expression.MakeMemberAccess(expression, OuterExpr.Member); + } + } + + private class ConditionalQueryReferenceMethodCallExpressionRewriter : ConditionalQueryReferenceExpressionRewriter + { + protected override MethodCallExpression Rewrite(QuerySourceReferenceExpression expression) + { + var isExtension = OuterExpr.Method.GetCustomAttributes().Any(); + if (isExtension) + { + var argList = OuterExpr.Arguments.ToArray(); + argList[0] = expression; + return Expression.Call(null, OuterExpr.Method, argList); + } + + return Expression.Call(expression, OuterExpr.Method, OuterExpr.Arguments); + } + } + } +} diff --git a/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs b/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs new file mode 100644 index 00000000000..8db8aa6ed70 --- /dev/null +++ b/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs @@ -0,0 +1,134 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using NHibernate.Linq.Clauses; +using NHibernate.Linq.Visitors; +using Remotion.Linq; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.ReWriters +{ + /// + /// Expands conditionals within subquery FROM clauses. + /// It does this by moving the conditional expression outside of the subquery and cloning the subquery, + /// replacing the FROM clause with the collection parts of the conditional. + /// + internal class SubQueryConditionalExpander : NhQueryModelVisitorBase + { + private readonly SubQueryConditionalExpressionExpander _expander = new SubQueryConditionalExpressionExpander(); + + public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) + { + _expander.Transform(selectClause); + } + + public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index) + { + _expander.Transform(ordering); + } + + public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) + { + _expander.Transform(resultOperator); + } + + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) + { + _expander.Transform(whereClause); + } + + public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel queryModel, int index) + { + _expander.Transform(havingClause); + } + + public static void ReWrite(QueryModel queryModel) + { + var visitor = new SubQueryConditionalExpander(); + visitor.VisitQueryModel(queryModel); + } + + private class SubQueryConditionalExpressionExpander : RelinqExpressionVisitor + { + public void Transform(IClause clause) + { + clause.TransformExpressions(Visit); + } + + public void Transform(Ordering ordering) + { + ordering.TransformExpressions(Visit); + } + + public void Transform(ResultOperatorBase resultOperator) + { + resultOperator.TransformExpressions(Visit); + } + + protected override Expression VisitSubQuery(SubQueryExpression expression) + { + var fromClauseExpander = new SubQueryFromClauseExpander(expression.QueryModel); + var fromExpr = fromClauseExpander.Visit(expression.QueryModel.MainFromClause.FromExpression); + return fromClauseExpander.Rewritten ? fromExpr : expression; + } + } + + private class SubQueryFromClauseExpander : RelinqExpressionVisitor + { + private readonly QueryModel _originalSubQueryModel; + private int _depth = -1; + private readonly IList _nominate = new List(); + + public bool Rewritten { get; private set; } + + public SubQueryFromClauseExpander(QueryModel originalSubQueryModel) + { + _originalSubQueryModel = originalSubQueryModel; + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + if (_depth >= 0) + { + _nominate[_depth] = false; + } + + var newTest = this.Visit(node.Test); + _nominate.Insert(++_depth, false); + var newTrue = this.Visit(node.IfTrue); + if (_nominate[_depth]) + { + newTrue = BuildNewSubQuery(newTrue); + Rewritten = true; + } + _nominate.Insert(_depth, false); + var newFalse = this.Visit(node.IfFalse); + if (_nominate[_depth]) + { + newFalse = BuildNewSubQuery(newFalse); + Rewritten = true; + } + _nominate.RemoveAt(_depth--); + return Expression.Condition(newTest, newTrue, newFalse); + } + + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) + { + if (_depth >= 0) + { + _nominate[_depth] = true; + } + + return base.VisitQuerySourceReference(expression); + } + + private SubQueryExpression BuildNewSubQuery(Expression fromExpr) + { + var newSubQuery = _originalSubQueryModel.Clone(); + newSubQuery.MainFromClause.FromExpression = fromExpr; + return new SubQueryExpression(newSubQuery); + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs index 624a9aae8ad..934fba8ec94 100644 --- a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs @@ -55,6 +55,14 @@ protected override Expression VisitMember(MemberExpression expression) return result; } + protected override Expression VisitMethodCall(MethodCallExpression node) + { + _memberExpressionDepth++; + var result = base.VisitMethodCall(node); + _memberExpressionDepth--; + return result; + } + protected override Expression VisitSubQuery(SubQueryExpression expression) { expression.QueryModel.TransformExpressions(Visit); @@ -69,9 +77,7 @@ protected override Expression VisitConditional(ConditionalExpression expression) _requiresJoinForNonIdentifier = oldRequiresJoinForNonIdentifier; var newFalse = Visit(expression.IfFalse); var newTrue = Visit(expression.IfTrue); - if ((newTest != expression.Test) || (newFalse != expression.IfFalse) || (newTrue != expression.IfTrue)) - return Expression.Condition(newTest, newTrue, newFalse); - return expression; + return expression.Update(newTest, newTrue, newFalse); } protected override Expression VisitExtension(Expression expression) diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 36d40e2be3e..e549843db9e 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -27,6 +27,9 @@ public class QueryModelVisitor : NhQueryModelVisitorBase, INhQueryModelVisitor public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root, NhLinqExpressionReturnType? rootReturnType) { + // Expand conditionals in subquery FROM clauses into multiple subqueries + SubQueryConditionalExpander.ReWrite(queryModel); + NestedSelectRewriter.ReWrite(queryModel, parameters.SessionFactory); // Remove unnecessary body operators @@ -64,6 +67,9 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer // Add joins for references AddJoinsReWriter.ReWrite(queryModel, parameters); + // Expand coalesced and conditional joins to their logical equivalents + ConditionalQueryReferenceExpander.ReWrite(queryModel); + // Move OrderBy clauses to end MoveOrderByToEndRewriter.ReWrite(queryModel); diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index e86212d8e0f..66508f01eef 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -325,6 +325,14 @@ protected override Expression VisitMember(MemberExpression expression) return result; } + protected override Expression VisitMethodCall(MethodCallExpression node) + { + _memberExpressionDepth++; + var result = base.VisitMethodCall(node); + _memberExpressionDepth--; + return result; + } + private void SetResultValues(ExpressionValues values) { _handled.Pop(); From e943374bc1fded7b4500593079bd558f196518ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Tue, 23 Oct 2018 13:10:33 +0200 Subject: [PATCH 2/5] fixup! GH-1879 - Allow Coalesce and Conditional logic on entity properties and collections (LINQ) Add missing async generation --- .../CoalesceChildThenAccessExtensionMethod.cs | 169 ++++++++++++++++++ .../GH1879/CoalesceChildThenAccessMember.cs | 132 ++++++++++++++ .../GH1879/CoalesceChildThenAccessMethod.cs | 169 ++++++++++++++++++ .../CoalesceSiblingsThenAccessMember.cs | 117 ++++++++++++ .../GH1879/ConditionalThenAccessMember.cs | 152 ++++++++++++++++ .../GH1879/ConditionalThenMethodCall.cs | 119 ++++++++++++ .../NHSpecificTest/GH1879/FixtureByCode.cs | 121 +++++++++++++ 7 files changed, 979 insertions(+) create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessExtensionMethod.cs create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessMember.cs create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessMethod.cs create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceSiblingsThenAccessMember.cs create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH1879/ConditionalThenAccessMember.cs create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH1879/ConditionalThenMethodCall.cs create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessExtensionMethod.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessExtensionMethod.cs new file mode 100644 index 00000000000..f89d7434072 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessExtensionMethod.cs @@ -0,0 +1,169 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + [TestFixture] + public class CoalesceChildThenAccessExtensionMethodAsync : GH1879BaseFixtureAsync + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var corpA = new CorporateClient { Name = "Alpha", CorporateId = "1234" }; + var corpB = new CorporateClient { Name = "Beta", CorporateId = "5647" }; + var clientZ = new Client { Name = null }; // A null value should propagate if the entity is non-null + session.Save(clientA); + session.Save(clientB); + session.Save(corpA); + session.Save(corpB); + session.Save(clientZ); + + var projectA = new Project { Name = "A", BillingClient = null, Client = clientA }; + var projectB = new Project { Name = "B", BillingClient = corpB, Client = clientA }; + var projectC = new Project { Name = "C", BillingClient = null, Client = clientB }; + var projectD = new Project { Name = "D", BillingClient = corpA, Client = clientB }; + var projectE = new Project { Name = "E", BillingClient = clientZ, Client = clientA }; + var projectZ = new Project { Name = "Z", BillingClient = null, Client = null }; + session.Save(projectA); + session.Save(projectB); + session.Save(projectC); + session.Save(projectD); + session.Save(projectE); + session.Save(projectZ); + + session.Save(new Issue { Name = "01", Project = null, Client = null }); + session.Save(new Issue { Name = "02", Project = null, Client = clientA }); + session.Save(new Issue { Name = "03", Project = null, Client = clientB }); + session.Save(new Issue { Name = "04", Project = projectC, Client = clientA }); + session.Save(new Issue { Name = "05", Project = projectA, Client = clientB }); + session.Save(new Issue { Name = "06", Project = projectB, Client = clientA }); + session.Save(new Issue { Name = "07", Project = projectD, Client = clientB }); + session.Save(new Issue { Name = "08", Project = projectZ, Client = corpA }); + session.Save(new Issue { Name = "09", Project = projectZ, Client = corpB }); + session.Save(new Issue { Name = "10", Project = projectE, Client = clientA }); + + session.Flush(); + transaction.Commit(); + } + } + + protected override void Configure(Configuration configuration) + { + configuration.LinqToHqlGeneratorsRegistry(); + } + + private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry + { + public TestLinqToHqlGeneratorsRegistry() + { + this.Merge(new TestHqlGeneratorForMethod()); + } + } + + private class TestHqlGeneratorForMethod : IHqlGeneratorForMethod + { + /// + public IEnumerable SupportedMethods => new [] + { + ReflectHelper.GetMethodDefinition(x => x.NameByExtension()), + }; + + /// + public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Dot(visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Ident("Name").AsExpression()); + } + } + + [Test] + public async Task WhereClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.Where(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension().StartsWith("A")), + // Expected + q => q.Where(i => (i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension()).StartsWith("A")) + )); + } + + [Test] + public async Task SelectClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i =>i.Name) + .Select(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension()), + // Expected + q => q.OrderBy(i =>i.Name) + .Select(i => i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension()) + )); + } + + [Test] + public async Task SelectClauseToAnonAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i =>i.Name) + .Select(i => new { Key =i.Name, Client = (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension() }), + // Expected + q => q.OrderBy(i =>i.Name) + .Select(i => new { Key =i.Name, Client = i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension() }) + )); + } + + [Test] + public async Task OrderByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension() ?? "ZZZ") + .ThenBy(i =>i.Name) + .Select(i =>i.Name), + // Expected + q => q.OrderBy(i => (i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension()) ?? "ZZZ") + .ThenBy(i =>i.Name) + .Select(i =>i.Name) + )); + } + + [Test] + public async Task GroupByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.GroupBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByExtension()) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(i => i.Project.BillingClient != null ? i.Project.BillingClient.NameByExtension() : i.Project.Client != null ? i.Project.Client.NameByExtension() : i.Client.NameByExtension()) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + )); + } + } +} diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessMember.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessMember.cs new file mode 100644 index 00000000000..02e98218972 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessMember.cs @@ -0,0 +1,132 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + [TestFixture] + public class CoalesceChildThenAccessMemberAsync : GH1879BaseFixtureAsync + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var corpA = new CorporateClient { Name = "Alpha", CorporateId = "1234" }; + var corpB = new CorporateClient { Name = "Beta", CorporateId = "5647" }; + var clientZ = new Client { Name = null }; // A null value should propagate if the entity is non-null + session.Save(clientA); + session.Save(clientB); + session.Save(corpA); + session.Save(corpB); + session.Save(clientZ); + + var projectA = new Project { Name = "A", BillingClient = null, Client = clientA }; + var projectB = new Project { Name = "B", BillingClient = corpB, Client = clientA }; + var projectC = new Project { Name = "C", BillingClient = null, Client = clientB }; + var projectD = new Project { Name = "D", BillingClient = corpA, Client = clientB }; + var projectE = new Project { Name = "E", BillingClient = clientZ, Client = clientA }; + var projectZ = new Project { Name = "Z", BillingClient = null, Client = null }; + session.Save(projectA); + session.Save(projectB); + session.Save(projectC); + session.Save(projectD); + session.Save(projectE); + session.Save(projectZ); + + session.Save(new Issue { Name = "01", Project = null, Client = null }); + session.Save(new Issue { Name = "02", Project = null, Client = clientA }); + session.Save(new Issue { Name = "03", Project = null, Client = clientB }); + session.Save(new Issue { Name = "04", Project = projectC, Client = clientA }); + session.Save(new Issue { Name = "05", Project = projectA, Client = clientB }); + session.Save(new Issue { Name = "06", Project = projectB, Client = clientA }); + session.Save(new Issue { Name = "07", Project = projectD, Client = clientB }); + session.Save(new Issue { Name = "08", Project = projectZ, Client = corpA }); + session.Save(new Issue { Name = "09", Project = projectZ, Client = corpB }); + session.Save(new Issue { Name = "10", Project = projectE, Client = clientA }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public async Task WhereClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.Where(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name.StartsWith("A")), + // Expected + q => q.Where(i => (i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name).StartsWith("A")) + )); + } + + [Test] + public async Task SelectClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i => i.Name) + .Select(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name), + // Expected + q => q.OrderBy(i => i.Name) + .Select(i => i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name) + )); + } + + [Test] + public async Task SelectClauseToAnonAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i => i.Name) + .Select(i => new { Key = i.Name, Client = (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name }), + // Expected + q => q.OrderBy(i => i.Name) + .Select(i => new { Key = i.Name, Client = i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name }) + )); + } + + [Test] + public async Task OrderByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name ?? "ZZZ") + .ThenBy(i => i.Name) + .Select(i => i.Name), + // Expected + q => q.OrderBy(i => (i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name) ?? "ZZZ") + .ThenBy(i => i.Name) + .Select(i => i.Name) + )); + } + + [Test] + public async Task GroupByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.GroupBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(i => i.Project.BillingClient != null ? i.Project.BillingClient.Name : i.Project.Client != null ? i.Project.Client.Name : i.Client.Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + )); + } + } +} diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessMethod.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessMethod.cs new file mode 100644 index 00000000000..2177730e7c1 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceChildThenAccessMethod.cs @@ -0,0 +1,169 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + [TestFixture] + public class CoalesceChildThenAccessMethodAsync : GH1879BaseFixtureAsync + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var corpA = new CorporateClient { Name = "Alpha", CorporateId = "1234" }; + var corpB = new CorporateClient { Name = "Beta", CorporateId = "5647" }; + var clientZ = new Client { Name = null }; // A null value should propagate if the entity is non-null + session.Save(clientA); + session.Save(clientB); + session.Save(corpA); + session.Save(corpB); + session.Save(clientZ); + + var projectA = new Project { Name = "A", BillingClient = null, Client = clientA }; + var projectB = new Project { Name = "B", BillingClient = corpB, Client = clientA }; + var projectC = new Project { Name = "C", BillingClient = null, Client = clientB }; + var projectD = new Project { Name = "D", BillingClient = corpA, Client = clientB }; + var projectE = new Project { Name = "E", BillingClient = clientZ, Client = clientA }; + var projectZ = new Project { Name = "Z", BillingClient = null, Client = null }; + session.Save(projectA); + session.Save(projectB); + session.Save(projectC); + session.Save(projectD); + session.Save(projectE); + session.Save(projectZ); + + session.Save(new Issue { Name = "01", Project = null, Client = null }); + session.Save(new Issue { Name = "02", Project = null, Client = clientA }); + session.Save(new Issue { Name = "03", Project = null, Client = clientB }); + session.Save(new Issue { Name = "04", Project = projectC, Client = clientA }); + session.Save(new Issue { Name = "05", Project = projectA, Client = clientB }); + session.Save(new Issue { Name = "06", Project = projectB, Client = clientA }); + session.Save(new Issue { Name = "07", Project = projectD, Client = clientB }); + session.Save(new Issue { Name = "08", Project = projectZ, Client = corpA }); + session.Save(new Issue { Name = "09", Project = projectZ, Client = corpB }); + session.Save(new Issue { Name = "10", Project = projectE, Client = clientA }); + + session.Flush(); + transaction.Commit(); + } + } + + protected override void Configure(Configuration configuration) + { + configuration.LinqToHqlGeneratorsRegistry(); + } + + private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry + { + public TestLinqToHqlGeneratorsRegistry() + { + this.Merge(new TestHqlGeneratorForMethod()); + } + } + + private class TestHqlGeneratorForMethod : IHqlGeneratorForMethod + { + /// + public IEnumerable SupportedMethods => new [] + { + ReflectHelper.GetMethodDefinition(x => x.NameByMethod()), + }; + + /// + public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Dot(visitor.Visit(targetObject).AsExpression(), treeBuilder.Ident("Name").AsExpression()); + } + } + + [Test] + public async Task WhereClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.Where(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod().StartsWith("A")), + // Expected + q => q.Where(i => (i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod()).StartsWith("A")) + )); + } + + [Test] + public async Task SelectClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i =>i.Name) + .Select(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod()), + // Expected + q => q.OrderBy(i =>i.Name) + .Select(i => i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod()) + )); + } + + [Test] + public async Task SelectClauseToAnonAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i =>i.Name) + .Select(i => new { Key =i.Name, Client = (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod() }), + // Expected + q => q.OrderBy(i =>i.Name) + .Select(i => new { Key =i.Name, Client = i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod() }) + )); + } + + [Test] + public async Task OrderByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod() ?? "ZZZ") + .ThenBy(i =>i.Name) + .Select(i =>i.Name), + // Expected + q => q.OrderBy(i => (i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod()) ?? "ZZZ") + .ThenBy(i =>i.Name) + .Select(i =>i.Name) + )); + } + + [Test] + public async Task GroupByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.GroupBy(i => (i.Project.BillingClient ?? i.Project.Client ?? i.Client).NameByMethod()) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(i => i.Project.BillingClient != null ? i.Project.BillingClient.NameByMethod() : i.Project.Client != null ? i.Project.Client.NameByMethod() : i.Client.NameByMethod()) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + )); + } + } +} diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceSiblingsThenAccessMember.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceSiblingsThenAccessMember.cs new file mode 100644 index 00000000000..10886b43bb3 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/CoalesceSiblingsThenAccessMember.cs @@ -0,0 +1,117 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + [TestFixture] + public class CoalesceSiblingsThenAccessMemberAsync : GH1879BaseFixtureAsync + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var corpA = new CorporateClient { Name = "Alpha", CorporateId = "1234" }; + var corpB = new CorporateClient { Name = "Beta", CorporateId = "5647" }; + var clientZ = new Client { Name = null }; // A null value should propagate if the entity is non-null + session.Save(clientA); + session.Save(clientB); + session.Save(corpA); + session.Save(corpB); + session.Save(clientZ); + + session.Save(new Project { Name = "A", BillingClient = null, CorporateClient = null, Client = clientA }); + session.Save(new Project { Name = "B", BillingClient = null, CorporateClient = null, Client = clientB }); + session.Save(new Project { Name = "C", BillingClient = null, CorporateClient = corpA, Client = clientA }); + session.Save(new Project { Name = "D", BillingClient = null, CorporateClient = corpB, Client = clientA }); + session.Save(new Project { Name = "E", BillingClient = corpA, CorporateClient = null, Client = clientA }); + session.Save(new Project { Name = "F", BillingClient = clientB, CorporateClient = null, Client = clientA }); + session.Save(new Project { Name = "G", BillingClient = clientZ, CorporateClient = null, Client = clientA }); + session.Save(new Project { Name = "Z", BillingClient = null, CorporateClient = null, Client = null }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public async Task WhereClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.Where(p => (p.BillingClient ?? p.CorporateClient ?? p.Client).Name.StartsWith("A")), + // Expected + q => q.Where(p => (p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name).StartsWith("A")) + )); + } + + [Test] + public async Task SelectClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(p => p.Name) + .Select(p => (p.BillingClient ?? p.CorporateClient ?? p.Client).Name), + // Expected + q => q.OrderBy(p => p.Name) + .Select(p => p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name) + )); + } + + [Test] + public async Task SelectClauseToAnonAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(p => p.Name) + .Select(p => new { Project = p.Name, Client = (p.BillingClient ?? p.CorporateClient ?? p.Client).Name }), + // Expected + q => q.OrderBy(p => p.Name) + .Select(p => new { Project = p.Name, Client = p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name }) + )); + } + + [Test] + public async Task OrderByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(p => (p.BillingClient ?? p.CorporateClient ?? p.Client).Name ?? "ZZZ") + .ThenBy(p => p.Name) + .Select(p => p.Name), + // Expected + q => q.OrderBy(p => (p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name) ?? "ZZZ") + .ThenBy(p => p.Name) + .Select(p => p.Name) + )); + } + + [Test] + public async Task GroupByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.GroupBy(p => (p.BillingClient ?? p.CorporateClient ?? p.Client).Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(p => p.BillingClient != null ? p.BillingClient.Name : p.CorporateClient != null ? p.CorporateClient.Name : p.Client.Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + )); + } + } +} diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ConditionalThenAccessMember.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ConditionalThenAccessMember.cs new file mode 100644 index 00000000000..03bbd16368b --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ConditionalThenAccessMember.cs @@ -0,0 +1,152 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + [TestFixture] + public class ConditionalThenAccessMemberAsync : GH1879BaseFixtureAsync + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Albert" }; + var clientB = new Client { Name = "Bob" }; + var clientC = new CorporateClient { Name = "Charlie", CorporateId = "1234" }; + session.Save(clientA); + session.Save(clientB); + session.Save(clientC); + + session.Save(new Project { Name = "A", EmailPref = EmailPref.Primary, Client = clientA, BillingClient = clientB, CorporateClient = clientC, }); + session.Save(new Project { Name = "B", EmailPref = EmailPref.Billing, Client = clientA, BillingClient = clientB, CorporateClient = clientC, }); + session.Save(new Project { Name = "C", EmailPref = EmailPref.Corp, Client = clientA, BillingClient = clientB, CorporateClient = clientC, }); + + session.Save(new Project { Name = "D", EmailPref = EmailPref.Primary, Client = null, BillingClient = clientB, CorporateClient = clientC, }); + session.Save(new Project { Name = "E", EmailPref = EmailPref.Billing, Client = clientA, BillingClient = null, CorporateClient = clientC, }); + session.Save(new Project { Name = "F", EmailPref = EmailPref.Corp, Client = clientA, BillingClient = clientB, CorporateClient = null, }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public async Task WhereClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.Where(p => (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name.Length > 3), + // Expected + q => q.Where(p => (p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name).Length > 3) + )); + } + + [Test] + public async Task SelectClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(p => p.Name) + .Select(p => (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name), + // Expected + q => q.OrderBy(p => p.Name) + .Select(p => p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name) + )); + } + + [Test] + public async Task SelectClauseToAnonAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(p => p.Name) + .Select(p => new { p.Name, Client = (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name }), + // Expected + q => q.OrderBy(p => p.Name) + .Select(p => new { p.Name, Client = p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name }) + )); + } + + [Test] + public async Task OrderByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.OrderBy(p => (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name ?? "ZZZ") + .ThenBy(p => p.Name) + .Select(p => p.Name), + // Expected + q => q.OrderBy(p => (p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name) ?? "ZZZ") + .ThenBy(p => p.Name) + .Select(p => p.Name) + )); + } + + [Test] + public async Task GroupByClauseAsync() + { + await (AreEqualAsync( + // Actual + q => q.GroupBy(p => (p.EmailPref == EmailPref.Primary + ? p.Client + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient + : p.BillingClient).Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }), + // Expected + q => q.GroupBy(p => p.EmailPref == EmailPref.Primary + ? p.Client.Name + : p.EmailPref == EmailPref.Corp + ? p.CorporateClient.Name + : p.BillingClient.Name) + .OrderBy(x => x.Key ?? "ZZZ") + .Select(grp => new { grp.Key, Count = grp.Count() }) + )); + } + } +} diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ConditionalThenMethodCall.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ConditionalThenMethodCall.cs new file mode 100644 index 00000000000..2bfd9b12f39 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/ConditionalThenMethodCall.cs @@ -0,0 +1,119 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + [TestFixture] + public class ConditionalThenMethodCallAsync : GH1879BaseFixtureAsync + { + /// + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Alpha" }; + var clientB = new Client { Name = "Beta" }; + session.Save(clientA); + session.Save(clientB); + + var issue1 = new Issue { Name = "1", Client = null }; + var issue2 = new Issue { Name = "2", Client = clientA }; + var issue3 = new Issue { Name = "3", Client = clientA }; + var issue4 = new Issue { Name = "4", Client = clientA }; + var issue5 = new Issue { Name = "5", Client = clientB }; + session.Save(issue1); + session.Save(issue2); + session.Save(issue3); + session.Save(issue4); + session.Save(issue5); + + session.Save(new Employee { Name = "Andy", ReviewAsPrimary = true, ReviewIssues = { issue1, issue2, issue5 }, WorkIssues = { issue3 } }); + session.Save(new Employee { Name = "Bart", ReviewAsPrimary = false, ReviewIssues = { issue3 }, WorkIssues = { issue4, issue5 } }); + session.Save(new Employee { Name = "Carl", ReviewAsPrimary = true, ReviewIssues = { issue3 }, WorkIssues = { issue1, issue4, issue5 } }); + session.Save(new Employee { Name = "Dorn", ReviewAsPrimary = false, ReviewIssues = { issue3 }, WorkIssues = { issue1, issue4 } }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public async Task WhereClauseAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.Where(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Any(i => i.Client.Name == "Beta")), + // Expected + q => q.Where(e => e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta")) + )); + } + + [Test] + public async Task SelectClauseAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.OrderBy(e => e.Name) + .Select(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Any(i => i.Client.Name == "Beta")), + // Expected + q => q.OrderBy(e => e.Name) + .Select(e => e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta")) + )); + } + + [Test] + public async Task SelectClauseToAnonAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.OrderBy(e => e.Name) + .Select(e => new { e.Name, Beta = (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Any(i => i.Client.Name == "Beta") }), + // Expected + q => q.OrderBy(e => e.Name) + .Select(e => new { e.Name, Beta = e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta") }) + )); + } + + [Test] + public async Task OrderByClauseAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.OrderBy(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Count()) + .ThenBy(p => p.Name) + .Select(p => p.Name), + // Expected + q => q.OrderBy(e => e.ReviewAsPrimary ? e.ReviewIssues.Count() : e.WorkIssues.Count()) + .ThenBy(p => p.Name) + .Select(p => p.Name) + )); + } + + [Test] + public async Task GroupByClauseAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.GroupBy(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.WorkIssues).Count()) + .OrderBy(x => x.Key) + .Select(grp => grp.Count()), + // Expected + q => q.GroupBy(e => e.ReviewAsPrimary ? e.ReviewIssues.Count() : e.WorkIssues.Count()) + .OrderBy(x => x.Key) + .Select(grp => grp.Count()) + )); + } + } +} diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs new file mode 100644 index 00000000000..f1bb1e633e4 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs @@ -0,0 +1,121 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections.Generic; +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Mapping.ByCode; +using NHibernate.Type; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + using System.Threading; + [TestFixture] + public abstract class GH1879BaseFixtureAsync : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + }); + mapper.JoinedSubclass(rc => + { + rc.Property(x => x.CorporateId); + }); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + rc.Property(x => x.EmailPref, m => m.Type>()); + rc.ManyToOne(x => x.Client, m => m.Column("ClientId")); + rc.ManyToOne(x => x.BillingClient, m => m.Column("BillingClientId")); + rc.ManyToOne(x => x.CorporateClient, m => m.Column("CorporateClientId")); + }); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + rc.ManyToOne(x => x.Project, m => m.Column("ProjectId")); + rc.ManyToOne(x => x.Client, m => m.Column("ClientId")); + }); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.InvoiceNumber); + rc.ManyToOne(x => x.Project, m => m.Column("ProjectId")); + rc.ManyToOne(x => x.Issue, m => m.Column("IssueId")); + }); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + rc.Property(x => x.ReviewAsPrimary); + rc.Set(x => x.WorkIssues, + m => + { + m.Table("EmployeesToWorkIssues"); + m.Cascade(Mapping.ByCode.Cascade.All | Mapping.ByCode.Cascade.DeleteOrphans); + m.Key(k => k.Column(c => c.Name("EmployeeId")) ); + }, + rel => rel.ManyToMany(m => m.Column("IssueId"))); + rc.Set(x => x.ReviewIssues, + m => + { + m.Table("EmployeesToReviewIssues"); + m.Cascade(Mapping.ByCode.Cascade.All | Mapping.ByCode.Cascade.DeleteOrphans); + m.Key(k => k.Column(c => c.Name("EmployeeId")) ); + }, + rel => rel.ManyToMany(m => m.Column("IssueId"))); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + session.Flush(); + transaction.Commit(); + } + } + + protected async Task AreEqualAsync( + Func, IQueryable> actualQuery, + Func, IQueryable> expectedQuery, CancellationToken cancellationToken = default(CancellationToken)) + { + using (var session = OpenSession()) + { + IEnumerable expectedResult = null; + try + { + expectedResult = await (expectedQuery(session.Query()).ToListAsync(cancellationToken)); + } + catch (OperationCanceledException) { throw; } + catch + { + Assert.Ignore("Not currently supported query"); + } + + var testResult = await (actualQuery(session.Query()).ToListAsync(cancellationToken)); + Assert.That(testResult, Is.EqualTo(expectedResult)); + } + } + } +} From 14b67ac0edd9976efbca906a3cbaca9a3dd7bfab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Tue, 23 Oct 2018 13:27:02 +0200 Subject: [PATCH 3/5] fixup! GH-1879 - Allow Coalesce and Conditional logic on entity properties and collections (LINQ) Do additional minor cleanup --- .../Linq/ReWriters/ConditionalQueryReferenceExpander.cs | 6 +++--- .../Linq/ReWriters/SubQueryConditionalExpander.cs | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs b/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs index b3d878c107a..86370dda35d 100644 --- a/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs +++ b/src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs @@ -76,7 +76,7 @@ public void Transform(ResultOperatorBase resultOperator) protected override Expression VisitMember(MemberExpression node) { - var result = (MemberExpression)base.VisitMember(node); + var result = (MemberExpression) base.VisitMember(node); if (QueryReferenceCounter.CountReferences(result.Expression) > 1) { return ConditionalQueryReferenceMemberExpressionRewriter.Rewrite(result.Expression, node); @@ -89,7 +89,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node) var result = (MethodCallExpression)base.VisitMethodCall(node); var isExtension = node.Method.GetCustomAttributes().Any(); var methodObject = isExtension ? node.Arguments[0] : node.Object; - + if (methodObject != null && QueryReferenceCounter.CountReferences(methodObject) > 1) { return ConditionalQueryReferenceMethodCallExpressionRewriter.Rewrite(methodObject, node); @@ -107,7 +107,7 @@ private QueryReferenceCounter(System.Type queryType) { _queryType = queryType; } - + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { if (_queryType.IsAssignableFrom(expression.Type)) diff --git a/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs b/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs index 8db8aa6ed70..b38ba805a41 100644 --- a/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs +++ b/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs @@ -94,16 +94,16 @@ protected override Expression VisitConditional(ConditionalExpression node) _nominate[_depth] = false; } - var newTest = this.Visit(node.Test); + var newTest = Visit(node.Test); _nominate.Insert(++_depth, false); - var newTrue = this.Visit(node.IfTrue); + var newTrue = Visit(node.IfTrue); if (_nominate[_depth]) { newTrue = BuildNewSubQuery(newTrue); Rewritten = true; } _nominate.Insert(_depth, false); - var newFalse = this.Visit(node.IfFalse); + var newFalse = Visit(node.IfFalse); if (_nominate[_depth]) { newFalse = BuildNewSubQuery(newFalse); From 6de248b057a31907823e65099ba89fdb389c82cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Tue, 23 Oct 2018 14:28:00 +0200 Subject: [PATCH 4/5] fixup! GH-1879 - Allow Coalesce and Conditional logic on entity properties and collections (LINQ) Use a stack instead of a list with a pointer to "depth" --- .../ReWriters/SubQueryConditionalExpander.cs | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs b/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs index b38ba805a41..98402dc2f1f 100644 --- a/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs +++ b/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs @@ -77,8 +77,7 @@ protected override Expression VisitSubQuery(SubQueryExpression expression) private class SubQueryFromClauseExpander : RelinqExpressionVisitor { private readonly QueryModel _originalSubQueryModel; - private int _depth = -1; - private readonly IList _nominate = new List(); + private readonly Stack _nominate = new Stack(); public bool Rewritten { get; private set; } @@ -89,35 +88,36 @@ public SubQueryFromClauseExpander(QueryModel originalSubQueryModel) protected override Expression VisitConditional(ConditionalExpression node) { - if (_depth >= 0) + if (_nominate.Count > 0) { - _nominate[_depth] = false; + _nominate.Pop(); + _nominate.Push(false); } var newTest = Visit(node.Test); - _nominate.Insert(++_depth, false); + _nominate.Push(false); var newTrue = Visit(node.IfTrue); - if (_nominate[_depth]) + if (_nominate.Pop()) { newTrue = BuildNewSubQuery(newTrue); Rewritten = true; } - _nominate.Insert(_depth, false); + _nominate.Push(false); var newFalse = Visit(node.IfFalse); - if (_nominate[_depth]) + if (_nominate.Pop()) { newFalse = BuildNewSubQuery(newFalse); Rewritten = true; } - _nominate.RemoveAt(_depth--); return Expression.Condition(newTest, newTrue, newFalse); } protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { - if (_depth >= 0) + if (_nominate.Count > 0) { - _nominate[_depth] = true; + _nominate.Pop(); + _nominate.Push(true); } return base.VisitQuerySourceReference(expression); From 4d8821b95391eaf66c7cef2e4cd0022af6c20095 Mon Sep 17 00:00:00 2001 From: Duncan M Date: Tue, 23 Oct 2018 10:41:17 -0600 Subject: [PATCH 5/5] GH-1879 - Nested conditional subquery expansion tests and logic --- .../NHSpecificTest/GH1879/FixtureByCode.cs | 14 ++ .../GH1879/NestedConditionalThenMethodCall.cs | 126 ++++++++++++++++++ .../NHSpecificTest/GH1879/Entity.cs | 2 + .../NHSpecificTest/GH1879/FixtureByCode.cs | 14 ++ .../GH1879/NestedConditionalThenMethodCall.cs | 115 ++++++++++++++++ .../ReWriters/SubQueryConditionalExpander.cs | 72 ++++++++-- .../Linq/Visitors/QueryModelVisitor.cs | 6 +- 7 files changed, 339 insertions(+), 10 deletions(-) create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH1879/NestedConditionalThenMethodCall.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH1879/NestedConditionalThenMethodCall.cs diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs index f1bb1e633e4..6eb360584df 100644 --- a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs @@ -44,6 +44,12 @@ protected override HbmMapping GetMappings() rc.ManyToOne(x => x.Client, m => m.Column("ClientId")); rc.ManyToOne(x => x.BillingClient, m => m.Column("BillingClientId")); rc.ManyToOne(x => x.CorporateClient, m => m.Column("CorporateClientId")); + rc.Set(x => x.Issues, + m => + { + m.Key(k => k.Column(c => c.Name("ProjectId")) ); + }, + rel => rel.OneToMany()); }); mapper.Class(rc => { @@ -64,6 +70,14 @@ protected override HbmMapping GetMappings() rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); rc.Property(x => x.Name); rc.Property(x => x.ReviewAsPrimary); + rc.Set(x => x.Projects, + m => + { + m.Table("EmployeesToProjects"); + m.Cascade(Mapping.ByCode.Cascade.All | Mapping.ByCode.Cascade.DeleteOrphans); + m.Key(k => k.Column(c => c.Name("EmployeeId")) ); + }, + rel => rel.ManyToMany(m => m.Column("ProjectId"))); rc.Set(x => x.WorkIssues, m => { diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH1879/NestedConditionalThenMethodCall.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/NestedConditionalThenMethodCall.cs new file mode 100644 index 00000000000..b46423c9159 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH1879/NestedConditionalThenMethodCall.cs @@ -0,0 +1,126 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + using System.Threading.Tasks; + [TestFixture] + public class NestedConditionalThenMethodCallAsync : GH1879BaseFixtureAsync + { + /// + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Alpha" }; + var clientB = new Client { Name = "Beta" }; + session.Save(clientA); + session.Save(clientB); + + var projectA = new Project { Name = "Apple" }; + var projectB = new Project { Name = "Banana" }; + var projectC = new Project { Name = "Cherry" }; + session.Save(projectA); + session.Save(projectB); + session.Save(projectC); + + var issue1 = new Issue { Name = "1", Client = null, Project = null }; + var issue2 = new Issue { Name = "2", Client = clientA, Project = projectA }; + var issue3 = new Issue { Name = "3", Client = clientA, Project = projectA }; + var issue4 = new Issue { Name = "4", Client = clientA, Project = projectB }; + var issue5 = new Issue { Name = "5", Client = clientB, Project = projectC }; + session.Save(issue1); + session.Save(issue2); + session.Save(issue3); + session.Save(issue4); + session.Save(issue5); + + session.Save(new Employee { Name = "Andy", ReviewAsPrimary = true, ReviewIssues = { issue1, issue2, issue5 }, WorkIssues = { issue3 }, Projects = { projectA, projectB } }); + session.Save(new Employee { Name = "Bart", ReviewAsPrimary = false, ReviewIssues = { issue3 }, WorkIssues = { issue4, issue5 }, Projects = { projectB, projectC } }); + session.Save(new Employee { Name = "Carl", ReviewAsPrimary = true, ReviewIssues = { issue3 }, WorkIssues = { issue1, issue4, issue5 }, Projects = { projectC } }); + session.Save(new Employee { Name = "Dorn", ReviewAsPrimary = false, ReviewIssues = { issue3 }, WorkIssues = { issue1, issue4 }, Projects = { } }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public async Task WhereClauseAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.Where(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Any(i => i.Client.Name == "Beta")), + // Expected + q => q.Where(e => e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta")) + )); + } + + [Test] + public async Task SelectClauseAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.OrderBy(e => e.Name) + .Select(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Any(i => i.Client.Name == "Beta")), + // Expected + q => q.OrderBy(e => e.Name) + .Select(e => e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta")) + )); + } + + [Test] + public async Task SelectClauseToAnonAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.OrderBy(e => e.Name) + .Select(e => new { e.Name, Beta = (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Any(i => i.Client.Name == "Beta") }), + // Expected + q => q.OrderBy(e => e.Name) + .Select(e => new { e.Name, Beta = e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta") }) + )); + } + + [Test] + public async Task OrderByClauseAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.OrderBy(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Count()) + .ThenBy(p => p.Name) + .Select(p => p.Name), + // Expected + q => q.OrderBy(e => e.ReviewAsPrimary ? e.ReviewIssues.Count() : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Count() : e.WorkIssues.Count()) + .ThenBy(p => p.Name) + .Select(p => p.Name) + )); + } + + [Test] + public async Task GroupByClauseAsync() + { + await (AreEqualAsync( + // Conditional style + q => q.GroupBy(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Count()) + .OrderBy(x => x.Key) + .Select(grp => grp.Count()), + // Expected + q => q.GroupBy(e => e.ReviewAsPrimary ? e.ReviewIssues.Count() : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Count() : e.WorkIssues.Count()) + .OrderBy(x => x.Key) + .Select(grp => grp.Count()) + )); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs index 402395f974b..6bc31cfa9c7 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs @@ -21,6 +21,7 @@ public class Employee public virtual string Name { get; set; } public virtual bool ReviewAsPrimary { get; set; } + public virtual ICollection Projects { get; set; } = new List(); public virtual ICollection WorkIssues { get; set; } = new List(); public virtual ICollection ReviewIssues { get; set; } = new List(); } @@ -33,6 +34,7 @@ public class Project public virtual Client Client { get; set; } public virtual Client BillingClient { get; set; } public virtual CorporateClient CorporateClient { get; set; } + public virtual ICollection Issues { get; set; } = new List(); } public enum EmailPref diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs index cbf5df62ec2..926cf7c3f6c 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs @@ -31,6 +31,12 @@ protected override HbmMapping GetMappings() rc.ManyToOne(x => x.Client, m => m.Column("ClientId")); rc.ManyToOne(x => x.BillingClient, m => m.Column("BillingClientId")); rc.ManyToOne(x => x.CorporateClient, m => m.Column("CorporateClientId")); + rc.Set(x => x.Issues, + m => + { + m.Key(k => k.Column(c => c.Name("ProjectId")) ); + }, + rel => rel.OneToMany()); }); mapper.Class(rc => { @@ -51,6 +57,14 @@ protected override HbmMapping GetMappings() rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); rc.Property(x => x.Name); rc.Property(x => x.ReviewAsPrimary); + rc.Set(x => x.Projects, + m => + { + m.Table("EmployeesToProjects"); + m.Cascade(Mapping.ByCode.Cascade.All | Mapping.ByCode.Cascade.DeleteOrphans); + m.Key(k => k.Column(c => c.Name("EmployeeId")) ); + }, + rel => rel.ManyToMany(m => m.Column("ProjectId"))); rc.Set(x => x.WorkIssues, m => { diff --git a/src/NHibernate.Test/NHSpecificTest/GH1879/NestedConditionalThenMethodCall.cs b/src/NHibernate.Test/NHSpecificTest/GH1879/NestedConditionalThenMethodCall.cs new file mode 100644 index 00000000000..6c47bfdb8b5 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH1879/NestedConditionalThenMethodCall.cs @@ -0,0 +1,115 @@ +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH1879 +{ + [TestFixture] + public class NestedConditionalThenMethodCall : GH1879BaseFixture + { + /// + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var clientA = new Client { Name = "Alpha" }; + var clientB = new Client { Name = "Beta" }; + session.Save(clientA); + session.Save(clientB); + + var projectA = new Project { Name = "Apple" }; + var projectB = new Project { Name = "Banana" }; + var projectC = new Project { Name = "Cherry" }; + session.Save(projectA); + session.Save(projectB); + session.Save(projectC); + + var issue1 = new Issue { Name = "1", Client = null, Project = null }; + var issue2 = new Issue { Name = "2", Client = clientA, Project = projectA }; + var issue3 = new Issue { Name = "3", Client = clientA, Project = projectA }; + var issue4 = new Issue { Name = "4", Client = clientA, Project = projectB }; + var issue5 = new Issue { Name = "5", Client = clientB, Project = projectC }; + session.Save(issue1); + session.Save(issue2); + session.Save(issue3); + session.Save(issue4); + session.Save(issue5); + + session.Save(new Employee { Name = "Andy", ReviewAsPrimary = true, ReviewIssues = { issue1, issue2, issue5 }, WorkIssues = { issue3 }, Projects = { projectA, projectB } }); + session.Save(new Employee { Name = "Bart", ReviewAsPrimary = false, ReviewIssues = { issue3 }, WorkIssues = { issue4, issue5 }, Projects = { projectB, projectC } }); + session.Save(new Employee { Name = "Carl", ReviewAsPrimary = true, ReviewIssues = { issue3 }, WorkIssues = { issue1, issue4, issue5 }, Projects = { projectC } }); + session.Save(new Employee { Name = "Dorn", ReviewAsPrimary = false, ReviewIssues = { issue3 }, WorkIssues = { issue1, issue4 }, Projects = { } }); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public void WhereClause() + { + AreEqual( + // Conditional style + q => q.Where(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Any(i => i.Client.Name == "Beta")), + // Expected + q => q.Where(e => e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta")) + ); + } + + [Test] + public void SelectClause() + { + AreEqual( + // Conditional style + q => q.OrderBy(e => e.Name) + .Select(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Any(i => i.Client.Name == "Beta")), + // Expected + q => q.OrderBy(e => e.Name) + .Select(e => e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta")) + ); + } + + [Test] + public void SelectClauseToAnon() + { + AreEqual( + // Conditional style + q => q.OrderBy(e => e.Name) + .Select(e => new { e.Name, Beta = (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Any(i => i.Client.Name == "Beta") }), + // Expected + q => q.OrderBy(e => e.Name) + .Select(e => new { e.Name, Beta = e.ReviewAsPrimary ? e.ReviewIssues.Any(i => i.Client.Name == "Beta") : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Any(i => i.Client.Name == "Beta") : e.WorkIssues.Any(i => i.Client.Name == "Beta") }) + ); + } + + [Test] + public void OrderByClause() + { + AreEqual( + // Conditional style + q => q.OrderBy(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Count()) + .ThenBy(p => p.Name) + .Select(p => p.Name), + // Expected + q => q.OrderBy(e => e.ReviewAsPrimary ? e.ReviewIssues.Count() : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Count() : e.WorkIssues.Count()) + .ThenBy(p => p.Name) + .Select(p => p.Name) + ); + } + + [Test] + public void GroupByClause() + { + AreEqual( + // Conditional style + q => q.GroupBy(e => (e.ReviewAsPrimary ? e.ReviewIssues : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues) : e.WorkIssues).Count()) + .OrderBy(x => x.Key) + .Select(grp => grp.Count()), + // Expected + q => q.GroupBy(e => e.ReviewAsPrimary ? e.ReviewIssues.Count() : e.Projects.Any() ? e.Projects.SelectMany(x => x.Issues).Count() : e.WorkIssues.Count()) + .OrderBy(x => x.Key) + .Select(grp => grp.Count()) + ); + } + } +} diff --git a/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs b/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs index 98402dc2f1f..3c6d3a4231f 100644 --- a/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs +++ b/src/NHibernate/Linq/ReWriters/SubQueryConditionalExpander.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using NHibernate.Linq.Clauses; using NHibernate.Linq.Visitors; @@ -77,6 +78,7 @@ protected override Expression VisitSubQuery(SubQueryExpression expression) private class SubQueryFromClauseExpander : RelinqExpressionVisitor { private readonly QueryModel _originalSubQueryModel; + private readonly System.Type _subQueryCollectionType; private readonly Stack _nominate = new Stack(); public bool Rewritten { get; private set; } @@ -84,31 +86,32 @@ private class SubQueryFromClauseExpander : RelinqExpressionVisitor public SubQueryFromClauseExpander(QueryModel originalSubQueryModel) { _originalSubQueryModel = originalSubQueryModel; + _subQueryCollectionType = typeof(IEnumerable<>).MakeGenericType(_originalSubQueryModel.MainFromClause.ItemType); } protected override Expression VisitConditional(ConditionalExpression node) { - if (_nominate.Count > 0) - { - _nominate.Pop(); - _nominate.Push(false); - } - + _nominate.Push(false); var newTest = Visit(node.Test); + _nominate.Pop(); + + _nominate.Push(false); var newTrue = Visit(node.IfTrue); - if (_nominate.Pop()) + if (_nominate.Pop() && _subQueryCollectionType.IsAssignableFrom(newTrue.Type)) { newTrue = BuildNewSubQuery(newTrue); Rewritten = true; } + _nominate.Push(false); var newFalse = Visit(node.IfFalse); - if (_nominate.Pop()) + if (_nominate.Pop() && _subQueryCollectionType.IsAssignableFrom(newFalse.Type)) { newFalse = BuildNewSubQuery(newFalse); Rewritten = true; } + return Expression.Condition(newTest, newTrue, newFalse); } @@ -123,12 +126,63 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr return base.VisitQuerySourceReference(expression); } + protected override Expression VisitSubQuery(SubQueryExpression expression) + { + SubQueryConditionalExpander.ReWrite(expression.QueryModel); + if (_nominate.Count > 0 && _subQueryCollectionType.IsAssignableFrom(expression.Type)) + { + _nominate.Pop(); + _nominate.Push(true); + } + return base.VisitSubQuery(expression); + } + private SubQueryExpression BuildNewSubQuery(Expression fromExpr) { var newSubQuery = _originalSubQueryModel.Clone(); - newSubQuery.MainFromClause.FromExpression = fromExpr; + + if (fromExpr is SubQueryExpression innerSubQuery) + { + newSubQuery.MainFromClause = innerSubQuery.QueryModel.MainFromClause; + foreach (var bodyClause in newSubQuery.BodyClauses) + { + bodyClause.TransformExpressions(expr => QueryReferenceUpdater.Update(expr, newSubQuery.SelectClause.Selector, innerSubQuery.QueryModel.SelectClause.Selector)); + } + newSubQuery.SelectClause = innerSubQuery.QueryModel.SelectClause; + foreach (var bodyClause in innerSubQuery.QueryModel.BodyClauses.Reverse()) + { + newSubQuery.BodyClauses.Insert(0, bodyClause); + } + } + else + { + newSubQuery.MainFromClause.FromExpression = fromExpr; + } + return new SubQueryExpression(newSubQuery); } } + + private class QueryReferenceUpdater : RelinqExpressionVisitor + { + private readonly Expression _original; + private readonly Expression _new; + + private QueryReferenceUpdater(Expression original, Expression @new) + { + _original = original; + _new = @new; + } + + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) + { + return Equals(expression, _original) ? _new : base.VisitQuerySourceReference(expression); + } + + public static Expression Update(Expression toUpdate, Expression original, Expression @new) + { + return new QueryReferenceUpdater(original, @new).Visit(toUpdate); + } + } } } diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index e549843db9e..0fa769938a5 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -28,7 +28,11 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer NhLinqExpressionReturnType? rootReturnType) { // Expand conditionals in subquery FROM clauses into multiple subqueries - SubQueryConditionalExpander.ReWrite(queryModel); + if (root) + { + // This expander works recursively + SubQueryConditionalExpander.ReWrite(queryModel); + } NestedSelectRewriter.ReWrite(queryModel, parameters.SessionFactory);