diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs index a8aabd37b7e..50bbc101be8 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs @@ -549,6 +549,100 @@ public void GroupByComputedValueInObjectArray() Assert.AreEqual(830, orderGroups.Sum(g => g.Count)); } + [Test(Description = "NH-3474")] + public void GroupByConstant() + { + var totals = db.Orders.GroupBy(o => 1).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList(); + Assert.That(totals.Count, Is.EqualTo(1)); + Assert.That(totals, Has.All.With.Property("Key").EqualTo(1)); + } + + [Test(Description = "NH-3474")] + public void GroupByConstantAnonymousType() + { + var totals = db.Orders.GroupBy(o => new { A = 1 }).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList(); + Assert.That(totals.Count, Is.EqualTo(1)); + Assert.That(totals, Has.All.With.Property("Key").With.Property("A").EqualTo(1)); + } + + [Test(Description = "NH-3474")] + public void GroupByConstantArray() + { + var totals = db.Orders.GroupBy(o => new object[] { 1 }).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList(); + Assert.That(totals.Count, Is.EqualTo(1)); + Assert.That(totals, Has.All.With.Property("Key").EqualTo(new object[] { 1 })); + } + + [Test(Description = "NH-3474")] + public void GroupByKeyWithConstantInAnonymousType() + { + var totals = db.Orders.GroupBy(o => new { A = 1, B = o.Shipper.ShipperId }).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList(); + Assert.That(totals.Count, Is.EqualTo(3)); + Assert.That(totals, Has.All.With.Property("Key").With.Property("A").EqualTo(1)); + } + + [Test(Description = "NH-3474")] + public void GroupByKeyWithConstantInArray() + { + var totals = db.Orders.GroupBy(o => new[] { 1, o.Shipper.ShipperId }).Select(g => new { Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight) }).ToList(); + Assert.That(totals.Count, Is.EqualTo(3)); + Assert.That(totals, Has.All.With.Property("Key").Contains(1)); + } + + private int constKey; + [Test(Description = "NH-3474")] + public void GroupByKeyWithConstantFromVariable() + { + constKey = 1; + var q1 = db.Orders.GroupBy(o => constKey).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)}); + var q2 = db.Orders.GroupBy(o => new {A = constKey}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)}); + var q3 = db.Orders.GroupBy(o => new object[] {constKey}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)}); + var q4 = db.Orders.GroupBy(o => new {A = constKey, B = o.Shipper.ShipperId}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)}); + var q5 = db.Orders.GroupBy(o => new[] {constKey, o.Shipper.ShipperId}).Select(g => new {Key = g.Key, Count = g.Count(), Sum = g.Sum(x => x.Freight)}); + + var r1_1 = q1.ToList(); + Assert.That(r1_1.Count, Is.EqualTo(1)); + Assert.That(r1_1, Has.All.With.Property("Key").EqualTo(1)); + + var r2_1 = q2.ToList(); + Assert.That(r2_1.Count, Is.EqualTo(1)); + Assert.That(r2_1, Has.All.With.Property("Key").With.Property("A").EqualTo(1)); + + var r3_1 = q3.ToList(); + Assert.That(r3_1.Count, Is.EqualTo(1)); + Assert.That(r3_1, Has.All.With.Property("Key").EqualTo(new object[] { 1 })); + + var r4_1 = q4.ToList(); + Assert.That(r4_1.Count, Is.EqualTo(3)); + Assert.That(r4_1, Has.All.With.Property("Key").With.Property("A").EqualTo(1)); + + var r5_1 = q5.ToList(); + Assert.That(r5_1.Count, Is.EqualTo(3)); + Assert.That(r5_1, Has.All.With.Property("Key").Contains(1)); + + constKey = 2; + + var r1_2 = q1.ToList(); + Assert.That(r1_2.Count, Is.EqualTo(1)); + Assert.That(r1_2, Has.All.With.Property("Key").EqualTo(2)); + + var r2_2 = q2.ToList(); + Assert.That(r2_2.Count, Is.EqualTo(1)); + Assert.That(r2_2, Has.All.With.Property("Key").With.Property("A").EqualTo(2)); + + var r3_2 = q3.ToList(); + Assert.That(r3_2.Count, Is.EqualTo(1)); + Assert.That(r3_2, Has.All.With.Property("Key").EqualTo(new object[] { constKey })); + + var r4_2 = q4.ToList(); + Assert.That(r4_2.Count, Is.EqualTo(3)); + Assert.That(r4_2, Has.All.With.Property("Key").With.Property("A").EqualTo(2)); + + var r5_2 = q5.ToList(); + Assert.That(r5_2.Count, Is.EqualTo(3)); + Assert.That(r5_2, Has.All.With.Property("Key").Contains(2)); + } + [Test(Description = "NH-3801")] public void GroupByComputedValueWithJoinOnObject() { diff --git a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs index 3c3bf5e0d71..8150d0ae5df 100644 --- a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using NHibernate.Linq.Clauses; using NHibernate.Linq.ReWriters; using NHibernate.Linq.Visitors; @@ -58,6 +59,7 @@ public static void ReWrite(QueryModel queryModel) if (groupBy != null) { FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy); + RemoveCostantGroupByKeys(queryModel, groupBy); } } } @@ -101,5 +103,22 @@ private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryMo // Replace the outer query source queryModel.MainFromClause = subQueryModel.MainFromClause; } + + private static void RemoveCostantGroupByKeys(QueryModel queryModel, GroupResultOperator groupBy) + { + var keys = groupBy.ExtractKeyExpressions().Where(x => !(x is ConstantExpression)).ToList(); + + if (!keys.Any()) + { + // Remove the Group By clause completely if all the keys are constant (redundant) + queryModel.ResultOperators.Remove(groupBy); + } + else + { + // Re-write the KeySelector as an object array of the non-constant keys + // This should be safe because we've already re-written the select clause using the original keys + groupBy.KeySelector = Expression.NewArrayInit(typeof (object), keys.Select(x => x.Type.IsValueType ? Expression.Convert(x, typeof(object)) : x)); + } + } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupResultOperatorExtensions.cs b/src/NHibernate/Linq/GroupResultOperatorExtensions.cs new file mode 100644 index 00000000000..d53dc387b91 --- /dev/null +++ b/src/NHibernate/Linq/GroupResultOperatorExtensions.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using NHibernate.Util; +using Remotion.Linq.Clauses.ResultOperators; + +namespace NHibernate.Linq +{ + internal static class GroupResultOperatorExtensions + { + public static IEnumerable ExtractKeyExpressions(this GroupResultOperator groupResult) + { + if (groupResult.KeySelector is NewExpression) + return (groupResult.KeySelector as NewExpression).Arguments; + if (groupResult.KeySelector is NewArrayExpression) + return (groupResult.KeySelector as NewArrayExpression).Expressions; + return new [] { groupResult.KeySelector }; + } + } +} diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index 0cc37ec40bf..1689bcb6671 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -1,3 +1,4 @@ +using System; using System.Collections; using System.Collections.Generic; using System.Linq; @@ -92,9 +93,22 @@ protected override Expression VisitConstantExpression(ConstantExpression express else { if (expression.Value == null) + { _string.Append("NULL"); + } + else if (expression.Type.IsArray) + { + // Const arrays all look the same (they just display the type of array and not the initializer contents + // Since the contents might be different, we need to include those as well so we don't use a cached query plan by mistake + _string.Append(expression.Value); + _string.Append(" {"); + _string.Append(String.Join(",", (object[]) expression.Value)); + _string.Append("}"); + } else + { _string.Append(expression.Value); + } } return base.VisitConstantExpression(expression); @@ -142,6 +156,7 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp case "Single": case "SingleOrDefault": case "Select": + case "GroupBy": insideSelectClause = true; break; default: diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index 020a5fca917..f8e02857ff6 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -299,6 +299,7 @@ +