Skip to content

Commit 0465dc4

Browse files
committed
Merge branch 'NH-3474'
2 parents 99e0bfc + 7740493 commit 0465dc4

File tree

5 files changed

+189
-4
lines changed

5 files changed

+189
-4
lines changed

src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,123 @@ public void GroupByComputedValueInObjectArray()
549549
Assert.AreEqual(830, orderGroups.Sum(g => g.Count));
550550
}
551551

552+
[Test(Description = "NH-3474")]
553+
public void GroupByConstant()
554+
{
555+
var totals = db.Orders.GroupBy(o => 1).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList();
556+
Assert.That(totals.Count, Is.EqualTo(1));
557+
Assert.That(totals, Has.All.With.Property("Key").EqualTo(1));
558+
}
559+
560+
[Test(Description = "NH-3474")]
561+
public void GroupByConstantAnonymousType()
562+
{
563+
var totals = db.Orders.GroupBy(o => new { A = 1 }).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList();
564+
Assert.That(totals.Count, Is.EqualTo(1));
565+
Assert.That(totals, Has.All.With.Property("Key").With.Property("A").EqualTo(1));
566+
}
567+
568+
[Test(Description = "NH-3474")]
569+
public void GroupByConstantArray()
570+
{
571+
var totals = db.Orders.GroupBy(o => new object[] { 1 }).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList();
572+
Assert.That(totals.Count, Is.EqualTo(1));
573+
Assert.That(totals, Has.All.With.Property("Key").EqualTo(new object[] { 1 }));
574+
}
575+
576+
[Test(Description = "NH-3474")]
577+
public void GroupByKeyWithConstantInAnonymousType()
578+
{
579+
var totals = db.Orders.GroupBy(o => new { A = 1, B = o.Shipper.ShipperId }).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList();
580+
Assert.That(totals.Count, Is.EqualTo(3));
581+
Assert.That(totals, Has.All.With.Property("Key").With.Property("A").EqualTo(1));
582+
}
583+
584+
[Test(Description = "NH-3474")]
585+
public void GroupByKeyWithConstantInArray()
586+
{
587+
var totals = db.Orders.GroupBy(o => new[] { 1, o.Shipper.ShipperId }).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList();
588+
Assert.That(totals.Count, Is.EqualTo(3));
589+
Assert.That(totals, Has.All.With.Property("Key").Contains(1));
590+
}
591+
592+
private int constKey;
593+
[Test(Description = "NH-3474")]
594+
public void GroupByKeyWithConstantFromVariable()
595+
{
596+
constKey = 1;
597+
var q1 = db.Orders.GroupBy(o => constKey).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)});
598+
var q1a = db.Orders.GroupBy(o => "").Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)});
599+
var q2 = db.Orders.GroupBy(o => new {A = constKey}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)});
600+
var q3 = db.Orders.GroupBy(o => new object[] {constKey}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)});
601+
var q3a = db.Orders.GroupBy(o => (IEnumerable<object>) new object[] {constKey}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)});
602+
var q4 = db.Orders.GroupBy(o => new {A = constKey, B = o.Shipper.ShipperId}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)});
603+
var q5 = db.Orders.GroupBy(o => new[] {constKey, o.Shipper.ShipperId}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)});
604+
var q5a = db.Orders.GroupBy(o => (IEnumerable<int>) new[] {constKey, o.Shipper.ShipperId}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)});
605+
606+
var r1_1 = q1.ToList();
607+
Assert.That(r1_1.Count, Is.EqualTo(1));
608+
Assert.That(r1_1, Has.All.With.Property("Key").EqualTo(1));
609+
610+
var r1a_1 = q1a.ToList();
611+
Assert.That(r1a_1.Count, Is.EqualTo(1));
612+
Assert.That(r1a_1, Has.All.With.Property("Key").EqualTo(""));
613+
614+
var r2_1 = q2.ToList();
615+
Assert.That(r2_1.Count, Is.EqualTo(1));
616+
Assert.That(r2_1, Has.All.With.Property("Key").With.Property("A").EqualTo(1));
617+
618+
var r3_1 = q3.ToList();
619+
Assert.That(r3_1.Count, Is.EqualTo(1));
620+
Assert.That(r3_1, Has.All.With.Property("Key").EquivalentTo(new object[] { 1 }));
621+
622+
var r3a_1 = q3a.ToList();
623+
Assert.That(r3a_1.Count, Is.EqualTo(1));
624+
Assert.That(r3a_1, Has.All.With.Property("Key").EquivalentTo(new object[] { 1 }));
625+
626+
var r4_1 = q4.ToList();
627+
Assert.That(r4_1.Count, Is.EqualTo(3));
628+
Assert.That(r4_1, Has.All.With.Property("Key").With.Property("A").EqualTo(1));
629+
630+
var r5_1 = q5.ToList();
631+
Assert.That(r5_1.Count, Is.EqualTo(3));
632+
Assert.That(r5_1, Has.All.With.Property("Key").Contains(1));
633+
634+
var r6_1 = q5a.ToList();
635+
Assert.That(r6_1.Count, Is.EqualTo(3));
636+
Assert.That(r6_1, Has.All.With.Property("Key").Contains(1));
637+
638+
constKey = 2;
639+
640+
var r1_2 = q1.ToList();
641+
Assert.That(r1_2.Count, Is.EqualTo(1));
642+
Assert.That(r1_2, Has.All.With.Property("Key").EqualTo(2));
643+
644+
var r2_2 = q2.ToList();
645+
Assert.That(r2_2.Count, Is.EqualTo(1));
646+
Assert.That(r2_2, Has.All.With.Property("Key").With.Property("A").EqualTo(2));
647+
648+
var r3_2 = q3.ToList();
649+
Assert.That(r3_2.Count, Is.EqualTo(1));
650+
Assert.That(r3_2, Has.All.With.Property("Key").EquivalentTo(new object[] { 2 }));
651+
652+
var r3a_2 = q3a.ToList();
653+
Assert.That(r3a_2.Count, Is.EqualTo(1));
654+
Assert.That(r3a_2, Has.All.With.Property("Key").EquivalentTo(new object[] { 2 }));
655+
656+
var r4_2 = q4.ToList();
657+
Assert.That(r4_2.Count, Is.EqualTo(3));
658+
Assert.That(r4_2, Has.All.With.Property("Key").With.Property("A").EqualTo(2));
659+
660+
var r5_2 = q5.ToList();
661+
Assert.That(r5_2.Count, Is.EqualTo(3));
662+
Assert.That(r5_2, Has.All.With.Property("Key").Contains(2));
663+
664+
var r6_2 = q5.ToList();
665+
Assert.That(r6_2.Count, Is.EqualTo(3));
666+
Assert.That(r6_2, Has.All.With.Property("Key").Contains(2));
667+
}
668+
552669
[Test(Description = "NH-3801")]
553670
public void GroupByComputedValueWithJoinOnObject()
554671
{

src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Linq;
4+
using System.Linq.Expressions;
45
using NHibernate.Linq.Clauses;
56
using NHibernate.Linq.ReWriters;
67
using NHibernate.Linq.Visitors;
@@ -58,6 +59,7 @@ public static void ReWrite(QueryModel queryModel)
5859
if (groupBy != null)
5960
{
6061
FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy);
62+
RemoveCostantGroupByKeys(queryModel, groupBy);
6163
}
6264
}
6365
}
@@ -101,5 +103,22 @@ private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryMo
101103
// Replace the outer query source
102104
queryModel.MainFromClause = subQueryModel.MainFromClause;
103105
}
106+
107+
private static void RemoveCostantGroupByKeys(QueryModel queryModel, GroupResultOperator groupBy)
108+
{
109+
var keys = groupBy.ExtractKeyExpressions().Where(x => !(x is ConstantExpression)).ToList();
110+
111+
if (!keys.Any())
112+
{
113+
// Remove the Group By clause completely if all the keys are constant (redundant)
114+
queryModel.ResultOperators.Remove(groupBy);
115+
}
116+
else
117+
{
118+
// Re-write the KeySelector as an object array of the non-constant keys
119+
// This should be safe because we've already re-written the select clause using the original keys
120+
groupBy.KeySelector = Expression.NewArrayInit(typeof (object), keys.Select(x => x.Type.IsValueType ? Expression.Convert(x, typeof(object)) : x));
121+
}
122+
}
104123
}
105124
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Linq.Expressions;
5+
using System.Text;
6+
using NHibernate.Util;
7+
using Remotion.Linq.Clauses.ResultOperators;
8+
9+
namespace NHibernate.Linq
10+
{
11+
internal static class GroupResultOperatorExtensions
12+
{
13+
public static IEnumerable<Expression> ExtractKeyExpressions(this GroupResultOperator groupResult)
14+
{
15+
if (groupResult.KeySelector is NewExpression)
16+
return (groupResult.KeySelector as NewExpression).Arguments;
17+
if (groupResult.KeySelector is NewArrayExpression)
18+
return (groupResult.KeySelector as NewArrayExpression).Expressions;
19+
return new [] { groupResult.KeySelector };
20+
}
21+
}
22+
}

