Skip to content

NH-3801 #436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 7, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,55 @@ public void GroupByComputedValueInObjectArray()
Assert.AreEqual(830, orderGroups.Sum(g => g.Count));
}

[Test(Description = "NH-3801")]
public void GroupByComputedValueWithJoinOnObject()
{
var orderGroups = db.OrderLines.GroupBy(o => o.Order.Customer == null ? 0 : 1).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
}

[Test(Description = "NH-3801")]
public void GroupByComputedValueWithJoinOnId()
{
var orderGroups = db.OrderLines.GroupBy(o => o.Order.Customer.CustomerId == null ? 0 : 1).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
}

[Test(Description = "NH-3801")]
public void GroupByComputedValueInAnonymousTypeWithJoinOnObject()
{
var orderGroups = db.OrderLines.GroupBy(o => new { Key = o.Order.Customer == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
}

[Test(Description = "NH-3801")]
public void GroupByComputedValueInAnonymousTypeWithJoinOnId()
{
var orderGroups = db.OrderLines.GroupBy(o => new { Key = o.Order.Customer.CustomerId == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
}

[Test(Description = "NH-3801")]
public void GroupByComputedValueInObjectArrayWithJoinOnObject()
{
var orderGroups = db.OrderLines.GroupBy(o => new[] { o.Order.Customer == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
}

[Test(Description = "NH-3801")]
public void GroupByComputedValueInObjectArrayWithJoinOnId()
{
var orderGroups = db.OrderLines.GroupBy(o => new[] { o.Order.Customer.CustomerId == null ? 0 : 1 }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
}

[Test(Description = "NH-3801")]
public void GroupByComputedValueInObjectArrayWithJoinInRightSideOfCase()
{
var orderGroups = db.OrderLines.GroupBy(o => new[] { o.Order.Customer.CustomerId == null ? "unknown" : o.Order.Customer.CompanyName }).Select(g => new { Key = g.Key, Count = g.Count() }).ToList();
Assert.AreEqual(2155, orderGroups.Sum(g => g.Count));
}

private static void CheckGrouping<TKey, TElement>(IEnumerable<IGrouping<TKey, TElement>> groupedItems, Func<TElement, TKey> groupBy)
{
var used = new HashSet<object>();
Expand Down
52 changes: 52 additions & 0 deletions src/NHibernate.Test/Linq/JoinTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,58 @@ public void OrderLinesWithSelectingOrderIdAndDateShouldProduceOneJoin()
}
}

[Test(Description = "NH-3801")]
public void OrderLinesWithSelectingCustomerIdInCaseShouldProduceOneJoin()
{
using (var spy = new SqlLogSpy())
{
(from l in db.OrderLines
select new { CustomerKnown = l.Order.Customer.CustomerId == null ? 0 : 1, l.Order.OrderDate }).ToList();

var countJoins = CountJoins(spy);
Assert.That(countJoins, Is.EqualTo(1));
}
}

[Test(Description = "NH-3801")]
public void OrderLinesWithSelectingCustomerInCaseShouldProduceOneJoin()
{
using (var spy = new SqlLogSpy())
{
(from l in db.OrderLines
select new { CustomerKnown = l.Order.Customer == null ? 0 : 1, l.Order.OrderDate }).ToList();

var countJoins = CountJoins(spy);
Assert.That(countJoins, Is.EqualTo(1));
}
}

[Test(Description = "NH-3801")]
public void OrderLinesWithSelectingCustomerNameInCaseShouldProduceTwoJoins()
{
using (var spy = new SqlLogSpy())
{
(from l in db.OrderLines
select new { CustomerKnown = l.Order.Customer.CustomerId == null ? "unknown" : l.Order.Customer.CompanyName, l.Order.OrderDate }).ToList();

var countJoins = CountJoins(spy);
Assert.That(countJoins, Is.EqualTo(2));
}
}

[Test(Description = "NH-3801")]
public void OrderLinesWithSelectingCustomerNameInCaseShouldProduceTwoJoinsAlternate()
{
using (var spy = new SqlLogSpy())
{
(from l in db.OrderLines
select new { CustomerKnown = l.Order.Customer == null ? "unknown" : l.Order.Customer.CompanyName, l.Order.OrderDate }).ToList();

var countJoins = CountJoins(spy);
Assert.That(countJoins, Is.EqualTo(2));
}
}

private static int CountJoins(LogSpy sqlLog)
{
return Count(sqlLog, "join");
Expand Down
2 changes: 1 addition & 1 deletion src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ private void DereferenceEntity(EntityType entityType, bool implicitJoin, string
}
else
{
joinIsNeeded = generateJoin || ( Walker.IsInSelect || Walker.IsInFrom );
joinIsNeeded = generateJoin || ( (Walker.IsInSelect && !Walker.IsInCase ) || Walker.IsInFrom );
}

if ( joinIsNeeded )
Expand Down
5 changes: 1 addition & 4 deletions src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using NHibernate.Linq.Clauses;
using NHibernate.Linq.ReWriters;
using NHibernate.Linq.Visitors;
using NHibernate.Util;
using Remotion.Linq;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ResultOperators;
Expand Down Expand Up @@ -44,7 +42,7 @@ public static class AggregatingGroupByRewriter
typeof (CacheableResultOperator)
};

public static void ReWrite(QueryModel queryModel, IList<Expression> groupByKeys)
public static void ReWrite(QueryModel queryModel)
{
var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression;

Expand All @@ -59,7 +57,6 @@ public static void ReWrite(QueryModel queryModel, IList<Expression> groupByKeys)
var groupBy = operators[0] as GroupResultOperator;
if (groupBy != null)
{
groupBy.ExtractKeyExpressions(groupByKeys);
FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy);
}
}
Expand Down
23 changes: 0 additions & 23 deletions src/NHibernate/Linq/GroupResultOperatorExtensions.cs

This file was deleted.

17 changes: 8 additions & 9 deletions src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,36 @@ internal interface IIsEntityDecider
public class AddJoinsReWriter : QueryModelVisitorBase, IIsEntityDecider
{
private readonly ISessionFactoryImplementor _sessionFactory;
private readonly SelectJoinDetector _selectJoinDetector;
private readonly ResultOperatorAndOrderByJoinDetector _resultOperatorAndOrderByJoinDetector;
private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector;
private readonly WhereJoinDetector _whereJoinDetector;

private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel)
{
_sessionFactory = sessionFactory;
var joiner = new Joiner(queryModel);
_selectJoinDetector = new SelectJoinDetector(this, joiner);
_resultOperatorAndOrderByJoinDetector = new ResultOperatorAndOrderByJoinDetector(this, joiner);
_memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner);
_whereJoinDetector = new WhereJoinDetector(this, joiner);
}

public static void ReWrite(QueryModel queryModel, ISessionFactoryImplementor sessionFactory)
public static void ReWrite(QueryModel queryModel, VisitorParameters parameters)
{
new AddJoinsReWriter(sessionFactory, queryModel).VisitQueryModel(queryModel);
var visitor = new AddJoinsReWriter(parameters.SessionFactory, queryModel);
visitor.VisitQueryModel(queryModel);
}

public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel)
{
_selectJoinDetector.Transform(selectClause);
_memberExpressionJoinDetector.Transform(selectClause);
}

public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index)
{
_resultOperatorAndOrderByJoinDetector.Transform(ordering);
_memberExpressionJoinDetector.Transform(ordering);
}

public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index)
{
_resultOperatorAndOrderByJoinDetector.Transform(resultOperator);
_memberExpressionJoinDetector.Transform(resultOperator);
}

public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index)
Expand Down
10 changes: 5 additions & 5 deletions src/NHibernate/Linq/Visitors/JoinBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ public void MakeInnerIfJoined(string key)
public bool CanAddJoin(Expression expression)
{
var source = QuerySourceExtractor.GetQuerySource(expression);
if (_queryModel.MainFromClause == source)

if (_queryModel.MainFromClause == source)
return true;

var bodyClause = source as IBodyClause;
if (bodyClause != null && _queryModel.BodyClauses.Contains(bodyClause))
if (bodyClause != null && _queryModel.BodyClauses.Contains(bodyClause))
return true;

var resultOperatorBase = source as ResultOperatorBase;
return resultOperatorBase != null && _queryModel.ResultOperators.Contains(resultOperatorBase);
}
Expand Down
91 changes: 91 additions & 0 deletions src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq.Expressions;
using NHibernate.Linq.ReWriters;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Parsing;

