Skip to content

Commit b57e654

Browse files
committed
Avoid cross join for associations inside an outer key selector
1 parent b17820a commit b57e654

File tree

7 files changed

+165
-24
lines changed

7 files changed

+165
-24
lines changed

src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,13 @@ from c in db.Customers
965965
join o in db.Orders on c.CustomerId equals o.Customer.CustomerId
966966
select new { c.ContactName, o.OrderId };
967967

968-
await (ObjectDumper.WriteAsync(q));
968+
using (var sqlSpy = new SqlLogSpy())
969+
{
970+
await (ObjectDumper.WriteAsync(q));
971+
972+
var sql = sqlSpy.GetWholeLog();
973+
Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1));
974+
}
969975
}
970976

971977
[Category("JOIN")]
@@ -1005,7 +1011,9 @@ public async Task DLinqJoin5dAsync(bool useCrossJoin)
10051011

10061012
var q =
10071013
from c in db.Customers
1008-
join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null }
1014+
join o in db.Orders on
1015+
new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals
1016+
new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null }
10091017
select new { c.ContactName, o.OrderId };
10101018

10111019
using (var substitute = SubstituteDialect())
@@ -1041,6 +1049,27 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId
10411049
}
10421050
}
10431051

1052+
[Category("JOIN")]
1053+
[TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")]
1054+
public async Task DLinqJoin5fAsync()
1055+
{
1056+
var q =
1057+
from o in db.Orders
1058+
join c in db.Customers on
1059+
new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals
1060+
new { c.CustomerId, HasContractTitle = c.ContactTitle != null }
1061+
select new { c.ContactName, o.OrderId };
1062+
1063+
using (var sqlSpy = new SqlLogSpy())
1064+
{
1065+
await (ObjectDumper.WriteAsync(q));
1066+
1067+
var sql = sqlSpy.GetWholeLog();
1068+
Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1));
1069+
Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1));
1070+
}
1071+
}
1072+
10441073
[Category("JOIN")]
10451074
[Test(Description = "This sample explictly joins three tables and projects results from each of them.")]
10461075
public async Task DLinqJoin6Async()
@@ -1138,5 +1167,24 @@ group o by c into x
11381167

11391168
await (ObjectDumper.WriteAsync(q));
11401169
}
1170+
1171+
[Category("JOIN")]
1172+
[Test(Description = "This sample shows how to join multiple tables.")]
1173+
public async Task DLinqJoin10aAsync()
1174+
{
1175+
var q =
1176+
from e in db.Employees
1177+
join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId
1178+
join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId
1179+
select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName };
1180+
1181+
using (var sqlSpy = new SqlLogSpy())
1182+
{
1183+
await (ObjectDumper.WriteAsync(q));
1184+
1185+
var sql = sqlSpy.GetWholeLog();
1186+
Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2));
1187+
}
1188+
}
11411189
}
11421190
}

src/NHibernate.Test/Linq/LinqQuerySamples.cs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,13 @@ from c in db.Customers
15091509
join o in db.Orders on c.CustomerId equals o.Customer.CustomerId
15101510
select new { c.ContactName, o.OrderId };
15111511

1512-
ObjectDumper.Write(q);
1512+
using (var sqlSpy = new SqlLogSpy())
1513+
{
1514+
ObjectDumper.Write(q);
1515+
1516+
var sql = sqlSpy.GetWholeLog();
1517+
Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1));
1518+
}
15131519
}
15141520

15151521
[Category("JOIN")]
@@ -1549,7 +1555,9 @@ public void DLinqJoin5d(bool useCrossJoin)
15491555

15501556
var q =
15511557
from c in db.Customers
1552-
join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null }
1558+
join o in db.Orders on
1559+
new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals
1560+
new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null }
15531561
select new { c.ContactName, o.OrderId };
15541562

15551563
using (var substitute = SubstituteDialect())
@@ -1585,6 +1593,27 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId
15851593
}
15861594
}
15871595

