From db2155eadd1d8540bd530e403f3068d00f1b317a Mon Sep 17 00:00:00 2001 From: Gunnar Liljas Date: Thu, 22 Apr 2021 23:16:15 +0200 Subject: [PATCH 1/6] Add more left join support --- .../Linq/ByMethod/JoinTests.cs | 99 ++++++++++++++++++- ...JoinAggregateDetectionQueryModelVisitor.cs | 54 ++++++++++ .../GroupJoinAggregateDetectionVisitor.cs | 3 + .../NonAggregatingGroupJoinRewriter.cs | 2 +- 4 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index a7a4750b87a..371edc8161e 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -1,8 +1,4 @@ -using System; -using System.Linq; -using System.Reflection; -using NHibernate.Cfg; -using NHibernate.Engine.Query; +using System.Linq; using NHibernate.Linq; using NHibernate.Util; using NSubstitute; @@ -103,6 +99,99 @@ public void LeftJoinExtensionMethodWithMultipleKeyProperties() } } + [Test] + public void LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Where(x => x.mammal.SerialNumber.StartsWith("9")) + .Select(x => new { SerialNumber = x.animal.SerialNumber }) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [KnownBug("GH-XXXX")] + public void NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() + { + using (var sqlSpy = new SqlLogSpy()) + { + var innerAnimals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Where(x => x.mammal.SerialNumber.StartsWith("9")) + .Select(x=>x.animal); + + var animals = db.Animals + .LeftJoin( + innerAnimals, + x => x.Id, + x => x.Id, + (animal, animal2) => new { animal, animal2 }) + .Select(x => new { SerialNumber = x.animal2.SerialNumber }) + .ToList(); + + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [KnownBug("GH-XXXX")] + public void LeftJoinExtensionMethodWithNoUseOfOuterReference() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Select(x => x.animal) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(6)); + } + } + + [Test] + public void LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnly() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .OrderBy(x => x.mammal.SerialNumber ?? "z") + .Select(x => new { SerialNumber = x.animal.SerialNumber }) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(6)); + Assert.That(animals[0].SerialNumber, Is.EqualTo("1121")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [TestCase(false)] [TestCase(true)] public void CrossJoinWithPredicateInWhereStatement(bool useCrossJoin) diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs new file mode 100644 index 00000000000..1e4b1ab9a0c --- /dev/null +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs @@ -0,0 +1,54 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; +using Remotion.Linq; +using Remotion.Linq.Clauses; + +namespace NHibernate.Linq.GroupJoin +{ + internal class GroupJoinAggregateDetectionQueryModelVisitor : NhQueryModelVisitorBase + { + private HashSet _groupJoinClauses; + private readonly List _nonAggregatingExpressions = new List(); + private readonly List _nonAggregatingGroupJoins = new List(); + private readonly List _aggregatingGroupJoins = new List(); + private GroupJoinAggregateDetectionQueryModelVisitor(IEnumerable groupJoinClauses) + { + _groupJoinClauses = new HashSet(groupJoinClauses); + } + + public static IsAggregatingResults Visit(IEnumerable groupJoinClause, QueryModel queryModel) + { + var visitor = new GroupJoinAggregateDetectionQueryModelVisitor(groupJoinClause); + + visitor.VisitQueryModel(queryModel); + + return new IsAggregatingResults { NonAggregatingClauses = visitor._nonAggregatingGroupJoins, AggregatingClauses = visitor._aggregatingGroupJoins, NonAggregatingExpressions = visitor._nonAggregatingExpressions }; + } + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) + { + var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, whereClause.Predicate); + AddResults(results); + } + + public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) + { + var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, selectClause.Selector); + AddResults(results); + } + + public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index) + { + var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, ordering.Expression); + AddResults(results); + } + + + private void AddResults(IsAggregatingResults results) + { + _nonAggregatingExpressions.AddRange(results.NonAggregatingExpressions); + _nonAggregatingGroupJoins.AddRange(results.NonAggregatingClauses); + _aggregatingGroupJoins.AddRange(results.AggregatingClauses); + } + } +} diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs index 20f2331bf9c..cb94233e24f 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq.Expressions; using NHibernate.Linq.Expressions; using NHibernate.Linq.Visitors; +using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -34,6 +36,7 @@ public static IsAggregatingResults Visit(IEnumerable groupJoinC protected override Expression VisitSubQuery(SubQueryExpression expression) { + //Visit the entire query model? Visit(expression.QueryModel.SelectClause.Selector); return expression; } diff --git a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs index d56a9bedff4..5d49beda9e3 100644 --- a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs @@ -159,7 +159,7 @@ private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin, QuerySourceU // TODO - rename this and share with the AggregatingGroupJoinRewriter private IsAggregatingResults GetGroupJoinInformation(IEnumerable clause) { - return GroupJoinAggregateDetectionVisitor.Visit(clause, _model.SelectClause.Selector); + return GroupJoinAggregateDetectionQueryModelVisitor.Visit(clause, _model); } } From d18b50e824b24e8fb09f6cd03aa64ff64f66d9a6 Mon Sep 17 00:00:00 2001 From: Gunnar Liljas Date: Fri, 23 Apr 2021 10:02:08 +0200 Subject: [PATCH 2/6] Minor fixes --- src/NHibernate.Test/Linq/ByMethod/JoinTests.cs | 5 ++--- .../GroupJoinAggregateDetectionQueryModelVisitor.cs | 3 --- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index 371edc8161e..6e5e9a4de72 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -120,7 +120,7 @@ public void LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() } } - [KnownBug("GH-XXXX")] + [KnownBug("GH-2379")] public void NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() { using (var sqlSpy = new SqlLogSpy()) @@ -143,14 +143,13 @@ public void NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() .Select(x => new { SerialNumber = x.animal2.SerialNumber }) .ToList(); - var sql = sqlSpy.GetWholeLog(); Assert.That(animals.Count, Is.EqualTo(1)); Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); } } - [KnownBug("GH-XXXX")] + [KnownBug("GH-2738")] public void LeftJoinExtensionMethodWithNoUseOfOuterReference() { using (var sqlSpy = new SqlLogSpy()) diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs index 1e4b1ab9a0c..84361931935 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs @@ -16,7 +16,6 @@ private GroupJoinAggregateDetectionQueryModelVisitor(IEnumerable(groupJoinClauses); } - public static IsAggregatingResults Visit(IEnumerable groupJoinClause, QueryModel queryModel) { var visitor = new GroupJoinAggregateDetectionQueryModelVisitor(groupJoinClause); @@ -30,7 +29,6 @@ public override void VisitWhereClause(WhereClause whereClause, QueryModel queryM var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, whereClause.Predicate); AddResults(results); } - public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) { var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, selectClause.Selector); @@ -43,7 +41,6 @@ public override void VisitOrdering(Ordering ordering, QueryModel queryModel, Ord AddResults(results); } - private void AddResults(IsAggregatingResults results) { _nonAggregatingExpressions.AddRange(results.NonAggregatingExpressions); From bb05060b31284c24b908f2a54b40f3e7431667ef Mon Sep 17 00:00:00 2001 From: maca88 Date: Wed, 26 May 2021 22:36:38 +0200 Subject: [PATCH 3/6] Fix test with no use of outer reference --- .../Async/Linq/ByMethod/JoinTests.cs | 97 ++++++++++++++++++- .../Linq/ByMethod/JoinTests.cs | 6 +- .../NonAggregatingGroupJoinRewriter.cs | 18 ++++ 3 files changed, 114 insertions(+), 7 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index 9fda2e873c9..71cfde341de 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -8,11 +8,7 @@ //------------------------------------------------------------------------------ -using System; using System.Linq; -using System.Reflection; -using NHibernate.Cfg; -using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.Util; using NSubstitute; @@ -21,6 +17,7 @@ namespace NHibernate.Test.Linq.ByMethod { using System.Threading.Tasks; + using System.Threading; [TestFixture] public class JoinTestsAsync : LinqTestCase { @@ -114,6 +111,98 @@ public async Task LeftJoinExtensionMethodWithMultipleKeyPropertiesAsync() } } + [Test] + public async Task LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Where(x => x.mammal.SerialNumber.StartsWith("9")) + .Select(x => new { SerialNumber = x.animal.SerialNumber }) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [KnownBug("GH-2379")] + public async Task NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + using (var sqlSpy = new SqlLogSpy()) + { + var innerAnimals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Where(x => x.mammal.SerialNumber.StartsWith("9")) + .Select(x=>x.animal); + + var animals = await (db.Animals + .LeftJoin( + innerAnimals, + x => x.Id, + x => x.Id, + (animal, animal2) => new { animal, animal2 }) + .Select(x => new { SerialNumber = x.animal2.SerialNumber }) + .ToListAsync(cancellationToken)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithNoUseOfOuterReferenceAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Select(x => x.animal) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(5)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnlyAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .OrderBy(x => x.mammal.SerialNumber ?? "z") + .Select(x => new { SerialNumber = x.animal.SerialNumber }) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(6)); + Assert.That(animals[0].SerialNumber, Is.EqualTo("1121")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [TestCase(false)] [TestCase(true)] public async Task CrossJoinWithPredicateInWhereStatementAsync(bool useCrossJoin) diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index 6e5e9a4de72..8e19e1854c9 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -149,7 +149,7 @@ public void NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() } } - [KnownBug("GH-2738")] + [Test] public void LeftJoinExtensionMethodWithNoUseOfOuterReference() { using (var sqlSpy = new SqlLogSpy()) @@ -164,8 +164,8 @@ public void LeftJoinExtensionMethodWithNoUseOfOuterReference() .ToList(); var sql = sqlSpy.GetWholeLog(); - Assert.That(animals.Count, Is.EqualTo(1)); - Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(6)); + Assert.That(animals.Count, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(5)); } } diff --git a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs index 5d49beda9e3..eb633bac558 100644 --- a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs @@ -91,6 +91,24 @@ private void ReWrite() throw new NotSupportedException(); } } + + // Remove not used group joins + foreach (var groupJoinClause in _groupJoinClauses) + { + if (aggregateDetectorResults.NonAggregatingClauses.Contains(groupJoinClause) || aggregateDetectorResults.AggregatingClauses.Contains(groupJoinClause)) + { + continue; + } + + var locator = new QuerySourceUsageLocator(groupJoinClause); + foreach (var bodyClause in _model.BodyClauses) + { + locator.Search(bodyClause); + } + + _model.BodyClauses.Remove((IBodyClause) locator.Usages[0]); + _model.BodyClauses.Remove(groupJoinClause); + } } private void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) From f54994e6e418e609f5b29429bf32205cee0f74d8 Mon Sep 17 00:00:00 2001 From: maca88 Date: Wed, 26 May 2021 22:42:37 +0200 Subject: [PATCH 4/6] Optimize GroupJoinAggregateDetectionQueryModelVisitor --- ...JoinAggregateDetectionQueryModelVisitor.cs | 29 +++++++------------ .../GroupJoinAggregateDetectionVisitor.cs | 19 ++++++++++-- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs index 84361931935..8a63b3ac810 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs @@ -8,44 +8,35 @@ namespace NHibernate.Linq.GroupJoin { internal class GroupJoinAggregateDetectionQueryModelVisitor : NhQueryModelVisitorBase { - private HashSet _groupJoinClauses; - private readonly List _nonAggregatingExpressions = new List(); - private readonly List _nonAggregatingGroupJoins = new List(); - private readonly List _aggregatingGroupJoins = new List(); + private readonly GroupJoinAggregateDetectionVisitor _groupJoinAggregateDetectionVisitor; + private GroupJoinAggregateDetectionQueryModelVisitor(IEnumerable groupJoinClauses) { - _groupJoinClauses = new HashSet(groupJoinClauses); + _groupJoinAggregateDetectionVisitor = new GroupJoinAggregateDetectionVisitor(groupJoinClauses); } + public static IsAggregatingResults Visit(IEnumerable groupJoinClause, QueryModel queryModel) { var visitor = new GroupJoinAggregateDetectionQueryModelVisitor(groupJoinClause); visitor.VisitQueryModel(queryModel); - return new IsAggregatingResults { NonAggregatingClauses = visitor._nonAggregatingGroupJoins, AggregatingClauses = visitor._aggregatingGroupJoins, NonAggregatingExpressions = visitor._nonAggregatingExpressions }; + return visitor._groupJoinAggregateDetectionVisitor.GetResults(); } + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { - var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, whereClause.Predicate); - AddResults(results); + _groupJoinAggregateDetectionVisitor.Visit(whereClause.Predicate); } + public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) { - var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, selectClause.Selector); - AddResults(results); + _groupJoinAggregateDetectionVisitor.Visit(selectClause.Selector); } public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index) { - var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, ordering.Expression); - AddResults(results); - } - - private void AddResults(IsAggregatingResults results) - { - _nonAggregatingExpressions.AddRange(results.NonAggregatingExpressions); - _nonAggregatingGroupJoins.AddRange(results.NonAggregatingClauses); - _aggregatingGroupJoins.AddRange(results.AggregatingClauses); + _groupJoinAggregateDetectionVisitor.Visit(ordering.Expression); } } } diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs index cb94233e24f..bb2ac628955 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs @@ -20,7 +20,7 @@ internal class GroupJoinAggregateDetectionVisitor : NhExpressionVisitor private readonly List _nonAggregatingGroupJoins = new List(); private readonly List _aggregatingGroupJoins = new List(); - private GroupJoinAggregateDetectionVisitor(IEnumerable groupJoinClause) + internal GroupJoinAggregateDetectionVisitor(IEnumerable groupJoinClause) { _groupJoinClauses = new HashSet(groupJoinClause); } @@ -31,7 +31,7 @@ public static IsAggregatingResults Visit(IEnumerable groupJoinC visitor.Visit(selectExpression); - return new IsAggregatingResults { NonAggregatingClauses = visitor._nonAggregatingGroupJoins, AggregatingClauses = visitor._aggregatingGroupJoins, NonAggregatingExpressions = visitor._nonAggregatingExpressions }; + return GetResults(visitor); } protected override Expression VisitSubQuery(SubQueryExpression expression) @@ -92,6 +92,21 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr return base.VisitQuerySourceReference(expression); } + internal IsAggregatingResults GetResults() + { + return GetResults(this); + } + + private static IsAggregatingResults GetResults(GroupJoinAggregateDetectionVisitor visitor) + { + return new IsAggregatingResults + { + NonAggregatingClauses = visitor._nonAggregatingGroupJoins, + AggregatingClauses = visitor._aggregatingGroupJoins, + NonAggregatingExpressions = visitor._nonAggregatingExpressions + }; + } + internal class StackFlag { public bool FlagIsTrue { get; private set; } From 0a30e7c1cd3746dd65efe959863b871c85c66453 Mon Sep 17 00:00:00 2001 From: maca88 Date: Fri, 28 May 2021 23:40:05 +0200 Subject: [PATCH 5/6] Fix for aggregated queries, #2672 --- .../Async/Linq/ByMethod/JoinTests.cs | 65 +++++++++++++++- .../Linq/ByMethod/JoinTests.cs | 65 +++++++++++++++- .../GroupJoin/AggregatingGroupJoinRewriter.cs | 7 +- .../NonAggregatingGroupJoinRewriter.cs | 78 ++++++++----------- 4 files changed, 164 insertions(+), 51 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index 71cfde341de..084f6cd3bf7 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -132,7 +132,27 @@ public async Task LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyAsyn } } - [KnownBug("GH-2379")] + [Test] + public async Task LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyCountAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = await (db.Orders + .LeftJoin( + db.OrderLines, + x => x, + x => x.Order, + (order, line) => new { order, line }) + + .Select(x => new { x.order.OrderId, x.line.Discount }) + .CountAsync()); + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(2155)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [KnownBug("GH-2739")] public async Task NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyAsync(CancellationToken cancellationToken = default(CancellationToken)) { using (var sqlSpy = new SqlLogSpy()) @@ -177,7 +197,27 @@ public async Task LeftJoinExtensionMethodWithNoUseOfOuterReferenceAsync() var sql = sqlSpy.GetWholeLog(); Assert.That(animals.Count, Is.EqualTo(6)); - Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(5)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(6)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithNoUseOfOuterReferenceCountAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new {animal, mammal}) + .Select(x => x.animal) + .CountAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); } } @@ -203,6 +243,27 @@ public async Task LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnlyAs } } + [Test] + public async Task LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnlyCountAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new {animal, mammal}) + .OrderBy(x => x.mammal.SerialNumber ?? "z") + .Select(x => new {SerialNumber = x.animal.SerialNumber}) + .CountAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [TestCase(false)] [TestCase(true)] public async Task CrossJoinWithPredicateInWhereStatementAsync(bool useCrossJoin) diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index 8e19e1854c9..0c6b3f5834d 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -120,7 +120,27 @@ public void LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() } } - [KnownBug("GH-2379")] + [Test] + public void LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyCount() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = db.Orders + .LeftJoin( + db.OrderLines, + x => x, + x => x.Order, + (order, line) => new { order, line }) + + .Select(x => new { x.order.OrderId, x.line.Discount }) + .Count(); + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(2155)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [KnownBug("GH-2739")] public void NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() { using (var sqlSpy = new SqlLogSpy()) @@ -165,7 +185,27 @@ public void LeftJoinExtensionMethodWithNoUseOfOuterReference() var sql = sqlSpy.GetWholeLog(); Assert.That(animals.Count, Is.EqualTo(6)); - Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(5)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(6)); + } + } + + [Test] + public void LeftJoinExtensionMethodWithNoUseOfOuterReferenceCount() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new {animal, mammal}) + .Select(x => x.animal) + .Count(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); } } @@ -191,6 +231,27 @@ public void LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnly() } } + [Test] + public void LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnlyCount() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new {animal, mammal}) + .OrderBy(x => x.mammal.SerialNumber ?? "z") + .Select(x => new {SerialNumber = x.animal.SerialNumber}) + .Count(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [TestCase(false)] [TestCase(true)] public void CrossJoinWithPredicateInWhereStatement(bool useCrossJoin) diff --git a/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs index 678711730fb..3226f718600 100644 --- a/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.Linq; +using NHibernate.Linq.Visitors; using Remotion.Linq; using Remotion.Linq.Clauses; @@ -43,6 +44,8 @@ public static void ReWrite(QueryModel model) if (aggregateDetectorResults.AggregatingClauses.Count > 0) { + NonAggregatingGroupJoinRewriter.RewriteGroupJoins(aggregateDetectorResults.AggregatingClauses, model); + // Re-write the select expression model.SelectClause.TransformExpressions(s => GroupJoinSelectClauseRewriter.ReWrite(s, aggregateDetectorResults)); @@ -56,7 +59,7 @@ public static void ReWrite(QueryModel model) private static IsAggregatingResults IsAggregatingGroupJoin(QueryModel model, IEnumerable clause) { - return GroupJoinAggregateDetectionVisitor.Visit(clause, model.SelectClause.Selector); + return GroupJoinAggregateDetectionQueryModelVisitor.Visit(clause, model); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs index eb633bac558..61281622a69 100644 --- a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs @@ -36,11 +36,9 @@ public static void ReWrite(QueryModel model) new NonAggregatingGroupJoinRewriter(model, clauses).ReWrite(); } - private void ReWrite() + internal static void RewriteGroupJoins(IEnumerable groupJoins, QueryModel model) { - var aggregateDetectorResults = GetGroupJoinInformation(_groupJoinClauses); - - foreach (var nonAggregatingJoin in aggregateDetectorResults.NonAggregatingClauses) + foreach (var groupJoin in groupJoins) { // Group joins get processed (currently) in one of three ways. // Option 1: results of group join are not further referenced outside of the final projection. @@ -67,22 +65,20 @@ private void ReWrite() // select new { o.OrderNumber, x.VendorName, y.StatusName } // This is used to repesent an outer join, and again the "from" is removing the hierarchy. So // simply change the group join to an outer join - - var locator = new QuerySourceUsageLocator(nonAggregatingJoin); - - foreach (var bodyClause in _model.BodyClauses) + var locator = new QuerySourceUsageLocator(groupJoin); + foreach (var bodyClause in model.BodyClauses) { locator.Search(bodyClause); } - if (IsHierarchicalJoin(nonAggregatingJoin, locator)) + if (IsHierarchicalJoin(locator)) { } - else if (IsFlattenedJoin(nonAggregatingJoin, locator)) + else if (IsFlattenedJoin(locator)) { - ProcessFlattenedJoin(nonAggregatingJoin, locator); + ProcessFlattenedJoin(groupJoin, locator, model); } - else if (IsOuterJoin(nonAggregatingJoin)) + else if (IsOuterJoin(locator)) { } else @@ -91,8 +87,17 @@ private void ReWrite() throw new NotSupportedException(); } } + } - // Remove not used group joins + private void ReWrite() + { + var aggregateDetectorResults = GetGroupJoinInformation(_groupJoinClauses); + RewriteGroupJoins(aggregateDetectorResults.NonAggregatingClauses, _model); + RewriteGroupJoins(GetNotUsedGroupJoins(aggregateDetectorResults), _model); + } + + private IEnumerable GetNotUsedGroupJoins(IsAggregatingResults aggregateDetectorResults) + { foreach (var groupJoinClause in _groupJoinClauses) { if (aggregateDetectorResults.NonAggregatingClauses.Contains(groupJoinClause) || aggregateDetectorResults.AggregatingClauses.Contains(groupJoinClause)) @@ -100,18 +105,11 @@ private void ReWrite() continue; } - var locator = new QuerySourceUsageLocator(groupJoinClause); - foreach (var bodyClause in _model.BodyClauses) - { - locator.Search(bodyClause); - } - - _model.BodyClauses.Remove((IBodyClause) locator.Usages[0]); - _model.BodyClauses.Remove(groupJoinClause); + yield return groupJoinClause; } } - private void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) + private static void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator, QueryModel model) { var nhJoin = locator.LeftJoin ? new NhOuterJoinClause(nonAggregatingJoin.JoinClause) @@ -121,55 +119,45 @@ private void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourc // 1. Remove the group join and replace it with a join // 2. Remove the corresponding "from" clause (the thing that was doing the flattening) // 3. Rewrite the query model to reference the "join" rather than the "from" clause - SwapClause(nonAggregatingJoin, (IBodyClause) nhJoin); + SwapClause(nonAggregatingJoin, (IBodyClause) nhJoin, model); + + model.BodyClauses.Remove((IBodyClause) locator.Usages[0]); - _model.BodyClauses.Remove((IBodyClause) locator.Usages[0]); - SwapQuerySourceVisitor querySourceSwapper; if (locator.LeftJoin) { // As we wrapped the join clause we have to update all references to the wrapped clause querySourceSwapper = new SwapQuerySourceVisitor(nonAggregatingJoin.JoinClause, nhJoin); - _model.TransformExpressions(querySourceSwapper.Swap); + model.TransformExpressions(querySourceSwapper.Swap); } querySourceSwapper = new SwapQuerySourceVisitor(locator.Usages[0], nhJoin); - _model.TransformExpressions(querySourceSwapper.Swap); + model.TransformExpressions(querySourceSwapper.Swap); } // TODO - store the indexes of the join clauses when we find them, then can remove this loop - private void SwapClause(IBodyClause oldClause, IBodyClause newClause) + private static void SwapClause(IBodyClause oldClause, IBodyClause newClause, QueryModel model) { - for (int i = 0; i < _model.BodyClauses.Count; i++) + for (int i = 0; i < model.BodyClauses.Count; i++) { - if (_model.BodyClauses[i] == oldClause) + if (model.BodyClauses[i] == oldClause) { - _model.BodyClauses[i] = newClause; + model.BodyClauses[i] = newClause; } } } - private bool IsOuterJoin(GroupJoinClause nonAggregatingJoin) + private static bool IsOuterJoin(QuerySourceUsageLocator locator) { return false; } - private bool IsFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) + private static bool IsFlattenedJoin(QuerySourceUsageLocator locator) { - if (locator.Usages.Count == 1) - { - var from = locator.Usages[0] as AdditionalFromClause; - - if (from != null) - { - return true; - } - } - - return false; + return locator.Usages.Count == 1 && locator.Usages[0] is AdditionalFromClause; } - private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) + private static bool IsHierarchicalJoin(QuerySourceUsageLocator locator) { return locator.Usages.Count == 0; } From cf8f0826afa261d98a31aed335f7679d8a59c27e Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 31 May 2021 22:37:34 +0200 Subject: [PATCH 6/6] Remove comment --- .../Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs index bb2ac628955..9dc2a9b4701 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs @@ -36,7 +36,6 @@ public static IsAggregatingResults Visit(IEnumerable groupJoinC protected override Expression VisitSubQuery(SubQueryExpression expression) { - //Visit the entire query model? Visit(expression.QueryModel.SelectClause.Selector); return expression; }