Skip to content

Commit 9dd77ca

Browse files
authored
Fix invalid join on subclass columns (#2680)
1 parent bad7e99 commit 9dd77ca

File tree

11 files changed

+135
-43
lines changed

11 files changed

+135
-43
lines changed

src/NHibernate.Test/Async/NHSpecificTest/NH1747/Fixture.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
//------------------------------------------------------------------------------
99

1010

11+
using System.Linq;
12+
using NHibernate.Linq;
1113
using NUnit.Framework;
1214

1315
namespace NHibernate.Test.NHSpecificTest.NH1747
@@ -51,5 +53,25 @@ public async Task TraversingBagToJoinChildElementShouldWorkAsync()
5153
Assert.AreEqual(1, paymentBatch.Payments.Count);
5254
}
5355
}
56+
57+
[Test]
58+
public async Task TraversingBagToJoinChildElementShouldWorkLinqFetchAsync()
59+
{
60+
using (ISession session = OpenSession())
61+
{
62+
var paymentBatch = await (session.Query<PaymentBatch>().Fetch(x => x.Payments).SingleOrDefaultAsync());
63+
Assert.AreEqual(1, paymentBatch.Payments.Count);
64+
}
65+
}
66+
67+
[Test]
68+
public async Task TraversingBagToJoinChildElementShouldWorkQueryOverFetchAsync()
69+
{
70+
using (ISession session = OpenSession())
71+
{
72+
var paymentBatch = await (session.Query<PaymentBatch>().Fetch(x => x.Payments).SingleOrDefaultAsync());
73+
Assert.AreEqual(1, paymentBatch.Payments.Count);
74+
}
75+
}
5476
}
5577
}

src/NHibernate.Test/Async/NHSpecificTest/NH2174/Fixture.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ protected override void OnTearDown()
4141
}
4242
}
4343

44-
[KnownBug("Not fixed yet")]
4544
[Test]
4645
public async Task LinqFetchAsync()
4746
{
@@ -54,7 +53,6 @@ public async Task LinqFetchAsync()
5453
}
5554
}
5655

57-
[KnownBug("Not fixed yet")]
5856
[Test]
5957
public async Task QueryOverFetchAsync()
6058
{

src/NHibernate.Test/NHSpecificTest/NH1747/Fixture.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using NUnit.Framework;
1+
using System.Linq;
2+
using NHibernate.Linq;
3+
using NUnit.Framework;
24

35
namespace NHibernate.Test.NHSpecificTest.NH1747
46
{
@@ -52,5 +54,25 @@ public void TraversingBagToJoinChildElementShouldWork()
5254
Assert.AreEqual(1, paymentBatch.Payments.Count);
5355
}
5456
}
57+
58+
[Test]
59+
public void TraversingBagToJoinChildElementShouldWorkLinqFetch()
60+
{
61+
using (ISession session = OpenSession())
62+
{
63+
var paymentBatch = session.Query<PaymentBatch>().Fetch(x => x.Payments).SingleOrDefault();
64+
Assert.AreEqual(1, paymentBatch.Payments.Count);
65+
}
66+
}
67+
68+
[Test]
69+
public void TraversingBagToJoinChildElementShouldWorkQueryOverFetch()
70+
{
71+
using (ISession session = OpenSession())
72+
{
73+
var paymentBatch = session.Query<PaymentBatch>().Fetch(x => x.Payments).SingleOrDefault();
74+
Assert.AreEqual(1, paymentBatch.Payments.Count);
75+
}
76+
}
5577
}
5678
}

src/NHibernate.Test/NHSpecificTest/NH2174/Fixture.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ protected override void OnTearDown()
3030
}
3131
}
3232

33-
[KnownBug("Not fixed yet")]
3433
[Test]
3534
public void LinqFetch()
3635
{
@@ -43,7 +42,6 @@ public void LinqFetch()
4342
}
4443
}
4544