1596+
[Category("JOIN")]
1597+
[TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")]
1598+
public void DLinqJoin5f()
1599+
{
1600+
var q =
1601+
from o in db.Orders
1602+
join c in db.Customers on
1603+
new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals
1604+
new { c.CustomerId, HasContractTitle = c.ContactTitle != null }
1605+
select new { c.ContactName, o.OrderId };
1606+
1607+
using (var sqlSpy = new SqlLogSpy())
1608+
{
1609+
ObjectDumper.Write(q);
1610+
1611+
var sql = sqlSpy.GetWholeLog();
1612+
Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1));
1613+
Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1));
1614+
}
1615+
}
1616+
15881617
[Category("JOIN")]
15891618
[Test(Description = "This sample explictly joins three tables and projects results from each of them.")]
15901619
public void DLinqJoin6()
@@ -1706,6 +1735,25 @@ group o by c into x
17061735
ObjectDumper.Write(q);
17071736
}
17081737

1738+
[Category("JOIN")]
1739+
[Test(Description = "This sample shows how to join multiple tables.")]
1740+
public void DLinqJoin10a()
1741+
{
1742+
var q =
1743+
from e in db.Employees
1744+
join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId
1745+
join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId
1746+
select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName };
1747+
1748+
using (var sqlSpy = new SqlLogSpy())
1749+
{
1750+
ObjectDumper.Write(q);
1751+
1752+
var sql = sqlSpy.GetWholeLog();
1753+
Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2));
1754+
}
1755+
}
1756+
17091757
[Category("WHERE")]
17101758
[Test(Description = "This sample uses WHERE to filter for orders with shipping date equals to null.")]
17111759
public void DLinq2B()

src/NHibernate/Linq/Clauses/NhJoinClause.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public NhJoinClause(string itemName, System.Type itemType, Expression fromExpres
5454

5555
public bool IsInner { get; private set; }
5656

57-
internal IBodyClause RelatedBodyClause { get; set; }
57+
internal JoinClause ParentJoinClause { get; set; }
5858

5959
public void TransformExpressions(Func<Expression, Expression> transformation)
6060
{

src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ public class AddJoinsReWriter : NhQueryModelVisitorBase, IIsEntityDecider
2020
private readonly ISessionFactoryImplementor _sessionFactory;
2121
private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector;
2222
private readonly WhereJoinDetector _whereJoinDetector;
23+
private int? _joinInsertIndex;
24+
private JoinClause _currentJoin;
2325

2426
private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel)
2527
{
2628
_sessionFactory = sessionFactory;
27-
var joiner = new Joiner(queryModel);
29+
var joiner = new Joiner(queryModel, AddJoin);
2830
_memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner);
2931
_whereJoinDetector = new WhereJoinDetector(this, joiner);
3032
}
@@ -62,20 +64,25 @@ public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel
6264

6365
public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index)
6466
{
65-
// When there are association navigations inside an on clause (e.g. c.ContactTitle equals o.Customer.ContactTitle),
67+
VisitJoinClause(joinClause, queryModel, joinClause);
68+
}
69+
70+
private void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, IBodyClause bodyClause)
71+
{
72+
joinClause.InnerSequence = _whereJoinDetector.Transform(joinClause.InnerSequence);
73+
74+
// When associations are located in the outer key (e.g. from a in A join b in B b on a.C.D.Id equals b.Id),
75+
// we have to insert the association join before the current join in order to produce a valid query.
76+
_joinInsertIndex = queryModel.BodyClauses.IndexOf(bodyClause);
77+
joinClause.OuterKeySelector = _whereJoinDetector.Transform(joinClause.OuterKeySelector);
78+
_joinInsertIndex = null;
79+
80+
// When associations are located in the inner key (e.g. from a in A join b in B b on a.Id equals b.C.D.Id),
6681
// we have to move the condition to the where statement, otherwise the query will be invalid.
6782
// Link newly created joins with the current join clause in order to later detect which join type to use.
68-
queryModel.BodyClauses.CollectionChanged += OnCollectionChange;
69-
_whereJoinDetector.Transform(joinClause);
70-
queryModel.BodyClauses.CollectionChanged -= OnCollectionChange;
71-
72-
void OnCollectionChange(object sender, NotifyCollectionChangedEventArgs e)
73-
{
74-
foreach (var nhJoinClause in e.NewItems.OfType<NhJoinClause>())
75-
{
76-
nhJoinClause.RelatedBodyClause = joinClause;
77-
}
78-
}
83+
_currentJoin = joinClause;
84+
joinClause.InnerKeySelector = _whereJoinDetector.Transform(joinClause.InnerKeySelector);
85+
_currentJoin = null;
7986
}
8087

