Skip to content

Commit f20a5e3

Browse files
committed
Merge pull request #432 from PleasantD/NH-3797
NH-3797 - Use group by keys as first class HqlCandidates for select clause rewriting
2 parents 7a43e33 + 76d1605 commit f20a5e3

File tree

7 files changed

+64
-7
lines changed

7 files changed

+64
-7
lines changed

src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,27 @@ into grp
528528
Assert.That(result.Count, Is.EqualTo(77));
529529
}
530530

531+
[Test(Description = "NH-3797")]
532+
public void GroupByComputedValue()
533+
{
534+
var orderGroups = db.Orders.GroupBy(o => o.Customer.CustomerId == null ? 0 : 1).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
535+
Assert.AreEqual(830, orderGroups.Sum(g => g.Count));
536+
}
537+
538+
[Test(Description = "NH-3797")]
539+
public void GroupByComputedValueInAnonymousType()
540+
{
541+
var orderGroups = db.Orders.GroupBy(o => new { Key = o.Customer.CustomerId == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
542+
Assert.AreEqual(830, orderGroups.Sum(g => g.Count));
543+
}
544+
545+
[Test(Description = "NH-3797")]
546+
public void GroupByComputedValueInObjectArray()
547+
{
548+
var orderGroups = db.Orders.GroupBy(o => new[] { o.Customer.CustomerId == null ? 0 : 1, }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
549+
Assert.AreEqual(830, orderGroups.Sum(g => g.Count));
550+
}
551+
531552
private static void CheckGrouping<TKey, TElement>(IEnumerable<IGrouping<TKey, TElement>> groupedItems, Func<TElement, TKey> groupBy)
532553
{
533554
var used = new HashSet<object>();

src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Linq;
4+
using System.Linq.Expressions;
45
using NHibernate.Linq.Clauses;
56
using NHibernate.Linq.ReWriters;
67
using NHibernate.Linq.Visitors;
8+
using NHibernate.Util;
79
using Remotion.Linq;
810
using Remotion.Linq.Clauses.Expressions;
911
using Remotion.Linq.Clauses.ResultOperators;
@@ -42,7 +44,7 @@ public static class AggregatingGroupByRewriter
4244
typeof (CacheableResultOperator)
4345
};
4446

45-
public static void ReWrite(QueryModel queryModel)
47+
public static void ReWrite(QueryModel queryModel, IList<Expression> groupByKeys)
4648
{
4749
var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression;
4850

@@ -57,6 +59,7 @@ public static void ReWrite(QueryModel queryModel)
5759
var groupBy = operators[0] as GroupResultOperator;
5860
if (groupBy != null)
5961
{
62+
groupBy.ExtractKeyExpressions(groupByKeys);
6063
FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy);
6164
}
6265
}
@@ -91,7 +94,7 @@ private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryMo
9194
queryModel.BodyClauses.Add(bodyClause);
9295

9396
// Replace the outer select clause...
94-
queryModel.SelectClause.TransformExpressions(s =>
97+
queryModel.SelectClause.TransformExpressions(s =>
9598
GroupBySelectClauseRewriter.ReWrite(s, groupBy, subQueryModel));
9699

97100
// Point all query source references to the outer from clause
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Linq.Expressions;
5+
using System.Text;
6+
using NHibernate.Util;
7+
using Remotion.Linq.Clauses.ResultOperators;
8+
9+
namespace NHibernate.Linq
10+
{
11+
internal static class GroupResultOperatorExtensions
12+
{
13+
public static void ExtractKeyExpressions(this GroupResultOperator groupResult, IList<Expression> groupByKeys)
14+
{
15+
if (groupResult.KeySelector is NewExpression)
16+
(groupResult.KeySelector as NewExpression).Arguments.ForEach(groupByKeys.Add);
17+
else if (groupResult.KeySelector is NewArrayExpression)
18+
(groupResult.KeySelector as NewArrayExpression).Expressions.ForEach(groupByKeys.Add);
19+
else
20+
groupByKeys.Add(groupResult.KeySelector);
21+
}
22+
}
23+
}

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Collections;
3+
using System.Collections.Generic;
24
using System.Linq.Expressions;
35
using NHibernate.Hql.Ast;
46
using NHibernate.Linq.Clauses;
@@ -32,7 +34,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
3234
NonAggregatingGroupByRewriter.ReWrite(queryModel);
3335

3436
// Rewrite aggregate group-by statements
35-
AggregatingGroupByRewriter.ReWrite(queryModel);
37+
AggregatingGroupByRewriter.ReWrite(queryModel, parameters.GroupByKeys);
3638

3739
// Rewrite aggregating group-joins
3840
AggregatingGroupJoinRewriter.ReWrite(queryModel);
@@ -77,7 +79,10 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
7779
// Identify and name query sources
7880
QuerySourceIdentifier.Visit(parameters.QuerySourceNamer, queryModel);
7981

80-
var visitor = new QueryModelVisitor(parameters, root, queryModel) { RewrittenOperatorResult = result };
82+
var visitor = new QueryModelVisitor(parameters, root, queryModel)
83+
{
84+
RewrittenOperatorResult = result,
85+
};
8186
visitor.Visit();
8287

8388
return visitor._hqlTree.GetTranslation();
@@ -233,7 +238,7 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que
233238
{
234239
CurrentEvaluationType = selectClause.GetOutputDataInfo();
235240

236-
var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters);
241+
var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters, VisitorParameters.GroupByKeys);
237242

238243
visitor.Visit(selectClause.Selector);
239244

src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ public class SelectClauseVisitor : ExpressionTreeVisitor
1818
private List<HqlExpression> _hqlTreeNodes = new List<HqlExpression>();
1919
private readonly HqlGeneratorExpressionTreeVisitor _hqlVisitor;
2020

21-
public SelectClauseVisitor(System.Type inputType, VisitorParameters parameters)
21+
public SelectClauseVisitor(System.Type inputType, VisitorParameters parameters, IEnumerable<Expression> groupByKeys)
2222
{
2323
_inputParameter = Expression.Parameter(inputType, "input");
2424
_parameters = parameters;
2525
_hqlVisitor = new HqlGeneratorExpressionTreeVisitor(_parameters);
26+
_hqlNodes = new HashSet<Expression>(groupByKeys);
2627
}
2728

2829
public LambdaExpression ProjectionExpression { get; private set; }
@@ -43,7 +44,7 @@ public void Visit(Expression expression)
4344
// Find the sub trees that can be expressed purely in HQL
4445
var nominator = new SelectClauseHqlNominator(_parameters);
4546
nominator.Visit(expression);
46-
_hqlNodes = nominator.HqlCandidates;
47+
_hqlNodes.UnionWith(nominator.HqlCandidates);
4748

4849
// Linq2SQL ignores calls to local methods. Linq2EF seems to not support
4950
// calls to local methods at all. For NHibernate we support local methods,

src/NHibernate/Linq/Visitors/VisitorParameters.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public class VisitorParameters
1616

1717
public QuerySourceNamer QuerySourceNamer { get; set; }
1818

19+
public IList<Expression> GroupByKeys { get; private set; }
20+
1921
public VisitorParameters(
2022
ISessionFactoryImplementor sessionFactory,
2123
IDictionary<ConstantExpression, NamedParameter> constantToParameterMap,
@@ -26,6 +28,7 @@ public VisitorParameters(
2628
ConstantToParameterMap = constantToParameterMap;
2729
RequiredHqlParameters = requiredHqlParameters;
2830
QuerySourceNamer = querySourceNamer;
31+
GroupByKeys = new List<Expression>();
2932
}
3033
}
3134
}

src/NHibernate/NHibernate.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@
299299
<Compile Include="Linq\Functions\EqualsGenerator.cs" />
300300
<Compile Include="Linq\GroupBy\KeySelectorVisitor.cs" />
301301
<Compile Include="Linq\GroupBy\PagingRewriter.cs" />
302+
<Compile Include="Linq\GroupResultOperatorExtensions.cs" />
302303
<Compile Include="Linq\NestedSelects\NestedSelectDetector.cs" />
303304
<Compile Include="Linq\NestedSelects\Tuple.cs" />
304305
<Compile Include="Linq\NestedSelects\SelectClauseRewriter.cs" />

0 commit comments

Comments
 (0)