46-
[KnownBug("Not fixed yet")]
4745
[Test]
4846
public void QueryOverFetch()
4947
{

src/NHibernate/Engine/IJoin.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ internal interface IJoin
88
{
99
IJoinable Joinable { get; }
1010
string[] LHSColumns { get; }
11+
string[] RHSColumns { get; }
1112
string Alias { get; }
1213
IAssociationType AssociationType { get; }
1314
JoinType JoinType { get; }

src/NHibernate/Engine/JoinHelper.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,20 @@ public static ILhsAssociationTypeSqlInfo GetIdLhsSqlInfo(string alias, IOuterJoi
2222
/// be used in the join
2323
/// </summary>
2424
public static string[] GetRHSColumnNames(IAssociationType type, ISessionFactoryImplementor factory)
25+
{
26+
return GetRHSColumnNames(type.GetAssociatedJoinable(factory), type);
27+
}
28+
29+
/// <summary>
30+
/// Get the columns of the associated table which are to
31+
/// be used in the join
32+
/// </summary>
33+
public static string[] GetRHSColumnNames(IJoinable joinable, IAssociationType type)
2534
{
2635
string uniqueKeyPropertyName = type.RHSUniqueKeyPropertyName;
27-
IJoinable joinable = type.GetAssociatedJoinable(factory);
28-
if (uniqueKeyPropertyName == null)
29-
{
30-
return joinable.KeyColumnNames;
31-
}
32-
else
33-
{
34-
return ((IOuterJoinLoadable)joinable).GetPropertyColumnNames(uniqueKeyPropertyName);
35-
}
36+
return uniqueKeyPropertyName == null
37+
? joinable.KeyColumnNames
38+
: ((IOuterJoinLoadable) joinable).GetPropertyColumnNames(uniqueKeyPropertyName);
3639
}
3740
}
3841

src/NHibernate/Engine/JoinSequence.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ private sealed class Join : IJoin
4747
private readonly JoinType joinType;
4848
private readonly string alias;
4949
private readonly string[] lhsColumns;
50+
private readonly string[] rhsColumns;
5051

5152
public Join(ISessionFactoryImplementor factory, IAssociationType associationType, string alias, JoinType joinType,
5253
string[] lhsColumns)
@@ -56,6 +57,9 @@ public Join(ISessionFactoryImplementor factory, IAssociationType associationType
5657
this.alias = alias;
5758
this.joinType = joinType;
5859
this.lhsColumns = lhsColumns;
60+
this.rhsColumns = lhsColumns.Length > 0
61+
? JoinHelper.GetRHSColumnNames(joinable, associationType)
62+
: Array.Empty<string>();
5963
}
6064

6165
public string Alias
@@ -83,6 +87,11 @@ public string[] LHSColumns
8387
get { return lhsColumns; }
8488
}
8589

90+
public string[] RHSColumns
91+
{
92+
get { return rhsColumns; }
93+
}
94+
8695
public override string ToString()
8796
{
8897
return joinable.ToString() + '[' + alias + ']';
@@ -195,7 +204,7 @@ internal JoinFragment ToJoinFragment(
195204
join.Joinable.TableName,
196205
join.Alias,
197206
join.LHSColumns,
198-
JoinHelper.GetRHSColumnNames(join.AssociationType, factory),
207+
join.RHSColumns,
199208
join.JoinType,
200209
withClauses[i]
201210
);

src/NHibernate/Engine/TableGroupJoinHelper.cs

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Linq;
4+
using NHibernate.Persister.Collection;
45
using NHibernate.Persister.Entity;
56
using NHibernate.SqlCommand;
67

@@ -38,7 +39,7 @@ internal static bool ProcessAsTableGroupJoin(IReadOnlyList<IJoin> tableGroupJoin
3839
join.Joinable.TableName,
3940
join.Alias,
4041
join.LHSColumns,
41-
JoinHelper.GetRHSColumnNames(join.AssociationType, sessionFactoryImplementor),
42+
join.RHSColumns,
4243
join.JoinType,
4344
SqlString.Empty);
4445

@@ -51,32 +52,37 @@ internal static bool ProcessAsTableGroupJoin(IReadOnlyList<IJoin> tableGroupJoin
5152
join.Joinable.WhereJoinFragment(join.Alias, innerJoin, include));
5253
}
5354

54-
var withClause = GetTableGroupJoinWithClause(withClauseFragments, first, sessionFactoryImplementor);
55+
var withClause = GetTableGroupJoinWithClause(withClauseFragments, first);
5556
joinFragment.AddFromFragmentString(withClause);
5657
return true;
5758
}
5859