namespace NHibernate.Linq.Visitors
{
/// <summary>
/// Detects joins in Select, OrderBy and Results (GroupBy) clauses.
/// Replaces them with appropriate joins, maintaining reference equality between different clauses.
/// This allows extracted GroupBy key expression to also be replaced so that they can continue to match replaced Select expressions
/// </summary>
internal class MemberExpressionJoinDetector : ExpressionTreeVisitor
{
private readonly IIsEntityDecider _isEntityDecider;
private readonly IJoiner _joiner;

private bool _requiresJoinForNonIdentifier;
private bool _hasIdentifier;
private int _memberExpressionDepth;

public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner)
{
_isEntityDecider = isEntityDecider;
_joiner = joiner;
}

protected override Expression VisitMemberExpression(MemberExpression expression)
{
var isIdentifier = _isEntityDecider.IsIdentifier(expression.Expression.Type, expression.Member.Name);
if (isIdentifier)
_hasIdentifier = true;
if (!isIdentifier)
_memberExpressionDepth++;

var result = base.VisitMemberExpression(expression);

if (!isIdentifier)
_memberExpressionDepth--;

if (_isEntityDecider.IsEntity(expression.Type) &&
((_requiresJoinForNonIdentifier && !_hasIdentifier) || _memberExpressionDepth > 0) &&
_joiner.CanAddJoin(expression))
{
var key = ExpressionKeyVisitor.Visit(expression, null);
return _joiner.AddJoin(result, key);
}

return result;
}

protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
{
expression.QueryModel.TransformExpressions(VisitExpression);
return expression;
}

protected override Expression VisitConditionalExpression(ConditionalExpression expression)
{
var oldRequiresJoinForNonIdentifier = _requiresJoinForNonIdentifier;
_requiresJoinForNonIdentifier = false;
var newTest = VisitExpression(expression.Test);
_requiresJoinForNonIdentifier = oldRequiresJoinForNonIdentifier;
var newFalse = VisitExpression(expression.IfFalse);
var newTrue = VisitExpression(expression.IfTrue);
if ((newTest != expression.Test) || (newFalse != expression.IfFalse) || (newTrue != expression.IfTrue))
return Expression.Condition(newTest, newTrue, newFalse);
return expression;
}

public void Transform(SelectClause selectClause)
{
_requiresJoinForNonIdentifier = true;
selectClause.TransformExpressions(VisitExpression);
_requiresJoinForNonIdentifier = false;
}

public void Transform(ResultOperatorBase resultOperator)
{
resultOperator.TransformExpressions(VisitExpression);
}

public void Transform(Ordering ordering)
{
ordering.TransformExpressions(VisitExpression);
}
}
}
6 changes: 3 additions & 3 deletions src/NHibernate/Linq/Visitors/QueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
NonAggregatingGroupByRewriter.ReWrite(queryModel);

// Rewrite aggregate group-by statements
AggregatingGroupByRewriter.ReWrite(queryModel, parameters.GroupByKeys);
AggregatingGroupByRewriter.ReWrite(queryModel);

// Rewrite aggregating group-joins
AggregatingGroupJoinRewriter.ReWrite(queryModel);
Expand All @@ -57,7 +57,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
ArrayIndexExpressionFlattener.ReWrite(queryModel);

// Add joins for references
AddJoinsReWriter.ReWrite(queryModel, parameters.SessionFactory);
AddJoinsReWriter.ReWrite(queryModel, parameters);

// Move OrderBy clauses to end
MoveOrderByToEndRewriter.ReWrite(queryModel);
Expand Down Expand Up @@ -238,7 +238,7 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que
{
CurrentEvaluationType = selectClause.GetOutputDataInfo();

var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters, VisitorParameters.GroupByKeys);
var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters);

visitor.Visit(selectClause.Selector);

Expand Down
Loading