From 727797a140b2f7552f77cbde46d2e2ed55ac883f Mon Sep 17 00:00:00 2001 From: Duncan M Date: Wed, 17 Jun 2015 14:07:07 -0600 Subject: [PATCH] NH-3800, NH-3681 - Fixed issue relating to group by and left outer joins - Added issue specific tests - Modifed the GroupBySelectClauseRewriter to dig a bit deeper to match the group result ElementSelector - Modifed the GroupBySelectClauseRewriter to allow member access of convert expressions - Added a step to the QueryModelVisitor to flatten array index expressions with constant indexers to the inner expression --- .../Linq/ByMethod/GroupByTests.cs | 18 +- .../NHSpecificTest/NH3800/Domain.cs | 48 +++++ .../NHSpecificTest/NH3800/Fixture.cs | 193 ++++++++++++++++++ .../NHSpecificTest/NH3800/Mappings.hbm.xml | 58 ++++++ src/NHibernate.Test/NHibernate.Test.csproj | 5 + src/NHibernate/Linq/ExpressionExtensions.cs | 23 ++- .../GroupBy/GroupBySelectClauseRewriter.cs | 17 +- .../ArrayIndexExpressionFlattener.cs | 40 ++++ .../Linq/Visitors/QueryModelVisitor.cs | 3 + src/NHibernate/NHibernate.csproj | 1 + 10 files changed, 397 insertions(+), 9 deletions(-) create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3800/Domain.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3800/Fixture.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3800/Mappings.hbm.xml create mode 100644 src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs index 8a1c5adc8b3..064366d7bc6 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs @@ -507,14 +507,22 @@ public void ProjectingWithSubQueriesFilteredByTheAggregateKey() Assert.That(result[15].FirstOrder, Is.EqualTo(10255)); } - [Test(Description = "NH-3681"), KnownBug("NH-3681 not yet fixed", "NHibernate.HibernateException")] + [Test(Description = "NH-3681")] public void SelectManyGroupByAggregateProjection() { var result = (from o in db.Orders - from ol in o.OrderLines - group ol by ol.Product.ProductId - into grp - select new {ProductId = grp.Key, Sum = grp.Sum(x => x.UnitPrice)} + from ol in o.OrderLines + group ol by ol.Product.ProductId + into grp + select new + { + ProductId = grp.Key, + Sum = grp.Sum(x => x.UnitPrice), + Count = grp.Count(), + Avg = grp.Average(x => x.UnitPrice), + Min = grp.Min(x => x.UnitPrice), + Max = grp.Max(x => x.UnitPrice), + } ).ToList(); Assert.That(result.Count, Is.EqualTo(77)); diff --git a/src/NHibernate.Test/NHSpecificTest/NH3800/Domain.cs b/src/NHibernate.Test/NHSpecificTest/NH3800/Domain.cs new file mode 100644 index 00000000000..b94e1481bf1 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3800/Domain.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace NHibernate.Test.NHSpecificTest.NH3800 +{ + public class Project + { + public Project() + { + Components = new List(); + } + + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual IList Components { get; set; } + } + + public class Component + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual Project Project { get; set; } + } + + public class TimeRecord + { + public TimeRecord() + { + Components = new List(); + Tags = new List(); + } + + public virtual Guid Id { get; set; } + public virtual double TimeInHours { get; set; } + public virtual Project Project { get; set; } + public virtual IList Components { get; set; } + public virtual IList Tags { get; set; } + + } + + public class Tag + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3800/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/NH3800/Fixture.cs new file mode 100644 index 00000000000..e7f06f684da --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3800/Fixture.cs @@ -0,0 +1,193 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using System.Threading; +using NHibernate.Linq; +using NHibernate.Test.ExceptionsTest; +using NHibernate.Test.MappingByCode; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.NH3800 +{ + [TestFixture] + public class Fixture : BugTestCase + { + protected override void OnSetUp() + { + var tagA = new Tag() { Name = "A" }; + var tagB = new Tag() { Name = "B" }; + + var project1 = new Project { Name = "ProjectOne" }; + var compP1_x = new Component() { Name = "PONEx", Project = project1 }; + var compP1_y = new Component() { Name = "PONEy", Project = project1 }; + + var project2 = new Project { Name = "ProjectTwo" }; + var compP2_x = new Component() { Name = "PTWOx", Project = project2 }; + var compP2_y = new Component() { Name = "PTWOy", Project = project2 }; + + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Save(tagA); + session.Save(tagB); + session.Save(project1); + session.Save(compP1_x); + session.Save(compP1_y); + session.Save(project2); + session.Save(compP2_x); + session.Save(compP2_y); + + session.Save(new TimeRecord { TimeInHours = 1, Project = null, Components = { }, Tags = { tagA } }); + session.Save(new TimeRecord { TimeInHours = 2, Project = null, Components = { }, Tags = { tagB } }); + + session.Save(new TimeRecord { TimeInHours = 3, Project = project1, Tags = { tagA, tagB } }); + session.Save(new TimeRecord { TimeInHours = 4, Project = project1, Components = { compP1_x }, Tags = { tagB } }); + session.Save(new TimeRecord { TimeInHours = 5, Project = project1, Components = { compP1_y }, Tags = { tagA } }); + session.Save(new TimeRecord { TimeInHours = 6, Project = project1, Components = { compP1_x, compP1_y }, Tags = { } }); + + session.Save(new TimeRecord { TimeInHours = 7, Project = project2, Components = { }, Tags = { tagA, tagB } }); + session.Save(new TimeRecord { TimeInHours = 8, Project = project2, Components = { compP2_x }, Tags = { tagB } }); + session.Save(new TimeRecord { TimeInHours = 9, Project = project2, Components = { compP2_y }, Tags = { tagA } }); + session.Save(new TimeRecord { TimeInHours = 10, Project = project2, Components = { compP2_x, compP2_y }, Tags = { } }); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from TimeRecord"); + session.Delete("from Component"); + session.Delete("from Project"); + session.Delete("from Tag"); + + transaction.Commit(); + } + } + + [Test] + public void ExpectedHql() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var baseQuery = session.Query(); + + Assert.That(baseQuery.Sum(x => x.TimeInHours), Is.EqualTo(55)); + + var query = session.CreateQuery(@" + select c.Id, count(t), sum(cast(t.TimeInHours as big_decimal)) + from TimeRecord t + left join t.Components as c + group by c.Id"); + + var results = query.List(); + Assert.That(results.Select(x => x[1]), Is.EquivalentTo(new[] { 4, 2, 2, 2, 2 })); + Assert.That(results.Select(x => x[2]), Is.EquivalentTo(new[] { 13, 10, 11, 18, 19 })); + + Assert.That(results.Sum(x => (decimal?)x[2]), Is.EqualTo(71)); + + transaction.Rollback(); + } + } + + [Test] + public void PureLinq() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var baseQuery = session.Query(); + var query = from t in baseQuery + from c in t.Components.Select(x => (object)x.Id).DefaultIfEmpty() + let r = new object[] { c, t } + group r by r[0] + into g + select new[] { g.Key, g.Select(x => x[1]).Count(), g.Select(x => x[1]).Sum(x => (decimal?)((TimeRecord)x).TimeInHours) }; + + var results = query.ToList(); + Assert.That(results.Select(x => x[1]), Is.EquivalentTo(new[] { 4, 2, 2, 2, 2 })); + Assert.That(results.Select(x => x[2]), Is.EquivalentTo(new[] { 13, 10, 11, 18, 19 })); + + Assert.That(results.Sum(x => (decimal?)x[2]), Is.EqualTo(71)); + + transaction.Rollback(); + } + } + + [Test] + public void MethodGroup() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var baseQuery = session.Query(); + var query = baseQuery + .SelectMany(t => t.Components.Select(c => c.Id).DefaultIfEmpty().Select(c => new object[] { c, t })) + .GroupBy(g => g[0], g => (TimeRecord)g[1]) + .Select(g => new[] { g.Key, g.Count(), g.Sum(x => (decimal?)x.TimeInHours) }); + + var results = query.ToList(); + Assert.That(results.Select(x => x[1]), Is.EquivalentTo(new[] { 4, 2, 2, 2, 2 })); + Assert.That(results.Select(x => x[2]), Is.EquivalentTo(new[] { 13, 10, 11, 18, 19 })); + + Assert.That(results.Sum(x => (decimal?)x[2]), Is.EqualTo(71)); + + transaction.Rollback(); + } + } + + [Test] + public void ComplexExample() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var baseQuery = session.Query(); + + Assert.That(baseQuery.Sum(x => x.TimeInHours), Is.EqualTo(55)); + + var query = baseQuery.Select(t => new object[] { t }) + .SelectMany(t => ((TimeRecord)t[0]).Components.Select(c => (object)c.Id).DefaultIfEmpty().Select(c => new[] { t[0], c })) + .SelectMany(t => ((TimeRecord)t[0]).Tags.Select(x => (object)x.Id).DefaultIfEmpty().Select(x => new[] { t[0], t[1], x })) + .GroupBy(j => new[] { ((TimeRecord)j[0]).Project.Id, j[1], j[2] }, j => (TimeRecord)j[0]) + .Select(g => new object[] { g.Key, g.Count(), g.Sum(t => (decimal?)t.TimeInHours) }); + + var results = query.ToList(); + Assert.That(results.Select(x => x[1]), Is.EquivalentTo(new[] { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 })); + Assert.That(results.Select(x => x[2]), Is.EquivalentTo(new[] { 1, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8, 9, 10, 10 })); + + Assert.That(results.Sum(x => (decimal?)x[2]), Is.EqualTo(81)); + + transaction.Rollback(); + } + } + + [Test] + public void OuterJoinGroupingWithSubQueryInProjection() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var baseQuery = session.Query(); + var query = baseQuery + .SelectMany(t => t.Components.Select(c => c.Name).DefaultIfEmpty().Select(c => new object[] { c, t })) + .GroupBy(g => g[0], g => (TimeRecord)g[1]) + .Select(g => new[] { g.Key, g.Count(), session.Query().Count(c => c.Name == (string)g.Key) }); + + var results = query.ToList(); + Assert.That(results.Select(x => x[1]), Is.EquivalentTo(new[] { 4, 2, 2, 2, 2 })); + Assert.That(results.Select(x => x[2]), Is.EquivalentTo(new[] { 0, 1, 1, 1, 1 })); + + transaction.Rollback(); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3800/Mappings.hbm.xml b/src/NHibernate.Test/NHSpecificTest/NH3800/Mappings.hbm.xml new file mode 100644 index 00000000000..d78da66a266 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3800/Mappings.hbm.xml @@ -0,0 +1,58 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 3bab554137d..5e6331670a7 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -1263,6 +1263,8 @@ + + @@ -3147,6 +3149,9 @@ + + Designer + diff --git a/src/NHibernate/Linq/ExpressionExtensions.cs b/src/NHibernate/Linq/ExpressionExtensions.cs index d72296126fd..8c84ccfe302 100644 --- a/src/NHibernate/Linq/ExpressionExtensions.cs +++ b/src/NHibernate/Linq/ExpressionExtensions.cs @@ -1,4 +1,5 @@ -using System.Linq; +using System; +using System.Linq; using System.Linq.Expressions; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -32,5 +33,25 @@ public static bool IsGroupingKeyOf(this MemberExpression expression,GroupResultO return query.QueryModel.ResultOperators.Contains(groupBy); } + + public static bool IsGroupingElementOf(this QuerySourceReferenceExpression expression, GroupResultOperator groupBy) + { + var fromClause = expression.ReferencedQuerySource as MainFromClause; + if (fromClause == null) return false; + + var innerQuerySource = fromClause.FromExpression as QuerySourceReferenceExpression; + if (innerQuerySource == null) return false; + + if (innerQuerySource.ReferencedQuerySource.ItemName != groupBy.ItemName + || innerQuerySource.ReferencedQuerySource.ItemType != groupBy.ItemType) return false; + + var innerFromClause = innerQuerySource.ReferencedQuerySource as MainFromClause; + if (innerFromClause == null) return false; + + var query = innerFromClause.FromExpression as SubQueryExpression; + if (query == null) return false; + + return query.QueryModel.ResultOperators.Contains(groupBy); + } } } diff --git a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs index 0af649b4def..eb2a633fe70 100644 --- a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs @@ -31,7 +31,12 @@ private GroupBySelectClauseRewriter(GroupResultOperator groupBy, QueryModel mode protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) { - if (expression.ReferencedQuerySource == _groupBy) + if (!IsMemberOfModel(expression)) + { + return base.VisitQuerySourceReferenceExpression(expression); + } + + if (expression.IsGroupingElementOf(_groupBy)) { return _groupBy.ElementSelector; } @@ -59,7 +64,8 @@ protected override Expression VisitMemberExpression(MemberExpression expression) return base.VisitMemberExpression(expression); } - if (elementSelector is NewExpression && elementSelector.Type == expression.Expression.Type) + if ((elementSelector is NewExpression || elementSelector.NodeType == ExpressionType.Convert) + && elementSelector.Type == expression.Expression.Type) { //TODO: probably we should check this with a visitor return Expression.MakeMemberAccess(elementSelector, expression.Member); @@ -78,7 +84,12 @@ private bool IsMemberOfModel(MemberExpression expression) return false; } - var fromClause = querySourceRef.ReferencedQuerySource as FromClauseBase; + return IsMemberOfModel(querySourceRef); + } + + private bool IsMemberOfModel(QuerySourceReferenceExpression expression) + { + var fromClause = expression.ReferencedQuerySource as FromClauseBase; if (fromClause == null) { diff --git a/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs b/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs new file mode 100644 index 00000000000..0b739b078e3 --- /dev/null +++ b/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs @@ -0,0 +1,40 @@ +using System.Linq.Expressions; +using Remotion.Linq; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.ReWriters +{ + public class ArrayIndexExpressionFlattener : ExpressionTreeVisitor + { + public static void ReWrite(QueryModel model) + { + var visitor = new ArrayIndexExpressionFlattener(); + model.TransformExpressions(visitor.VisitExpression); + } + + protected override Expression VisitBinaryExpression(BinaryExpression expression) + { + var visitedExpression = base.VisitBinaryExpression(expression); + + if (visitedExpression.NodeType != ExpressionType.ArrayIndex) + return visitedExpression; + + var index = expression.Right as ConstantExpression; + if (index == null) + return visitedExpression; + + var expressionList = expression.Left as NewArrayExpression; + if (expressionList == null || expressionList.NodeType != ExpressionType.NewArrayInit) + return visitedExpression; + + return VisitExpression(expressionList.Expressions[(int)index.Value]); + } + + protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + { + ReWrite(expression.QueryModel); + return expression; // Note that we modifiy the (mutable) QueryModel, we return an unchanged expression + } + } +} \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 9ac7c237acc..63867610246 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -51,6 +51,9 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer // Flatten pointless subqueries QueryReferenceExpressionFlattener.ReWrite(queryModel); + // Flatten array index access to query references + ArrayIndexExpressionFlattener.ReWrite(queryModel); + // Add joins for references AddJoinsReWriter.ReWrite(queryModel, parameters.SessionFactory); diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index 3d83baf167f..e1905c7ca6c 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -303,6 +303,7 @@ +