60+
// detect cases when withClause is used on multiple tables or when join keys depend on subclass columns
5961
private static bool NeedsTableGroupJoin(IReadOnlyList<IJoin> joins, SqlString[] withClauseFragments, bool includeSubclasses)
6062
{
61-
// If we don't have a with clause, we don't need a table group join
62-
if (withClauseFragments.All(x => SqlStringHelper.IsEmpty(x)))
63-
{
64-
return false;
65-
}
63+
bool hasWithClause = withClauseFragments.Any(x => SqlStringHelper.IsNotEmpty(x));
6664

67-
// If we only have one join, a table group join is only necessary if subclass columns are used in the with clause
68-
if (joins.Count == 1)
65+
//NH Specific: No alias processing (see hibernate JoinSequence.NeedsTableGroupJoin)
66+
if (joins.Count > 1 && hasWithClause)
67+
return true;
68+
69+
foreach (var join in joins)
6970
{
70-
return joins[0].Joinable is AbstractEntityPersister persister && persister.HasSubclassJoins(includeSubclasses);
71-
//NH Specific: No alias processing
72-
//return isSubclassAliasDereferenced( joins[ 0], withClauseFragment );
71+
var entityPersister = GetEntityPersister(join.Joinable);
72+
if (entityPersister?.HasSubclassJoins(includeSubclasses) != true)
73+
continue;
74+
75+
if (hasWithClause)
76+
return true;
77+
78+
if (entityPersister.ColumnsDependOnSubclassJoins(join.RHSColumns))
79+
return true;
7380
}
7481

75-
//NH Specific: No alias processing (see hibernate JoinSequence.NeedsTableGroupJoin)
76-
return true;
82+
return false;
7783
}
7884

