diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/AnotherEntity.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/AnotherEntity.cs new file mode 100644 index 00000000000..c7249716126 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/AnotherEntity.cs @@ -0,0 +1,17 @@ +using System.Collections.Generic; +using NHibernate.Test.NHSpecificTest.NH3845.Interfaces; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Concrete +{ + public class AnotherEntity : IAnotherEntity + { + private ISet _childEntities = new HashSet(); + public virtual int AnotherEntityId { get; set; } + public virtual string Text { get; set; } + public virtual ISet ChildEntities + { + get { return _childEntities; } + protected set { _childEntities = value ?? new HashSet(); } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/MainEntity.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/MainEntity.cs new file mode 100644 index 00000000000..01cb7a15d24 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/MainEntity.cs @@ -0,0 +1,26 @@ +using System.Collections.Generic; +using NHibernate.Test.NHSpecificTest.NH3845.Interfaces; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Concrete +{ + public class MainEntity : IMainEntity + { + private ISet _properties = new HashSet(); + private ISet _separateEntities = new HashSet(); + + public virtual int MainEntityId { get; set; } + public virtual string Text { get; set; } + + public virtual ISet Properties + { + get { return _properties; } + protected set { _properties = value ?? new HashSet(); } + } + + public virtual ISet SeparateEntities + { + get { return _separateEntities; } + protected set { _separateEntities = value ?? new HashSet(); } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityA.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityA.cs new file mode 100644 index 00000000000..ba4160d6ef5 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityA.cs @@ -0,0 +1,9 @@ +using NHibernate.Test.NHSpecificTest.NH3845.Interfaces; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Concrete +{ + public class PropertyEntityA : PropertyEntityBase, IPropertyEntityA + { + public virtual int SerialNumber { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityB.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityB.cs new file mode 100644 index 00000000000..9fa2b2a85aa --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityB.cs @@ -0,0 +1,10 @@ +using NHibernate.Test.NHSpecificTest.NH3845.Interfaces; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Concrete +{ + public class PropertyEntityB : PropertyEntityBase, IPropertyEntityB + { + public virtual string Description { get; set; } + public virtual string AnotherString { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityBase.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityBase.cs new file mode 100644 index 00000000000..889336c05ea --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityBase.cs @@ -0,0 +1,12 @@ +using NHibernate.Test.NHSpecificTest.NH3845.Interfaces; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Concrete +{ + public abstract class PropertyEntityBase : IPropertyEntityBase + { + public virtual int PropertyEntityBaseId { get; set; } + public virtual string Name { get; set; } + public virtual string SharedValue { get; set; } + public virtual IMainEntity MainEntity { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityC.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityC.cs new file mode 100644 index 00000000000..b8fec384357 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/PropertyEntityC.cs @@ -0,0 +1,11 @@ +using NHibernate.Test.NHSpecificTest.NH3845.Interfaces; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Concrete +{ + public class PropertyEntityC : PropertyEntityBase, IPropertyEntityC + { + public virtual string Description { get; set; } + public virtual int AnotherNumber { get; set; } + public virtual IAnotherEntity AnotherEntity { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/SeparateEntity.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/SeparateEntity.cs new file mode 100644 index 00000000000..38606c4cee7 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Concrete/SeparateEntity.cs @@ -0,0 +1,11 @@ +using NHibernate.Test.NHSpecificTest.NH3845.Interfaces; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Concrete +{ + public class SeparateEntity : ISeparateEntity + { + public virtual int SeparateEntityId { get; set; } + public virtual IMainEntity MainEntity { get; set; } + public virtual int SeparateEntityValue { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Fixture.cs new file mode 100644 index 00000000000..c56e4f08541 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Fixture.cs @@ -0,0 +1,226 @@ +using System.Linq; +using NHibernate.Linq; +using NHibernate.Test.NHSpecificTest.NH3845.Concrete; +using NHibernate.Test.NHSpecificTest.NH3845.Interfaces; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.NH3845 +{ + [TestFixture] + public class Fixture : BugTestCase + { + protected override void OnSetUp() + { + using (ISession session = OpenSession()) + using (ITransaction transaction = session.BeginTransaction()) + { + var entityA = new PropertyEntityA() + { + Name = "Name A", + SerialNumber = 4321, + SharedValue = "Some Value" + }; + var entityB = new PropertyEntityB() + { + Name = "Name B", + Description = "Some Description", + SharedValue = "Another Value", + AnotherString = "Another String" + }; + var entityC = new PropertyEntityC() + { + Name = "Name C", + Description = "Has Description", + SharedValue = "Value", + AnotherNumber = 42 + }; + + var separateEntity = new SeparateEntity() + { + SeparateEntityValue = 5432 + }; + + var mainEntity = new MainEntity() + { + Text = "Main Entity Text" + }; + var secondMainEntity = new MainEntity() + { + Text = "Second Entity Text" + }; + + var anotherEntity = new AnotherEntity() + { + Text = "Another Entity Text" + }; + session.Save(mainEntity); + session.Save(secondMainEntity); + session.Save(anotherEntity); + entityA.MainEntity = mainEntity; + entityB.MainEntity = mainEntity; + entityC.MainEntity = secondMainEntity; + entityC.AnotherEntity = anotherEntity; + separateEntity.MainEntity = secondMainEntity; + session.Save(entityA); + session.Save(entityB); + session.Save(entityC); + session.Save(separateEntity); + + session.Flush(); + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (ISession session = OpenSession()) + using (ITransaction transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public void OfTypeWorksWithSingleEntityInterface() + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var entityQuery = session.Query(); + var result = + entityQuery.Where(m => m.Properties.OfType().Any()).ToList(); + Assert.AreEqual(1, result.Count); + } + } + + [Test] + public void OfTypeWorksWithUnrelatedInterface() + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var entityQuery = session.Query(); + var result = + entityQuery.Where(m => m.Properties.OfType().Any()).ToList(); + Assert.AreEqual(2, result.Count); + } + } + + [Test] + public void CanQueryOfTypePropertyWithUnrelatedInterface() + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var entityQuery = session.Query(); + var result = + entityQuery.Where(m => m.Properties.OfType().Any(d => d.Description == "Has Description")) + .ToList(); + Assert.AreEqual(1, result.Count); + } + } + + [Test] + public void ImpossibleOfTypeReturnsNoResults() + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var entityQuery = session.Query(); + var result = + entityQuery.Where(m => m.Properties.OfType().Any()).ToList(); + Assert.IsEmpty(result); + } + } + + [Test] + public void ImpossibleMappedOfTypeReturnsNoResults() + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var entityQuery = session.Query(); + var result = + entityQuery.Where(m => m.Properties.OfType().Any(se => se.SeparateEntityValue == 5432)).ToList(); + Assert.IsEmpty(result); + } + } + + [Test] + public void OfTypeAppliedToNonSubclassEntityStillWorks() + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var entityQuery = session.Query(); + var result = entityQuery.Where(m => m.SeparateEntities.OfType().Any()).ToList(); + Assert.AreEqual(1, result.Count); + } + } + + [Test] + public void SourceTypeOfNonPolymorphicEntityPropertyIsCorrect() + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var entityQuery = session.Query(); + var result = + entityQuery.Where( + m => + m.Properties.OfType().Select(c => c.AnotherEntity).Any(ae => ae.Text == "Another Entity Text")) + .ToList(); + Assert.AreEqual(1, result.Count); + } + } + + [Test] + public void EnsureParameterValuesAreNotCached() + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var firstEntityQuery = session.Query(); + var name = "Name B"; + var firstResult = + firstEntityQuery.Where(m => m.Properties.OfType().Any(p => p.Name == name)).ToList(); + Assert.AreEqual(1, firstResult.Count); + Assert.AreEqual("Main Entity Text", firstResult.First().Text); + var secondEntityQuery = session.Query(); + name = "Name C"; + var secondResult = + secondEntityQuery.Where(m => m.Properties.OfType().Any(p => p.Name == name)).ToList(); + Assert.AreEqual(1, secondResult.Count); + Assert.AreEqual("Second Entity Text", secondResult.First().Text); + } + } + + [Test] + public void EnsureParameterValuesInSeparateSessionsAreNotCached() + { + var name = "Name B"; + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var firstEntityQuery = session.Query(); + var firstResult = + firstEntityQuery.Where(m => m.Properties.OfType().Any(p => p.Name == name)).ToList(); + Assert.AreEqual(1, firstResult.Count); + Assert.AreEqual("Main Entity Text", firstResult.First().Text); + } + name = "Name C"; + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + var secondEntityQuery = session.Query(); + var secondResult = + secondEntityQuery.Where(m => m.Properties.OfType().Any(p => p.Name == name)).ToList(); + Assert.AreEqual(1, secondResult.Count); + Assert.AreEqual("Second Entity Text", secondResult.First().Text); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IAnotherEntity.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IAnotherEntity.cs new file mode 100644 index 00000000000..10b0e444148 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IAnotherEntity.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Interfaces +{ + public interface IAnotherEntity + { + int AnotherEntityId { get; set; } + string Text { get; set; } + ISet ChildEntities { get; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IHasDescription.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IHasDescription.cs new file mode 100644 index 00000000000..ff0b5f77562 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IHasDescription.cs @@ -0,0 +1,7 @@ +namespace NHibernate.Test.NHSpecificTest.NH3845.Interfaces +{ + public interface IHasDescription + { + string Description { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IMainEntity.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IMainEntity.cs new file mode 100644 index 00000000000..b1db9239d02 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IMainEntity.cs @@ -0,0 +1,12 @@ +using System.Collections.Generic; + +namespace NHibernate.Test.NHSpecificTest.NH3845.Interfaces +{ + public interface IMainEntity + { + int MainEntityId { get; set; } + string Text { get; set; } + ISet Properties { get; } + ISet SeparateEntities { get; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityA.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityA.cs new file mode 100644 index 00000000000..4a0356aa81f --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityA.cs @@ -0,0 +1,7 @@ +namespace NHibernate.Test.NHSpecificTest.NH3845.Interfaces +{ + public interface IPropertyEntityA : IPropertyEntityBase + { + int SerialNumber { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityB.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityB.cs new file mode 100644 index 00000000000..8f2df25a328 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityB.cs @@ -0,0 +1,7 @@ +namespace NHibernate.Test.NHSpecificTest.NH3845.Interfaces +{ + public interface IPropertyEntityB : IPropertyEntityBase, IHasDescription + { + string AnotherString { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityBase.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityBase.cs new file mode 100644 index 00000000000..5bef01561e5 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityBase.cs @@ -0,0 +1,10 @@ +namespace NHibernate.Test.NHSpecificTest.NH3845.Interfaces +{ + public interface IPropertyEntityBase + { + int PropertyEntityBaseId { get; set; } + string Name { get; set; } + string SharedValue { get; set; } + IMainEntity MainEntity { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityC.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityC.cs new file mode 100644 index 00000000000..a5aefdf5eab --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/IPropertyEntityC.cs @@ -0,0 +1,8 @@ +namespace NHibernate.Test.NHSpecificTest.NH3845.Interfaces +{ + public interface IPropertyEntityC : IPropertyEntityBase, IHasDescription + { + int AnotherNumber { get; set; } + IAnotherEntity AnotherEntity { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/ISeparateEntity.cs b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/ISeparateEntity.cs new file mode 100644 index 00000000000..13c2347a1c5 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Interfaces/ISeparateEntity.cs @@ -0,0 +1,9 @@ +namespace NHibernate.Test.NHSpecificTest.NH3845.Interfaces +{ + public interface ISeparateEntity + { + int SeparateEntityId { get; set; } + IMainEntity MainEntity { get; set; } + int SeparateEntityValue { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3845/Mappings.hbm.xml b/src/NHibernate.Test/NHSpecificTest/NH3845/Mappings.hbm.xml new file mode 100644 index 00000000000..c844946dab2 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3845/Mappings.hbm.xml @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index c2cc110206b..a83ae6b228c 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -720,6 +720,14 @@ + + + + + + + + @@ -886,6 +894,7 @@ + @@ -1280,6 +1289,13 @@ + + + + + + + @@ -3160,6 +3176,7 @@ + diff --git a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs index 2732866645f..b163f0868a0 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs @@ -416,7 +416,7 @@ public HqlFalse False() return new HqlFalse(_factory); } - public HqlIn In(HqlExpression itemExpression, HqlTreeNode source) + public HqlIn In(HqlExpression itemExpression, params HqlTreeNode[] source) { return new HqlIn(_factory, itemExpression, source); } diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 7496c382fcb..d4db6592865 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -885,7 +885,7 @@ public HqlStar(IASTFactory factory) : base(HqlSqlWalker.ROW_STAR, "*", factory) public class HqlIn : HqlBooleanExpression { - public HqlIn(IASTFactory factory, HqlExpression itemExpression, HqlTreeNode source) + public HqlIn(IASTFactory factory, HqlExpression itemExpression, params HqlTreeNode[] source) : base(HqlSqlWalker.IN, "in", factory, itemExpression) { AddChild(new HqlInList(factory, source)); @@ -894,7 +894,7 @@ public HqlIn(IASTFactory factory, HqlExpression itemExpression, HqlTreeNode sour public class HqlInList : HqlTreeNode { - public HqlInList(IASTFactory factory, HqlTreeNode source) + public HqlInList(IASTFactory factory, params HqlTreeNode[] source) : base(HqlSqlWalker.IN_LIST, "inlist", factory, source) { } diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs index 7c9613df7a8..44aeefdbfd2 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs @@ -1,5 +1,8 @@ -using System.Linq.Expressions; +using System; +using System.Collections.Generic; +using System.Linq; using NHibernate.Hql.Ast; +using NHibernate.Persister.Entity; using Remotion.Linq.Clauses.ResultOperators; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors @@ -8,15 +11,112 @@ public class ProcessOfType : IResultOperatorProcessor { #region IResultOperatorProcessor Members - public void Process(OfTypeResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) + public void Process( + OfTypeResultOperator resultOperator, + QueryModelVisitor queryModelVisitor, + IntermediateHqlTree tree) { - Expression source = queryModelVisitor.Model.SelectClause.GetOutputDataInfo().ItemExpression; + var fromItemEnumerableType = queryModelVisitor.Model.MainFromClause.FromExpression.Type; + var fromItemType = typeof(object); + var asEnumerable = + fromItemEnumerableType.GetInterfaces() + .FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)); + if (asEnumerable != null) + { + fromItemType = asEnumerable.GetGenericArguments()[0]; + } - tree.AddWhereClause(tree.TreeBuilder.Equality( + var fromItemImplementorQueue = + new Queue(queryModelVisitor.VisitorParameters.SessionFactory.GetImplementors(fromItemType.FullName)); + var fromItemTypeNamesProcessed = new HashSet(); + var fromItemTypePersisters = new List(); + while (fromItemImplementorQueue.Any()) + { + var currentImplementor = fromItemImplementorQueue.Dequeue(); + if (fromItemTypeNamesProcessed.Add(currentImplementor)) + { + var persister = queryModelVisitor.VisitorParameters.SessionFactory.TryGetEntityPersister(currentImplementor); + if (persister != null) + { + fromItemTypePersisters.Add(persister); + persister.EntityMetamodel.SubclassEntityNames.ToList().ForEach(fromItemImplementorQueue.Enqueue); + } + } + } + + // Which of the mapped types are assignable to both the source -- meaning that the property could actually + // be of the mapped type -- and to the searched item type? + var persistersToUse = + fromItemTypePersisters.Where( + p => + { + var mappedClass = p.GetMappedClass(EntityMode.Poco); + return fromItemType.IsAssignableFrom(mappedClass) && resultOperator.SearchedItemType.IsAssignableFrom(mappedClass); + }).ToList(); + + // If the persisters are not among the subclass persister types, there will be no class property, so the query would fail + // if we tried to include the class literal in the query anyway. Also, if there are no applicable persisters, no results + // can be returned, so add "WHERE 1 = 0" in those cases. + if (persistersToUse.Count == 0) + { + tree.AddWhereClause(tree.TreeBuilder.Equality(tree.TreeBuilder.Constant(1), tree.TreeBuilder.Constant(0))); + // Because the rest of the query may be invalid (e.g., by referencing properties that do not exist), + // delete any remaining body clauses. + queryModelVisitor.Model.BodyClauses.Clear(); + return; + } + if (!persistersToUse.Any(p => p.EntityMetamodel.HasSubclasses || p.EntityMetamodel.SuperclassType != null)) + { + // All results should be returned, so no point adding a where clause + return; + } + + var typesToUse = + fromItemTypePersisters.Select(p => p.GetMappedClass(EntityMode.Poco)) + .Where(t => fromItemType.IsAssignableFrom(t) && resultOperator.SearchedItemType.IsAssignableFrom(t)).ToList(); + var classesToUse = typesToUse.Select(t => t.FullName).ToList(); + + // It appears that ReLinq's QuerySourceLocator.FindQuerySource(QueryModel, Type) method may find the source + // only sometimes when specifying fromItemType and only sometimes when specifying the return type of the + // select clause. There may be other scenarios not yet considered here. For now, try both, in order, + // plus the actual output types. Give up if we can't find the correct query source. + var querySourceLocatorCandidateTypeList = new List() + { + fromItemType, + queryModelVisitor.Model.SelectClause.GetOutputDataInfo().ResultItemType + }; + querySourceLocatorCandidateTypeList.AddRange(typesToUse); + var querySource = + querySourceLocatorCandidateTypeList.Select(t => QuerySourceLocator.FindQuerySource(queryModelVisitor.Model, t)) + .FirstOrDefault(qs => qs != null); + + if (querySource == null) + { + throw new QueryException( + String.Format( + "Unable to find QuerySource for any of these types: {0}", + String.Join(", ", querySourceLocatorCandidateTypeList.Select(t => t.FullName)))); + } + + var querySourceName = queryModelVisitor.VisitorParameters.QuerySourceNamer.GetName(querySource); + + var dotNode = tree.TreeBuilder.Dot( - HqlGeneratorExpressionTreeVisitor.Visit(source, queryModelVisitor.VisitorParameters).AsExpression(), - tree.TreeBuilder.Class()), - tree.TreeBuilder.Ident(resultOperator.SearchedItemType.FullName))); + tree.TreeBuilder.Ident(querySourceName), + tree.TreeBuilder.Class()); + + // For now, use the name of the persisted class as a literal identifier. The string value containing + // the full class name is translated to a discriminator column value in + // NHibernate.Hql.Ast.ANTLR.Util.LiteralProcessor.ProcessConstant(SqlNode, bool). + if (classesToUse.Count == 1) + { + tree.AddWhereClause(tree.TreeBuilder.Equality(dotNode, tree.TreeBuilder.Ident(classesToUse[0]))); + } + else + { + var implementorNodes = classesToUse.Select(tree.TreeBuilder.Ident).OfType().ToArray(); + tree.AddWhereClause(tree.TreeBuilder.In(dotNode, implementorNodes)); + } } #endregion