Skip to content

Commit f227a6c

Browse files
committed
Merge branch 'master' into NH-3474 and fix conflicts
Conflicts: src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs src/NHibernate/Linq/GroupResultOperatorExtensions.cs
2 parents 593ec2b + 99e0bfc commit f227a6c

14 files changed

+221
-151
lines changed

src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,55 @@ public void GroupByKeyWithConstantFromVariable()
642642
Assert.That(r5_2.Count, Is.EqualTo(3));
643643
Assert.That(r5_2, Has.All.With.Property("Key").Contains(2));
644644
}
645+
646+
[Test(Description = "NH-3801")]
647+
public void GroupByComputedValueWithJoinOnObject()
648+
{
649+
var orderGroups = db.OrderLines.GroupBy(o => o.Order.Customer == null ? 0 : 1).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
650+
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
651+
}
652+
653+
[Test(Description = "NH-3801")]
654+
public void GroupByComputedValueWithJoinOnId()
655+
{
656+
var orderGroups = db.OrderLines.GroupBy(o => o.Order.Customer.CustomerId == null ? 0 : 1).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
657+
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
658+
}
659+
660+
[Test(Description = "NH-3801")]
661+
public void GroupByComputedValueInAnonymousTypeWithJoinOnObject()
662+
{
663+
var orderGroups = db.OrderLines.GroupBy(o => new { Key = o.Order.Customer == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
664+
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
665+
}
666+
667+
[Test(Description = "NH-3801")]
668+
public void GroupByComputedValueInAnonymousTypeWithJoinOnId()
669+
{
670+
var orderGroups = db.OrderLines.GroupBy(o => new { Key = o.Order.Customer.CustomerId == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
671+
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
672+
}
673+
674+
[Test(Description = "NH-3801")]
675+
public void GroupByComputedValueInObjectArrayWithJoinOnObject()
676+
{
677+
var orderGroups = db.OrderLines.GroupBy(o => new[] { o.Order.Customer == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
678+
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
679+
}
680+
681+
[Test(Description = "NH-3801")]
682+
public void GroupByComputedValueInObjectArrayWithJoinOnId()
683+
{
684+
var orderGroups = db.OrderLines.GroupBy(o => new[] { o.Order.Customer.CustomerId == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
685+
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
686+
}
687+
688+
[Test(Description = "NH-3801")]
689+
public void GroupByComputedValueInObjectArrayWithJoinInRightSideOfCase()
690+
{
691+
var orderGroups = db.OrderLines.GroupBy(o => new[] { o.Order.Customer.CustomerId == null ? "unknown" : o.Order.Customer.CompanyName }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
692+
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
693+
}
645694

646695
private static void CheckGrouping<TKey, TElement>(IEnumerable<IGrouping<TKey, TElement>> groupedItems, Func<TElement, TKey> groupBy)
647696
{

src/NHibernate.Test/Linq/JoinTests.cs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,58 @@ public void OrderLinesWithSelectingOrderIdAndDateShouldProduceOneJoin()
260260
}
261261
}
262262

263+
[Test(Description = "NH-3801")]
264+
public void OrderLinesWithSelectingCustomerIdInCaseShouldProduceOneJoin()
265+
{
266+
using (var spy = new SqlLogSpy())
267+
{
268+
(from l in db.OrderLines
269+
select new { CustomerKnown = l.Order.Customer.CustomerId == null ? 0 : 1, l.Order.OrderDate }).ToList();
270+
271+
var countJoins = CountJoins(spy);
272+
Assert.That(countJoins, Is.EqualTo(1));
273+
}
274+
}
275+
276+
[Test(Description = "NH-3801")]
277+
public void OrderLinesWithSelectingCustomerInCaseShouldProduceOneJoin()
278+
{
279+
using (var spy = new SqlLogSpy())
280+
{
281+
(from l in db.OrderLines
282+
select new { CustomerKnown = l.Order.Customer == null ? 0 : 1, l.Order.OrderDate }).ToList();
283+
284+
var countJoins = CountJoins(spy);
285+
Assert.That(countJoins, Is.EqualTo(1));
286+
}
287+
}
288+
289+
[Test(Description = "NH-3801")]
290+
public void OrderLinesWithSelectingCustomerNameInCaseShouldProduceTwoJoins()
291+
{
292+
using (var spy = new SqlLogSpy())
293+
{
294+
(from l in db.OrderLines
295+
select new { CustomerKnown = l.Order.Customer.CustomerId == null ? "unknown" : l.Order.Customer.CompanyName, l.Order.OrderDate }).ToList();
296+
297+
var countJoins = CountJoins(spy);
298+
Assert.That(countJoins, Is.EqualTo(2));
299+
}
300+
}
301+
302+
[Test(Description = "NH-3801")]
303+
public void OrderLinesWithSelectingCustomerNameInCaseShouldProduceTwoJoinsAlternate()
304+
{
305+
using (var spy = new SqlLogSpy())
306+
{
307+
(from l in db.OrderLines
308+
select new { CustomerKnown = l.Order.Customer == null ? "unknown" : l.Order.Customer.CompanyName, l.Order.OrderDate }).ToList();
309+
310+
var countJoins = CountJoins(spy);
311+
Assert.That(countJoins, Is.EqualTo(2));
312+
}
313+
}
314+
263315
private static int CountJoins(LogSpy sqlLog)
264316
{
265317
return Count(sqlLog, "join");

src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ private void DereferenceEntity(EntityType entityType, bool implicitJoin, string
407407
}
408408
else
409409
{
410-
joinIsNeeded = generateJoin || ( Walker.IsInSelect || Walker.IsInFrom );
410+
joinIsNeeded = generateJoin || ( (Walker.IsInSelect && !Walker.IsInCase ) || Walker.IsInFrom );
411411
}
412412

413413
if ( joinIsNeeded )

src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public static class AggregatingGroupByRewriter
4343
typeof (CacheableResultOperator)
4444
};
4545

46-
public static void ReWrite(QueryModel queryModel, IList<Expression> groupByKeys)
46+
public static void ReWrite(QueryModel queryModel)
4747
{
4848
var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression;
4949

@@ -59,11 +59,7 @@ public static void ReWrite(QueryModel queryModel, IList<Expression> groupByKeys)
5959
if (groupBy != null)
6060
{
6161
FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy);
62-
var extractedGroupByKeys = RemoveCostantGroupByKeys(queryModel, groupBy);
63-
foreach (var key in extractedGroupByKeys)
64-
{
65-
groupByKeys.Add(key);
66-
}
62+
RemoveCostantGroupByKeys(queryModel, groupBy);
6763
}
6864
}
6965
}
@@ -108,7 +104,7 @@ private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryMo
108104
queryModel.MainFromClause = subQueryModel.MainFromClause;
109105
}
110106

111-
private static IEnumerable<Expression> RemoveCostantGroupByKeys(QueryModel queryModel, GroupResultOperator groupBy)
107+
private static void RemoveCostantGroupByKeys(QueryModel queryModel, GroupResultOperator groupBy)
112108
{
113109
var keys = groupBy.ExtractKeyExpressions().Where(x => !(x is ConstantExpression)).ToList();
114110

@@ -123,8 +119,6 @@ private static IEnumerable<Expression> RemoveCostantGroupByKeys(QueryModel query
123119
// This should be safe because we've already re-written the select clause using the original keys
124120
groupBy.KeySelector = Expression.NewArrayInit(typeof (object), keys.Select(x => x.Type.IsValueType ? Expression.Convert(x, typeof(object)) : x));
125121
}
126-
127-
return keys;
128122
}
129123
}
130124
}

src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,36 @@ internal interface IIsEntityDecider
1515
public class AddJoinsReWriter : QueryModelVisitorBase, IIsEntityDecider
1616
{
1717
private readonly ISessionFactoryImplementor _sessionFactory;
18-
private readonly SelectJoinDetector _selectJoinDetector;
19-
private readonly ResultOperatorAndOrderByJoinDetector _resultOperatorAndOrderByJoinDetector;
18+
private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector;
2019
private readonly WhereJoinDetector _whereJoinDetector;
2120

2221
private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel)
2322
{
2423
_sessionFactory = sessionFactory;
2524
var joiner = new Joiner(queryModel);
26-
_selectJoinDetector = new SelectJoinDetector(this, joiner);
27-
_resultOperatorAndOrderByJoinDetector = new ResultOperatorAndOrderByJoinDetector(this, joiner);
25+
_memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner);
2826
_whereJoinDetector = new WhereJoinDetector(this, joiner);
2927
}
3028

31-
public static void ReWrite(QueryModel queryModel, ISessionFactoryImplementor sessionFactory)
29+
public static void ReWrite(QueryModel queryModel, VisitorParameters parameters)
3230
{
33-
new AddJoinsReWriter(sessionFactory, queryModel).VisitQueryModel(queryModel);
31+
var visitor = new AddJoinsReWriter(parameters.SessionFactory, queryModel);
32+
visitor.VisitQueryModel(queryModel);
3433
}
3534

3635
public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel)
3736
{
38-
_selectJoinDetector.Transform(selectClause);
37+
_memberExpressionJoinDetector.Transform(selectClause);
3938
}
4039

4140
public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index)
4241
{
43-
_resultOperatorAndOrderByJoinDetector.Transform(ordering);
42+
_memberExpressionJoinDetector.Transform(ordering);
4443
}
4544

