diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index e51ac68f1e6..623df4936f5 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -8,11 +8,7 @@ //------------------------------------------------------------------------------ -using System; using System.Linq; -using System.Reflection; -using NHibernate.Cfg; -using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.Util; using NSubstitute; @@ -21,6 +17,7 @@ namespace NHibernate.Test.Linq.ByMethod { using System.Threading.Tasks; + using System.Threading; [TestFixture] public class JoinTestsAsync : LinqTestCase { @@ -114,6 +111,159 @@ public async Task LeftJoinExtensionMethodWithMultipleKeyPropertiesAsync() } } + [Test] + public async Task LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Where(x => x.mammal.SerialNumber.StartsWith("9")) + .Select(x => new { SerialNumber = x.animal.SerialNumber }) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyCountAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = await (db.Orders + .LeftJoin( + db.OrderLines, + x => x, + x => x.Order, + (order, line) => new { order, line }) + + .Select(x => new { x.order.OrderId, x.line.Discount }) + .CountAsync()); + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(2155)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [KnownBug("GH-2739")] + public async Task NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + using (var sqlSpy = new SqlLogSpy()) + { + var innerAnimals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Where(x => x.mammal.SerialNumber.StartsWith("9")) + .Select(x=>x.animal); + + var animals = await (db.Animals + .LeftJoin( + innerAnimals, + x => x.Id, + x => x.Id, + (animal, animal2) => new { animal, animal2 }) + .Select(x => new { SerialNumber = x.animal2.SerialNumber }) + .ToListAsync(cancellationToken)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithNoUseOfOuterReferenceAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Select(x => x.animal) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(6)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithNoUseOfOuterReferenceCountAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new {animal, mammal}) + .Select(x => x.animal) + .CountAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnlyAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .OrderBy(x => x.mammal.SerialNumber ?? "z") + .Select(x => new { SerialNumber = x.animal.SerialNumber }) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(6)); + Assert.That(animals[0].SerialNumber, Is.EqualTo("1121")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnlyCountAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = await (db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new {animal, mammal}) + .OrderBy(x => x.mammal.SerialNumber ?? "z") + .Select(x => new {SerialNumber = x.animal.SerialNumber}) + .CountAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [TestCase(false)] [TestCase(true)] public async Task CrossJoinWithPredicateInWhereStatementAsync(bool useCrossJoin) diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index bba8b52a638..99c8451c9ac 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -1,8 +1,4 @@ -using System; -using System.Linq; -using System.Reflection; -using NHibernate.Cfg; -using NHibernate.Engine.Query; +using System.Linq; using NHibernate.Linq; using NHibernate.Util; using NSubstitute; @@ -103,6 +99,159 @@ public void LeftJoinExtensionMethodWithMultipleKeyProperties() } } + [Test] + public void LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Where(x => x.mammal.SerialNumber.StartsWith("9")) + .Select(x => new { SerialNumber = x.animal.SerialNumber }) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public void LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnlyCount() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = db.Orders + .LeftJoin( + db.OrderLines, + x => x, + x => x.Order, + (order, line) => new { order, line }) + + .Select(x => new { x.order.OrderId, x.line.Discount }) + .Count(); + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(2155)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [KnownBug("GH-2739")] + public void NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly() + { + using (var sqlSpy = new SqlLogSpy()) + { + var innerAnimals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Where(x => x.mammal.SerialNumber.StartsWith("9")) + .Select(x=>x.animal); + + var animals = db.Animals + .LeftJoin( + innerAnimals, + x => x.Id, + x => x.Id, + (animal, animal2) => new { animal, animal2 }) + .Select(x => new { SerialNumber = x.animal2.SerialNumber }) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public void LeftJoinExtensionMethodWithNoUseOfOuterReference() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .Select(x => x.animal) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(6)); + } + } + + [Test] + public void LeftJoinExtensionMethodWithNoUseOfOuterReferenceCount() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new {animal, mammal}) + .Select(x => x.animal) + .Count(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public void LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnly() + { + using (var sqlSpy = new SqlLogSpy()) + { + var animals = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new { animal, mammal }) + .OrderBy(x => x.mammal.SerialNumber ?? "z") + .Select(x => new { SerialNumber = x.animal.SerialNumber }) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(animals.Count, Is.EqualTo(6)); + Assert.That(animals[0].SerialNumber, Is.EqualTo("1121")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + + [Test] + public void LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnlyCount() + { + using (var sqlSpy = new SqlLogSpy()) + { + var total = db.Animals + .LeftJoin( + db.Mammals, + x => x.Id, + x => x.Id, + (animal, mammal) => new {animal, mammal}) + .OrderBy(x => x.mammal.SerialNumber ?? "z") + .Select(x => new {SerialNumber = x.animal.SerialNumber}) + .Count(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(total, Is.EqualTo(6)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [TestCase(false)] [TestCase(true)] public void CrossJoinWithPredicateInWhereStatement(bool useCrossJoin) diff --git a/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs index 678711730fb..3226f718600 100644 --- a/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.Linq; +using NHibernate.Linq.Visitors; using Remotion.Linq; using Remotion.Linq.Clauses; @@ -43,6 +44,8 @@ public static void ReWrite(QueryModel model) if (aggregateDetectorResults.AggregatingClauses.Count > 0) { + NonAggregatingGroupJoinRewriter.RewriteGroupJoins(aggregateDetectorResults.AggregatingClauses, model); + // Re-write the select expression model.SelectClause.TransformExpressions(s => GroupJoinSelectClauseRewriter.ReWrite(s, aggregateDetectorResults)); @@ -56,7 +59,7 @@ public static void ReWrite(QueryModel model) private static IsAggregatingResults IsAggregatingGroupJoin(QueryModel model, IEnumerable clause) { - return GroupJoinAggregateDetectionVisitor.Visit(clause, model.SelectClause.Selector); + return GroupJoinAggregateDetectionQueryModelVisitor.Visit(clause, model); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs new file mode 100644 index 00000000000..8a63b3ac810 --- /dev/null +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionQueryModelVisitor.cs @@ -0,0 +1,42 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; +using Remotion.Linq; +using Remotion.Linq.Clauses; + +namespace NHibernate.Linq.GroupJoin +{ + internal class GroupJoinAggregateDetectionQueryModelVisitor : NhQueryModelVisitorBase + { + private readonly GroupJoinAggregateDetectionVisitor _groupJoinAggregateDetectionVisitor; + + private GroupJoinAggregateDetectionQueryModelVisitor(IEnumerable groupJoinClauses) + { + _groupJoinAggregateDetectionVisitor = new GroupJoinAggregateDetectionVisitor(groupJoinClauses); + } + + public static IsAggregatingResults Visit(IEnumerable groupJoinClause, QueryModel queryModel) + { + var visitor = new GroupJoinAggregateDetectionQueryModelVisitor(groupJoinClause); + + visitor.VisitQueryModel(queryModel); + + return visitor._groupJoinAggregateDetectionVisitor.GetResults(); + } + + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) + { + _groupJoinAggregateDetectionVisitor.Visit(whereClause.Predicate); + } + + public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) + { + _groupJoinAggregateDetectionVisitor.Visit(selectClause.Selector); + } + + public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index) + { + _groupJoinAggregateDetectionVisitor.Visit(ordering.Expression); + } + } +} diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs index 20f2331bf9c..9dc2a9b4701 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq.Expressions; using NHibernate.Linq.Expressions; using NHibernate.Linq.Visitors; +using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -18,7 +20,7 @@ internal class GroupJoinAggregateDetectionVisitor : NhExpressionVisitor private readonly List _nonAggregatingGroupJoins = new List(); private readonly List _aggregatingGroupJoins = new List(); - private GroupJoinAggregateDetectionVisitor(IEnumerable groupJoinClause) + internal GroupJoinAggregateDetectionVisitor(IEnumerable groupJoinClause) { _groupJoinClauses = new HashSet(groupJoinClause); } @@ -29,7 +31,7 @@ public static IsAggregatingResults Visit(IEnumerable groupJoinC visitor.Visit(selectExpression); - return new IsAggregatingResults { NonAggregatingClauses = visitor._nonAggregatingGroupJoins, AggregatingClauses = visitor._aggregatingGroupJoins, NonAggregatingExpressions = visitor._nonAggregatingExpressions }; + return GetResults(visitor); } protected override Expression VisitSubQuery(SubQueryExpression expression) @@ -89,6 +91,21 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr return base.VisitQuerySourceReference(expression); } + internal IsAggregatingResults GetResults() + { + return GetResults(this); + } + + private static IsAggregatingResults GetResults(GroupJoinAggregateDetectionVisitor visitor) + { + return new IsAggregatingResults + { + NonAggregatingClauses = visitor._nonAggregatingGroupJoins, + AggregatingClauses = visitor._aggregatingGroupJoins, + NonAggregatingExpressions = visitor._nonAggregatingExpressions + }; + } + internal class StackFlag { public bool FlagIsTrue { get; private set; } diff --git a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs index d56a9bedff4..61281622a69 100644 --- a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs @@ -36,11 +36,9 @@ public static void ReWrite(QueryModel model) new NonAggregatingGroupJoinRewriter(model, clauses).ReWrite(); } - private void ReWrite() + internal static void RewriteGroupJoins(IEnumerable groupJoins, QueryModel model) { - var aggregateDetectorResults = GetGroupJoinInformation(_groupJoinClauses); - - foreach (var nonAggregatingJoin in aggregateDetectorResults.NonAggregatingClauses) + foreach (var groupJoin in groupJoins) { // Group joins get processed (currently) in one of three ways. // Option 1: results of group join are not further referenced outside of the final projection. @@ -67,22 +65,20 @@ private void ReWrite() // select new { o.OrderNumber, x.VendorName, y.StatusName } // This is used to repesent an outer join, and again the "from" is removing the hierarchy. So // simply change the group join to an outer join - - var locator = new QuerySourceUsageLocator(nonAggregatingJoin); - - foreach (var bodyClause in _model.BodyClauses) + var locator = new QuerySourceUsageLocator(groupJoin); + foreach (var bodyClause in model.BodyClauses) { locator.Search(bodyClause); } - if (IsHierarchicalJoin(nonAggregatingJoin, locator)) + if (IsHierarchicalJoin(locator)) { } - else if (IsFlattenedJoin(nonAggregatingJoin, locator)) + else if (IsFlattenedJoin(locator)) { - ProcessFlattenedJoin(nonAggregatingJoin, locator); + ProcessFlattenedJoin(groupJoin, locator, model); } - else if (IsOuterJoin(nonAggregatingJoin)) + else if (IsOuterJoin(locator)) { } else @@ -93,7 +89,27 @@ private void ReWrite() } } - private void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) + private void ReWrite() + { + var aggregateDetectorResults = GetGroupJoinInformation(_groupJoinClauses); + RewriteGroupJoins(aggregateDetectorResults.NonAggregatingClauses, _model); + RewriteGroupJoins(GetNotUsedGroupJoins(aggregateDetectorResults), _model); + } + + private IEnumerable GetNotUsedGroupJoins(IsAggregatingResults aggregateDetectorResults) + { + foreach (var groupJoinClause in _groupJoinClauses) + { + if (aggregateDetectorResults.NonAggregatingClauses.Contains(groupJoinClause) || aggregateDetectorResults.AggregatingClauses.Contains(groupJoinClause)) + { + continue; + } + + yield return groupJoinClause; + } + } + + private static void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator, QueryModel model) { var nhJoin = locator.LeftJoin ? new NhOuterJoinClause(nonAggregatingJoin.JoinClause) @@ -103,55 +119,45 @@ private void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourc // 1. Remove the group join and replace it with a join // 2. Remove the corresponding "from" clause (the thing that was doing the flattening) // 3. Rewrite the query model to reference the "join" rather than the "from" clause - SwapClause(nonAggregatingJoin, (IBodyClause) nhJoin); + SwapClause(nonAggregatingJoin, (IBodyClause) nhJoin, model); + + model.BodyClauses.Remove((IBodyClause) locator.Usages[0]); - _model.BodyClauses.Remove((IBodyClause) locator.Usages[0]); - SwapQuerySourceVisitor querySourceSwapper; if (locator.LeftJoin) { // As we wrapped the join clause we have to update all references to the wrapped clause querySourceSwapper = new SwapQuerySourceVisitor(nonAggregatingJoin.JoinClause, nhJoin); - _model.TransformExpressions(querySourceSwapper.Swap); + model.TransformExpressions(querySourceSwapper.Swap); } querySourceSwapper = new SwapQuerySourceVisitor(locator.Usages[0], nhJoin); - _model.TransformExpressions(querySourceSwapper.Swap); + model.TransformExpressions(querySourceSwapper.Swap); } // TODO - store the indexes of the join clauses when we find them, then can remove this loop - private void SwapClause(IBodyClause oldClause, IBodyClause newClause) + private static void SwapClause(IBodyClause oldClause, IBodyClause newClause, QueryModel model) { - for (int i = 0; i < _model.BodyClauses.Count; i++) + for (int i = 0; i < model.BodyClauses.Count; i++) { - if (_model.BodyClauses[i] == oldClause) + if (model.BodyClauses[i] == oldClause) { - _model.BodyClauses[i] = newClause; + model.BodyClauses[i] = newClause; } } } - private bool IsOuterJoin(GroupJoinClause nonAggregatingJoin) + private static bool IsOuterJoin(QuerySourceUsageLocator locator) { return false; } - private bool IsFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) + private static bool IsFlattenedJoin(QuerySourceUsageLocator locator) { - if (locator.Usages.Count == 1) - { - var from = locator.Usages[0] as AdditionalFromClause; - - if (from != null) - { - return true; - } - } - - return false; + return locator.Usages.Count == 1 && locator.Usages[0] is AdditionalFromClause; } - private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) + private static bool IsHierarchicalJoin(QuerySourceUsageLocator locator) { return locator.Usages.Count == 0; } @@ -159,7 +165,7 @@ private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin, QuerySourceU // TODO - rename this and share with the AggregatingGroupJoinRewriter private IsAggregatingResults GetGroupJoinInformation(IEnumerable clause) { - return GroupJoinAggregateDetectionVisitor.Visit(clause, _model.SelectClause.Selector); + return GroupJoinAggregateDetectionQueryModelVisitor.Visit(clause, _model); } }