diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index 987ed002f60..9fda2e873c9 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -148,5 +148,13 @@ public async Task CanJoinOnEntityWithSubclassesAsync() from o2 in db.Animals.Where(x => x.BodyWeight > 50) select new {o, o2}).Take(1).ToListAsync()); } + + [Test(Description = "GH-2580")] + public async Task CanInnerJoinOnSubclassWithBaseTableReferenceInOnClauseAsync() + { + var result = await ((from o in db.Animals + join o2 in db.Mammals on o.BodyWeight equals o2.BodyWeight + select new { o, o2 }).Take(1).ToListAsync()); + } } } diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index 8dce4d39223..a7a4750b87a 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -137,5 +137,13 @@ public void CanJoinOnEntityWithSubclasses() from o2 in db.Animals.Where(x => x.BodyWeight > 50) select new {o, o2}).Take(1).ToList(); } + + [Test(Description = "GH-2580")] + public void CanInnerJoinOnSubclassWithBaseTableReferenceInOnClause() + { + var result = (from o in db.Animals + join o2 in db.Mammals on o.BodyWeight equals o2.BodyWeight + select new { o, o2 }).Take(1).ToList(); + } } } diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 040e9b38932..62abaae77cb 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using NHibernate.Engine; using NHibernate.Hql.Ast; using NHibernate.Linq.Clauses; using NHibernate.Linq.Expressions; @@ -12,6 +13,7 @@ using NHibernate.Linq.ResultOperators; using NHibernate.Linq.ReWriters; using NHibernate.Linq.Visitors.ResultOperatorProcessors; +using NHibernate.Persister.Entity; using NHibernate.Util; using Remotion.Linq; using Remotion.Linq.Clauses; @@ -527,10 +529,13 @@ private void AddJoin(JoinClause joinClause, QueryModel queryModel, bool innerJoi var withClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); var alias = _hqlTree.TreeBuilder.Alias(VisitorParameters.QuerySourceNamer.GetName(joinClause)); var joinExpression = HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters); + var baseMemberCheker = new BaseMemberChecker(VisitorParameters.SessionFactory); + HqlTreeNode join; - // When associations are located inside the inner key selector we have to use a cross join instead of an inner - // join and add the condition in the where statement. - if (queryModel.BodyClauses.OfType().Any(o => o.ParentJoinClause == joinClause)) + // When associations or members from another table are located inside the inner key selector we have to use a cross join + // instead of an inner join and add the condition in the where statement. + if (queryModel.BodyClauses.OfType().Any(o => o.ParentJoinClause == joinClause) || + queryModel.BodyClauses.OfType().Any(baseMemberCheker.ContainsBaseMember)) { if (!innerJoin) { @@ -551,6 +556,54 @@ private void AddJoin(JoinClause joinClause, QueryModel queryModel, bool innerJoi _hqlTree.AddFromClause(join); } + private class BaseMemberChecker : NhExpressionVisitor + { + private readonly ISessionFactoryImplementor _sessionFactory; + private bool _result; + + public BaseMemberChecker(ISessionFactoryImplementor sessionFactory) + { + _sessionFactory = sessionFactory; + } + + public bool ContainsBaseMember(JoinClause joinClause) + { + // Visit the join inner key only for entities that have subclasses + if (joinClause.InnerSequence is ConstantExpression constantNode && + constantNode.Value is IEntityNameProvider entityNameProvider && + !_sessionFactory.GetEntityPersister(entityNameProvider.EntityName).EntityMetamodel.HasSubclasses) + { + return false; + } + + _result = false; + Visit(joinClause.InnerKeySelector); + + return _result; + } + + protected override Expression VisitMember(MemberExpression node) + { + if (ExpressionsHelper.TryGetMappedType( + _sessionFactory, + node, + out _, + out var persister, + out _, + out var propertyPath) && + persister is IOuterJoinLoadable joinLoadable && + joinLoadable.EntityMetamodel.GetIdentifierPropertyType(propertyPath) == null && + joinLoadable.GetPropertyTableName(propertyPath) != joinLoadable.TableName + ) + { + _result = true; + return node; + } + + return base.VisitMember(node); + } + } + public override void VisitGroupJoinClause(GroupJoinClause groupJoinClause, QueryModel queryModel, int index) { throw new NotImplementedException();