4645
public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index)
4746
{
48-
_resultOperatorAndOrderByJoinDetector.Transform(resultOperator);
47+
_memberExpressionJoinDetector.Transform(resultOperator);
4948
}
5049

5150
public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index)

src/NHibernate/Linq/Visitors/JoinBuilder.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ public void MakeInnerIfJoined(string key)
6060
public bool CanAddJoin(Expression expression)
6161
{
6262
var source = QuerySourceExtractor.GetQuerySource(expression);
63-
64-
if (_queryModel.MainFromClause == source)
63+
64+
if (_queryModel.MainFromClause == source)
6565
return true;
66-
66+
6767
var bodyClause = source as IBodyClause;
68-
if (bodyClause != null && _queryModel.BodyClauses.Contains(bodyClause))
68+
if (bodyClause != null && _queryModel.BodyClauses.Contains(bodyClause))
6969
return true;
70-
70+
7171
var resultOperatorBase = source as ResultOperatorBase;
7272
return resultOperatorBase != null && _queryModel.ResultOperators.Contains(resultOperatorBase);
7373
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
using System.Collections;
2+
using System.Collections.Generic;
3+
using System.Linq.Expressions;
4+
using NHibernate.Linq.ReWriters;
5+
using Remotion.Linq.Clauses;
6+
using Remotion.Linq.Clauses.Expressions;
7+
using Remotion.Linq.Parsing;
8+
9+
namespace NHibernate.Linq.Visitors
10+
{
11+
/// <summary>
12+
/// Detects joins in Select, OrderBy and Results (GroupBy) clauses.
13+
/// Replaces them with appropriate joins, maintaining reference equality between different clauses.
14+
/// This allows extracted GroupBy key expression to also be replaced so that they can continue to match replaced Select expressions
15+
/// </summary>
16+
internal class MemberExpressionJoinDetector : ExpressionTreeVisitor
17+
{
18+
private readonly IIsEntityDecider _isEntityDecider;
19+
private readonly IJoiner _joiner;
20+
21+
private bool _requiresJoinForNonIdentifier;
22+
private bool _hasIdentifier;
23+
private int _memberExpressionDepth;
24+
25+
public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner)
26+
{
27+
_isEntityDecider = isEntityDecider;
28+
_joiner = joiner;
29+
}
30+
31+
protected override Expression VisitMemberExpression(MemberExpression expression)
32+
{
33+
var isIdentifier = _isEntityDecider.IsIdentifier(expression.Expression.Type, expression.Member.Name);
34+
if (isIdentifier)
35+
_hasIdentifier = true;
36+
if (!isIdentifier)
37+
_memberExpressionDepth++;
38+
39+
var result = base.VisitMemberExpression(expression);
40+
41+
if (!isIdentifier)
42+
_memberExpressionDepth--;
43+
44+
if (_isEntityDecider.IsEntity(expression.Type) &&
45+
((_requiresJoinForNonIdentifier && !_hasIdentifier) || _memberExpressionDepth > 0) &&
46+
_joiner.CanAddJoin(expression))
47+
{
48+
var key = ExpressionKeyVisitor.Visit(expression, null);
49+
return _joiner.AddJoin(result, key);
50+
}
51+
52+
return result;
53+
}
54+
55+
protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
56+
{
57+
expression.QueryModel.TransformExpressions(VisitExpression);
58+
return expression;
59+
}
60+
61+
protected override Expression VisitConditionalExpression(ConditionalExpression expression)
62+
{
63+
var oldRequiresJoinForNonIdentifier = _requiresJoinForNonIdentifier;
64+
_requiresJoinForNonIdentifier = false;
65+
var newTest = VisitExpression(expression.Test);
66+
_requiresJoinForNonIdentifier = oldRequiresJoinForNonIdentifier;
67+
var newFalse = VisitExpression(expression.IfFalse);
68+
var newTrue = VisitExpression(expression.IfTrue);
69+
if ((newTest != expression.Test) || (newFalse != expression.IfFalse) || (newTrue != expression.IfTrue))
70+
return Expression.Condition(newTest, newTrue, newFalse);
71+
return expression;
72+
}
73+
74+
public void Transform(SelectClause selectClause)
75+
{
76+
_requiresJoinForNonIdentifier = true;
77+
selectClause.TransformExpressions(VisitExpression);
78+
_requiresJoinForNonIdentifier = false;
79+
}
80+
81+
public void Transform(ResultOperatorBase resultOperator)
82+
{
83+
resultOperator.TransformExpressions(VisitExpression);
84+
}
85+
86+
public void Transform(Ordering ordering)
87+
{
88+
ordering.TransformExpressions(VisitExpression);
89+
}
90+
}
91+
}

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
3434
NonAggregatingGroupByRewriter.ReWrite(queryModel);
3535

3636
// Rewrite aggregate group-by statements
37-
AggregatingGroupByRewriter.ReWrite(queryModel, parameters.GroupByKeys);
37+
AggregatingGroupByRewriter.ReWrite(queryModel);
3838

3939
// Rewrite aggregating group-joins
4040
AggregatingGroupJoinRewriter.ReWrite(queryModel);
@@ -57,7 +57,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
5757
ArrayIndexExpressionFlattener.ReWrite(queryModel);
5858

5959
// Add joins for references
60-
AddJoinsReWriter.ReWrite(queryModel, parameters.SessionFactory);
60+
AddJoinsReWriter.ReWrite(queryModel, parameters);
6161

6262
// Move OrderBy clauses to end
6363
MoveOrderByToEndRewriter.ReWrite(queryModel);
@@ -238,7 +238,7 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que
238238
{
239239
CurrentEvaluationType = selectClause.GetOutputDataInfo();
240240

241-
var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters, VisitorParameters.GroupByKeys);
241+
var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters);
242242

243243
visitor.Visit(selectClause.Selector);
244244

0 commit comments

Comments
 (0)