79-
private static SqlString GetTableGroupJoinWithClause(SqlString[] withClauseFragments, IJoin first, ISessionFactoryImplementor factory)
85+
private static SqlString GetTableGroupJoinWithClause(SqlString[] withClauseFragments, IJoin first)
8086
{
8187
SqlStringBuilder fromFragment = new SqlStringBuilder();
8288
fromFragment.Add(")").Add(" on ");
@@ -85,12 +91,18 @@ private static SqlString GetTableGroupJoinWithClause(SqlString[] withClauseFragm
8591
var isAssociationJoin = lhsColumns.Length > 0;
8692
if (isAssociationJoin)
8793
{
94+
var entityPersister = GetEntityPersister(first.Joinable);
8895
string rhsAlias = first.Alias;
89-
string[] rhsColumns = JoinHelper.GetRHSColumnNames(first.AssociationType, factory);
90-
fromFragment.Add(lhsColumns[0]).Add("=").Add(rhsAlias).Add(".").Add(rhsColumns[0]);
91-
for (int j = 1; j < lhsColumns.Length; j++)
96+
string[] rhsColumns = first.RHSColumns;
97+
for (int j = 0; j < lhsColumns.Length; j++)
9298
{
93-
fromFragment.Add(" and ").Add(lhsColumns[j]).Add("=").Add(rhsAlias).Add(".").Add(rhsColumns[j]);
99+
fromFragment.Add(lhsColumns[j])
100+
.Add("=")
101+
.Add(entityPersister?.GenerateTableAliasForColumn(rhsAlias, rhsColumns[j]) ?? rhsAlias)
102+
.Add(".")
103+
.Add(rhsColumns[j]);
104+
if (j != lhsColumns.Length - 1)
105+
fromFragment.Add(" and ");
94106
}
95107
}
96108

@@ -99,6 +111,15 @@ private static SqlString GetTableGroupJoinWithClause(SqlString[] withClauseFragm
99111
return fromFragment.ToSqlString();
100112
}
101113

114+
private static AbstractEntityPersister GetEntityPersister(IJoinable joinable)
115+
{
116+
if (!joinable.IsCollection)
117+
return joinable as AbstractEntityPersister;
118+
119+
var collection = (IQueryableCollection) joinable;
120+
return collection.ElementType.IsEntityType ? collection.ElementPersister as AbstractEntityPersister : null;
121+
}
122+
102123
private static void AppendWithClause(SqlStringBuilder fromFragment, bool hasConditions, SqlString[] withClauseFragments)
103124
{
104125
for (var i = 0; i < withClauseFragments.Length; i++)

src/NHibernate/Loader/JoinWalker.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,8 +751,9 @@ protected virtual bool IsDuplicateAssociation(string lhsTable, string[] lhsColum
751751
}
752752
else
753753
{
754-
foreignKeyTable = type.GetAssociatedJoinable(Factory).TableName;
755-
foreignKeyColumns = JoinHelper.GetRHSColumnNames(type, Factory);
754+
var joinable = type.GetAssociatedJoinable(Factory);
755+
foreignKeyTable = joinable.TableName;
756+
foreignKeyColumns = JoinHelper.GetRHSColumnNames(joinable, type);
756757
}
757758

758759
return IsDuplicateAssociation(foreignKeyTable, foreignKeyColumns);

src/NHibernate/Loader/OuterJoinableAssociation.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public OuterJoinableAssociation(IAssociationType joinableType, String lhsAlias,
4646
this.rhsAlias = rhsAlias;
4747
this.joinType = joinType;
4848
joinable = joinableType.GetAssociatedJoinable(factory);
49-
rhsColumns = JoinHelper.GetRHSColumnNames(joinableType, factory);
49+
rhsColumns = JoinHelper.GetRHSColumnNames(joinable, joinableType);
5050
on = new SqlString(joinableType.GetOnCondition(rhsAlias, factory, enabledFilters));
5151
if (SqlStringHelper.IsNotEmpty(withClause))
5252
on = on.Append(" and ( ", withClause, " )");
@@ -105,6 +105,7 @@ public SelectMode SelectMode
105105
string[] IJoin.LHSColumns => lhsColumns;
106106
string IJoin.Alias => RHSAlias;
107107
IAssociationType IJoin.AssociationType => JoinableType;
108+
string[] IJoin.RHSColumns => rhsColumns;
108109

109110
public int GetOwner(IList<OuterJoinableAssociation> associations)
110111
{

src/NHibernate/Persister/Entity/AbstractEntityPersister.cs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,14 +2138,20 @@ public virtual Declarer GetSubclassPropertyDeclarer(string propertyPath)
21382138

21392139
public virtual string GenerateTableAliasForColumn(string rootAlias, string column)
21402140
{
2141-
int propertyIndex = Array.IndexOf(SubclassColumnClosure, column);
2141+
return GenerateTableAlias(rootAlias, GetColumnTableNumber(column));
2142+
}
2143+
2144+
private int GetColumnTableNumber(string column)
2145+
{
2146+
if (SubclassTableSpan == 1)
2147+
return 0;
2148+
2149+
int i = Array.IndexOf(SubclassColumnClosure, column);
21422150

21432151
// The check for KeyColumnNames was added to fix NH-2491
2144-
if (propertyIndex < 0 || Array.IndexOf(KeyColumnNames, column) >= 0)
2145-
{
2146-
return rootAlias;
2147-
}
2148-
return GenerateTableAlias(rootAlias, SubclassColumnTableNumberClosure[propertyIndex]);
2152+
return i < 0 || Array.IndexOf(KeyColumnNames, column) >= 0
2153+
? 0
2154+
: SubclassColumnTableNumberClosure[i];
21492155
}
21502156

21512157
public string GenerateTableAlias(string rootAlias, int tableNumber)
@@ -3796,6 +3802,16 @@ private JoinFragment CreateJoin(string name, bool innerjoin, bool includeSubclas
37963802
return join;
37973803
}
37983804

3805+
internal bool ColumnsDependOnSubclassJoins(string[] columns)
3806+
{
3807+
foreach (var column in columns)
3808+
{
3809+
if (GetColumnTableNumber(column) > 0)
3810+
return true;
3811+
}
3812+
return false;
3813+
}
3814+
37993815
internal bool HasSubclassJoins(bool includeSubclasses)
38003816
{
38013817
if (SubclassTableSpan == 1)

0 commit comments

Comments
 (0)