Skip to content

Commit 0b2b83a

Browse files
committed
Add more left join support
1 parent d2599c6 commit 0b2b83a

File tree

4 files changed

+152
-6
lines changed

4 files changed

+152
-6
lines changed

src/NHibernate.Test/Linq/ByMethod/JoinTests.cs

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
using System;
2-
using System.Linq;
3-
using System.Reflection;
4-
using NHibernate.Cfg;
5-
using NHibernate.Engine.Query;
1+
using System.Linq;
62
using NHibernate.Linq;
73
using NHibernate.Util;
84
using NSubstitute;
@@ -103,6 +99,99 @@ public void LeftJoinExtensionMethodWithMultipleKeyProperties()
10399
}
104100
}
105101

102+
[Test]
103+
public void LeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly()
104+
{
105+
using (var sqlSpy = new SqlLogSpy())
106+
{
107+
var animals = db.Animals
108+
.LeftJoin(
109+
db.Mammals,
110+
x => x.Id,
111+
x => x.Id,
112+
(animal, mammal) => new { animal, mammal })
113+
.Where(x => x.mammal.SerialNumber.StartsWith("9"))
114+
.Select(x => new { SerialNumber = x.animal.SerialNumber })
115+
.ToList();
116+
117+
var sql = sqlSpy.GetWholeLog();
118+
Assert.That(animals.Count, Is.EqualTo(1));
119+
Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1));
120+
}
121+
}
122+
123+
[KnownBug("GH-XXXX")]
124+
public void NestedLeftJoinExtensionMethodWithOuterReferenceInWhereClauseOnly()
125+
{
126+
using (var sqlSpy = new SqlLogSpy())
127+
{
128+
var innerAnimals = db.Animals
129+
.LeftJoin(
130+
db.Mammals,
131+
x => x.Id,
132+
x => x.Id,
133+
(animal, mammal) => new { animal, mammal })
134+
.Where(x => x.mammal.SerialNumber.StartsWith("9"))
135+
.Select(x=>x.animal);
136+
137+
var animals = db.Animals
138+
.LeftJoin(
139+
innerAnimals,
140+
x => x.Id,
141+
x => x.Id,
142+
(animal, animal2) => new { animal, animal2 })
143+
.Select(x => new { SerialNumber = x.animal2.SerialNumber })
144+
.ToList();
145+
146+
147+
var sql = sqlSpy.GetWholeLog();
148+
Assert.That(animals.Count, Is.EqualTo(1));
149+
Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1));
150+
}
151+
}
152+
153+
[KnownBug("GH-XXXX")]
154+
public void LeftJoinExtensionMethodWithNoUseOfOuterReference()
155+
{
156+
using (var sqlSpy = new SqlLogSpy())
157+
{
158+
var animals = db.Animals
159+
.LeftJoin(
160+
db.Mammals,
161+
x => x.Id,
162+
x => x.Id,
163+
(animal, mammal) => new { animal, mammal })
164+
.Select(x => x.animal)
165+
.ToList();
166+
167+
var sql = sqlSpy.GetWholeLog();
168+
Assert.That(animals.Count, Is.EqualTo(1));
169+
Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(6));
170+
}
171+
}
172+
173+
[Test]
174+
public void LeftJoinExtensionMethodWithOuterReferenceInOrderByClauseOnly()
175+
{
176+
using (var sqlSpy = new SqlLogSpy())
177+
{
178+
var animals = db.Animals
179+
.LeftJoin(
180+
db.Mammals,
181+
x => x.Id,
182+
x => x.Id,
183+
(animal, mammal) => new { animal, mammal })
184+
.OrderBy(x => x.mammal.SerialNumber ?? "z")
185+
.Select(x => new { SerialNumber = x.animal.SerialNumber })
186+
.ToList();
187+
188+
var sql = sqlSpy.GetWholeLog();
189+
Assert.That(animals.Count, Is.EqualTo(6));
190+
Assert.That(animals[0].SerialNumber, Is.EqualTo("1121"));
191+
Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1));
192+
}
193+
}
194+
106195
[TestCase(false)]
107196
[TestCase(true)]
108197
public void CrossJoinWithPredicateInWhereStatement(bool useCrossJoin)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using System.Collections.Generic;
2+
using System.Linq.Expressions;
3+
using NHibernate.Linq.Visitors;
4+
using Remotion.Linq;
5+
using Remotion.Linq.Clauses;
6+
7+
namespace NHibernate.Linq.GroupJoin
8+
{
9+
internal class GroupJoinAggregateDetectionQueryModelVisitor : NhQueryModelVisitorBase
10+
{
11+
private HashSet<GroupJoinClause> _groupJoinClauses;
12+
private readonly List<Expression> _nonAggregatingExpressions = new List<Expression>();
13+
private readonly List<GroupJoinClause> _nonAggregatingGroupJoins = new List<GroupJoinClause>();
14+
private readonly List<GroupJoinClause> _aggregatingGroupJoins = new List<GroupJoinClause>();
15+
private GroupJoinAggregateDetectionQueryModelVisitor(IEnumerable<GroupJoinClause> groupJoinClauses)
16+
{
17+
_groupJoinClauses = new HashSet<GroupJoinClause>(groupJoinClauses);
18+
}
19+
20+
public static IsAggregatingResults Visit(IEnumerable<GroupJoinClause> groupJoinClause, QueryModel queryModel)
21+
{
22+
var visitor = new GroupJoinAggregateDetectionQueryModelVisitor(groupJoinClause);
23+
24+
visitor.VisitQueryModel(queryModel);
25+
26+
return new IsAggregatingResults { NonAggregatingClauses = visitor._nonAggregatingGroupJoins, AggregatingClauses = visitor._aggregatingGroupJoins, NonAggregatingExpressions = visitor._nonAggregatingExpressions };
27+
}
28+
public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index)
29+
{
30+
var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, whereClause.Predicate);
31+
AddResults(results);
32+
}
33+
34+
public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel)
35+
{
36+
var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, selectClause.Selector);
37+
AddResults(results);
38+
}
39+
40+
public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index)
41+
{
42+
var results = GroupJoinAggregateDetectionVisitor.Visit(_groupJoinClauses, ordering.Expression);
43+
AddResults(results);
44+
}
45+
46+
47+
private void AddResults(IsAggregatingResults results)
48+
{
49+
_nonAggregatingExpressions.AddRange(results.NonAggregatingExpressions);
50+
_nonAggregatingGroupJoins.AddRange(results.NonAggregatingClauses);
51+
_aggregatingGroupJoins.AddRange(results.AggregatingClauses);
52+
}
53+
}
54+
}

src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Collections.ObjectModel;
34
using System.Linq.Expressions;
45
using NHibernate.Linq.Expressions;
56
using NHibernate.Linq.Visitors;
7+
using Remotion.Linq;
68
using Remotion.Linq.Clauses;
79
using Remotion.Linq.Clauses.Expressions;
810

@@ -34,6 +36,7 @@ public static IsAggregatingResults Visit(IEnumerable<GroupJoinClause> groupJoinC
3436

3537
protected override Expression VisitSubQuery(SubQueryExpression expression)
3638
{
39+
//Visit the entire query model?
3740
Visit(expression.QueryModel.SelectClause.Selector);
3841
return expression;
3942
}

src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin, QuerySourceU
159159
// TODO - rename this and share with the AggregatingGroupJoinRewriter
160160
private IsAggregatingResults GetGroupJoinInformation(IEnumerable<GroupJoinClause> clause)
161161
{
162-
return GroupJoinAggregateDetectionVisitor.Visit(clause, _model.SelectClause.Selector);
162+
return GroupJoinAggregateDetectionQueryModelVisitor.Visit(clause, _model);
163163
}
164164
}
165165

0 commit comments

Comments
 (0)