src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using System.Collections;
23
using System.Collections.Generic;
34
using System.Linq;
@@ -83,18 +84,42 @@ protected override Expression VisitConstantExpression(ConstantExpression express
8384
{
8485
// Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not.
8586
if (param.Value == null)
87+
{
8688
_string.Append("NULL");
87-
if (param.Value is IEnumerable && !((IEnumerable)param.Value).Cast<object>().Any())
88-
_string.Append("EmptyList");
89+
}
8990
else
90-
_string.Append(param.Name);
91+
{
92+
var value = param.Value as IEnumerable;
93+
if (value != null && !(value is string) && !value.Cast<object>().Any())
94+
{
95+
_string.Append("EmptyList");
96+
}
97+
else
98+
{
99+
_string.Append(param.Name);
100+
}
101+
}
91102
}
92103
else
93104
{
94105
if (expression.Value == null)
106+
{
95107
_string.Append("NULL");
108+
}
96109
else
97-
_string.Append(expression.Value);
110+
{
111+
var value = expression.Value as IEnumerable;
112+
if (value != null && !(value is string) && !(value is IQueryable))
113+
{
114+
_string.Append("{");
115+
_string.Append(String.Join(",", value.Cast<object>()));
116+
_string.Append("}");
117+
}
118+
else
119+
{
120+
_string.Append(expression.Value);
121+
}
122+
}
98123
}
99124

100125
return base.VisitConstantExpression(expression);
@@ -142,6 +167,7 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp
142167
case "Single":
143168
case "SingleOrDefault":
144169
case "Select":
170+
case "GroupBy":
145171
insideSelectClause = true;
146172
break;
147173
default:

src/NHibernate/NHibernate.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@
299299
<Compile Include="Linq\Functions\EqualsGenerator.cs" />
300300
<Compile Include="Linq\GroupBy\KeySelectorVisitor.cs" />
301301
<Compile Include="Linq\GroupBy\PagingRewriter.cs" />
302+
<Compile Include="Linq\GroupResultOperatorExtensions.cs" />
302303
<Compile Include="Linq\NestedSelects\NestedSelectDetector.cs" />
303304
<Compile Include="Linq\NestedSelects\Tuple.cs" />
304305
<Compile Include="Linq\NestedSelects\SelectClauseRewriter.cs" />

0 commit comments

Comments
 (0)