From 76d1605351e20272678d99d11d35d2096f205f59 Mon Sep 17 00:00:00 2001 From: Duncan M Date: Mon, 8 Jun 2015 12:06:46 -0600 Subject: [PATCH] NH-3797 - Use group by keys as first class HqlCandidates for select clause rewriting - Added test cases - Added the GroupByKeys expression list into the VisitorParameters class - Created an extension method to simplify the extraction of group by key expressions --- .../Linq/ByMethod/GroupByTests.cs | 21 +++++++++++++++++ .../GroupBy/AggregatingGroupByRewriter.cs | 7 ++++-- .../Linq/GroupResultOperatorExtensions.cs | 23 +++++++++++++++++++ .../Linq/Visitors/QueryModelVisitor.cs | 11 ++++++--- .../Linq/Visitors/SelectClauseVisitor.cs | 5 ++-- .../Linq/Visitors/VisitorParameters.cs | 3 +++ src/NHibernate/NHibernate.csproj | 1 + 7 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 src/NHibernate/Linq/GroupResultOperatorExtensions.cs diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs index 8a1c5adc8b3..35c63590533 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs @@ -520,6 +520,27 @@ into grp Assert.That(result.Count, Is.EqualTo(77)); } + [Test(Description = "NH-3797")] + public void GroupByComputedValue() + { + var orderGroups = db.Orders.GroupBy(o => o.Customer.CustomerId == null ? 0 : 1).Select(g => new { Key = g.Key, Count = g.Count() }).ToList(); + Assert.AreEqual(830, orderGroups.Sum(g => g.Count)); + } + + [Test(Description = "NH-3797")] + public void GroupByComputedValueInAnonymousType() + { + var orderGroups = db.Orders.GroupBy(o => new { Key = o.Customer.CustomerId == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList(); + Assert.AreEqual(830, orderGroups.Sum(g => g.Count)); + } + + [Test(Description = "NH-3797")] + public void GroupByComputedValueInObjectArray() + { + var orderGroups = db.Orders.GroupBy(o => new[] { o.Customer.CustomerId == null ? 0 : 1, }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList(); + Assert.AreEqual(830, orderGroups.Sum(g => g.Count)); + } + private static void CheckGrouping(IEnumerable> groupedItems, Func groupBy) { var used = new HashSet(); diff --git a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs index 11bf22bfb0d..795edf52dc5 100644 --- a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs @@ -1,9 +1,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using NHibernate.Linq.Clauses; using NHibernate.Linq.ReWriters; using NHibernate.Linq.Visitors; +using NHibernate.Util; using Remotion.Linq; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; @@ -42,7 +44,7 @@ public static class AggregatingGroupByRewriter typeof (CacheableResultOperator) }; - public static void ReWrite(QueryModel queryModel) + public static void ReWrite(QueryModel queryModel, IList groupByKeys) { var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; @@ -57,6 +59,7 @@ public static void ReWrite(QueryModel queryModel) var groupBy = operators[0] as GroupResultOperator; if (groupBy != null) { + groupBy.ExtractKeyExpressions(groupByKeys); FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy); } } @@ -91,7 +94,7 @@ private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryMo queryModel.BodyClauses.Add(bodyClause); // Replace the outer select clause... - queryModel.SelectClause.TransformExpressions(s => + queryModel.SelectClause.TransformExpressions(s => GroupBySelectClauseRewriter.ReWrite(s, groupBy, subQueryModel)); // Point all query source references to the outer from clause diff --git a/src/NHibernate/Linq/GroupResultOperatorExtensions.cs b/src/NHibernate/Linq/GroupResultOperatorExtensions.cs new file mode 100644 index 00000000000..2a6ee7d42c3 --- /dev/null +++ b/src/NHibernate/Linq/GroupResultOperatorExtensions.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using NHibernate.Util; +using Remotion.Linq.Clauses.ResultOperators; + +namespace NHibernate.Linq +{ + internal static class GroupResultOperatorExtensions + { + public static void ExtractKeyExpressions(this GroupResultOperator groupResult, IList groupByKeys) + { + if (groupResult.KeySelector is NewExpression) + (groupResult.KeySelector as NewExpression).Arguments.ForEach(groupByKeys.Add); + else if (groupResult.KeySelector is NewArrayExpression) + (groupResult.KeySelector as NewArrayExpression).Expressions.ForEach(groupByKeys.Add); + else + groupByKeys.Add(groupResult.KeySelector); + } + } +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 9ac7c237acc..89fb4ae6ffb 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -1,4 +1,6 @@ using System; +using System.Collections; +using System.Collections.Generic; using System.Linq.Expressions; using NHibernate.Hql.Ast; using NHibernate.Linq.Clauses; @@ -32,7 +34,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer NonAggregatingGroupByRewriter.ReWrite(queryModel); // Rewrite aggregate group-by statements - AggregatingGroupByRewriter.ReWrite(queryModel); + AggregatingGroupByRewriter.ReWrite(queryModel, parameters.GroupByKeys); // Rewrite aggregating group-joins AggregatingGroupJoinRewriter.ReWrite(queryModel); @@ -74,7 +76,10 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer // Identify and name query sources QuerySourceIdentifier.Visit(parameters.QuerySourceNamer, queryModel); - var visitor = new QueryModelVisitor(parameters, root, queryModel) { RewrittenOperatorResult = result }; + var visitor = new QueryModelVisitor(parameters, root, queryModel) + { + RewrittenOperatorResult = result, + }; visitor.Visit(); return visitor._hqlTree.GetTranslation(); @@ -230,7 +235,7 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que { CurrentEvaluationType = selectClause.GetOutputDataInfo(); - var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters); + var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters, VisitorParameters.GroupByKeys); visitor.Visit(selectClause.Selector); diff --git a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs index 2692fb7ed43..73d2eaedf15 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs @@ -18,11 +18,12 @@ public class SelectClauseVisitor : ExpressionTreeVisitor private List _hqlTreeNodes = new List(); private readonly HqlGeneratorExpressionTreeVisitor _hqlVisitor; - public SelectClauseVisitor(System.Type inputType, VisitorParameters parameters) + public SelectClauseVisitor(System.Type inputType, VisitorParameters parameters, IEnumerable groupByKeys) { _inputParameter = Expression.Parameter(inputType, "input"); _parameters = parameters; _hqlVisitor = new HqlGeneratorExpressionTreeVisitor(_parameters); + _hqlNodes = new HashSet(groupByKeys); } public LambdaExpression ProjectionExpression { get; private set; } @@ -43,7 +44,7 @@ public void Visit(Expression expression) // Find the sub trees that can be expressed purely in HQL var nominator = new SelectClauseHqlNominator(_parameters); nominator.Visit(expression); - _hqlNodes = nominator.HqlCandidates; + _hqlNodes.UnionWith(nominator.HqlCandidates); // Linq2SQL ignores calls to local methods. Linq2EF seems to not support // calls to local methods at all. For NHibernate we support local methods, diff --git a/src/NHibernate/Linq/Visitors/VisitorParameters.cs b/src/NHibernate/Linq/Visitors/VisitorParameters.cs index 27ef5de7029..03aba50ea8d 100644 --- a/src/NHibernate/Linq/Visitors/VisitorParameters.cs +++ b/src/NHibernate/Linq/Visitors/VisitorParameters.cs @@ -16,6 +16,8 @@ public class VisitorParameters public QuerySourceNamer QuerySourceNamer { get; set; } + public IList GroupByKeys { get; private set; } + public VisitorParameters( ISessionFactoryImplementor sessionFactory, IDictionary constantToParameterMap, @@ -26,6 +28,7 @@ public VisitorParameters( ConstantToParameterMap = constantToParameterMap; RequiredHqlParameters = requiredHqlParameters; QuerySourceNamer = querySourceNamer; + GroupByKeys = new List(); } } } \ No newline at end of file diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index 3d83baf167f..fb455561390 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -299,6 +299,7 @@ +