8188
public bool IsEntity(System.Type type)
@@ -88,5 +95,19 @@ public bool IsIdentifier(System.Type type, string propertyName)
8895
var metadata = _sessionFactory.GetClassMetadata(type);
8996
return metadata != null && propertyName.Equals(metadata.IdentifierPropertyName);
9097
}
98+
99+
private void AddJoin(QueryModel queryModel, NhJoinClause joinClause)
100+
{
101+
joinClause.ParentJoinClause = _currentJoin;
102+
if (_joinInsertIndex.HasValue)
103+
{
104+
queryModel.BodyClauses.Insert(_joinInsertIndex.Value, joinClause);
105+
_joinInsertIndex++;
106+
}
107+
else
108+
{
109+
queryModel.BodyClauses.Add(joinClause);
110+
}
111+
}
91112
}
92113
}

src/NHibernate/Linq/Visitors/JoinBuilder.cs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,20 @@ public class Joiner : IJoiner
2121
private readonly NameGenerator _nameGenerator;
2222
private readonly QueryModel _queryModel;
2323

24+
internal Joiner(QueryModel queryModel, System.Action<QueryModel, NhJoinClause> addJoinMethod)
25+
: this(queryModel)
26+
{
27+
AddJoinMethod = addJoinMethod;
28+
}
29+
2430
internal Joiner(QueryModel queryModel)
2531
{
2632
_nameGenerator = new NameGenerator(queryModel);
2733
_queryModel = queryModel;
34+
AddJoinMethod = AddJoin;
2835
}
36+
37+
internal System.Action<QueryModel, NhJoinClause> AddJoinMethod { get; }
2938

3039
public IEnumerable<NhJoinClause> Joins
3140
{
@@ -39,7 +48,7 @@ public Expression AddJoin(Expression expression, string key)
3948
if (!_joins.TryGetValue(key, out join))
4049
{
4150
join = new NhJoinClause(_nameGenerator.GetNewName(), expression.Type, expression);
42-
_queryModel.BodyClauses.Add(join);
51+
AddJoinMethod(_queryModel, join);
4352
_joins.Add(key, join);
4453
}
4554

@@ -72,6 +81,11 @@ public bool CanAddJoin(Expression expression)
7281
return resultOperatorBase != null && _queryModel.ResultOperators.Contains(resultOperatorBase);
7382
}
7483

84+
private void AddJoin(QueryModel queryModel, NhJoinClause joinClause)
85+
{
86+
queryModel.BodyClauses.Add(joinClause);
87+
}
88+
7589
private class QuerySourceExtractor : RelinqExpressionVisitor
7690
{
7791
private IQuerySource _querySource;
@@ -90,4 +104,4 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr
90104
}
91105
}
92106
}
93-
}
107+
}

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,9 @@ public override void VisitJoinClause(JoinClause joinClause, QueryModel queryMode
518518
var alias = _hqlTree.TreeBuilder.Alias(VisitorParameters.QuerySourceNamer.GetName(joinClause));
519519
var joinExpression = HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters);
520520
HqlTreeNode join;
521-
// When there are association navigations inside an on clause:
522-
// from c in db.Customers join o in db.Orders on c.ContactTitle equals o.Customer.ContactTitle
523-
// we have to use a cross join instead of inner join and add the condition in the where statement.
524-
if (queryModel.BodyClauses.OfType<NhJoinClause>().Any(o => o.RelatedBodyClause == joinClause))
521+
// When associations are located inside the inner key selector we have to use a cross join instead of an inner
522+
// join and add the condition in the where statement.
523+
if (queryModel.BodyClauses.OfType<NhJoinClause>().Any(o => o.ParentJoinClause == joinClause))
525524
{
526525
_hqlTree.AddWhereClause(withClause);
527526
join = CreateCrossJoin(joinExpression, alias);

src/NHibernate/Linq/Visitors/WhereJoinDetector.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,21 @@ internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner)
7777
_joiner = joiner;
7878
}
7979

80+
public Expression Transform(Expression expression)
81+
{
82+
var result = Visit(expression);
83+
PostTransform();
84+
return result;
85+
}
86+
8087
public void Transform(IClause whereClause)
8188
{
8289
whereClause.TransformExpressions(Visit);
90+
PostTransform();
91+
}
8392

93+
private void PostTransform()
94+
{
8495
var values = _values.Pop();
8596

8697
foreach (var memberExpression in values.MemberExpressions)

0 commit comments

Comments
 (0)