diff --git a/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs b/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs new file mode 100644 index 00000000000..9b6492d4cb0 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs @@ -0,0 +1,33 @@ +using System.Collections.Generic; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class AnotherEntityRequired + { + public virtual int Id { get; set; } + + public virtual string Output { get; set; } + + public virtual string Input { get; set; } + + public virtual Address Address { get; set; } + + public virtual AnotherEntityNullability InputNullability { get; set; } + + public virtual string NullableOutput { get; set; } + + public virtual AnotherEntityRequired NullableAnotherEntityRequired { get; set; } + + public virtual int? NullableAnotherEntityRequiredId { get; set; } + + public virtual ISet RelatedItems { get; set; } = new HashSet(); + + public virtual bool? NullableBool { get; set; } + } + + public enum AnotherEntityNullability + { + False = 0, + True = 1 + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index c3f220ffda5..c23e667be9b 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -26,6 +26,8 @@ public interface IUser Role Role { get; set; } EnumStoredAsString Enum1 { get; set; } EnumStoredAsInt32 Enum2 { get; set; } + IUser CreatedBy { get; set; } + IUser ModifiedBy { get; set; } } public class User : IUser, IEntity @@ -50,6 +52,10 @@ public class User : IUser, IEntity public virtual EnumStoredAsInt32 Enum2 { get; set; } + public virtual IUser CreatedBy { get; set; } + + public virtual IUser ModifiedBy { get; set; } + public virtual int NotMapped { get; set; } public virtual Role NotMappedRole { get; set; } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml new file mode 100644 index 00000000000..0d9efe4136f --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml index f59dc2956c6..2764cb70898 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml @@ -7,11 +7,19 @@ - + + + + + + + + + diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index cecc6778c74..6a5c75091c7 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -15,10 +15,12 @@ using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; +using NUnit.Framework.Constraints; namespace NHibernate.Test.Linq { using System.Threading.Tasks; + using System.Threading; [TestFixture] public class NullComparisonTestsAsync : LinqTestCase { @@ -28,6 +30,350 @@ public class NullComparisonTestsAsync : LinqTestCase private static readonly AnotherEntity BothNull = new AnotherEntity(); private static readonly AnotherEntity BothDifferent = new AnotherEntity {Input = "input", Output = "output"}; + [Test] + public async Task NullInequalityWithNotNullAsync() + { + var q = session.Query().Where(o => o.Input != null); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent)); + + q = session.Query().Where(o => null != o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent)); + + q = session.Query().Where(o => o.InputNullability != AnotherEntityNullability.True); + await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, InputSet, BothSame, BothDifferent)); + + q = session.Query().Where(o => AnotherEntityNullability.True != o.InputNullability); + await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, InputSet, BothSame, BothDifferent)); + + q = session.Query().Where(o => "input" != o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input != "input"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input != o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => o.Output != o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => o.Input != o.NullableOutput); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.NullableOutput != o.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Output != o.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.Input != o.NullableAnotherEntityRequired.Output); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Input != o.Output); + await (ExpectAsync(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull)); + + q = session.Query().Where(o => o.Output != o.NullableAnotherEntityRequired.Input); + await (ExpectAsync(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull)); + + q = session.Query().Where(o => 3 != o.NullableOutput.Length); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, InputSet, BothDifferent, BothNull, OutputSet)); + + q = session.Query().Where(o => o.NullableOutput.Length != 3); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, InputSet, BothDifferent, BothNull, OutputSet)); + + q = session.Query().Where(o => 3 != o.Input.Length); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => o.Input.Length != 3); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequiredId ?? 0) != (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0) != (o.NullableAnotherEntityRequiredId ?? 0)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.GetValueOrDefault() != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault()); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault() != o.NullableAnotherEntityRequiredId.GetValueOrDefault()); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value && o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != 0); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != 0 && o.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value != 0); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != 0 || o.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput != null && o.NullableOutput != "test"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + + q = session.Query().Where(o => o.NullableOutput != "test" && o.NullableOutput != null); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput != "test"); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput != "test" || o.NullableOutput != null); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput != "test" && (o.NullableAnotherEntityRequiredId > 0 && o.NullableOutput != null)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + + q = session.Query().Where(o => o.NullableOutput != null && (o.NullableAnotherEntityRequiredId > 0 && o.NullableOutput != "test")); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequiredId.Value); + await (ExpectAsync(q, Does.Contain("or case").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Any(r => r.Output != o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Contain("Output is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Contain("Output is null").IgnoreCase, InputSet, OutputSet, BothDifferent, BothNull)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null && r.Output != o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothDifferent, OutputSet)); + + q = session.Query().Where(o => (o.NullableOutput + o.Output) != o.Output); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => (o.Input + o.Output) != o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame, BothDifferent)); + + q = session.Query().Where(o => o.Address.Street != o.Output); + await (ExpectAsync(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull)); + + q = session.Query().Where(o => o.Address.City != o.Output); + await (ExpectAsync(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothNull)); + + q = session.Query().Where(o => o.Address.City != null && o.Address.City != o.Output); + await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase)); + + q = session.Query().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput); + await (ExpectAsync(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent)); + + await (ExpectAsync(session.Query().Where(o => o.CustomerId != null), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => null != o.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CompanyName != "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.CreatedBy.Name != "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.CreatedBy.CreatedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.Id != 5), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 != o.CreatedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase)); + } + + [Test] + public async Task NullInequalityWithNotNullSubSelectAsync() + { + if (!Dialect.SupportsScalarSubSelects) + { + Assert.Ignore("Dialect does not support scalar subselects"); + } + + var q = session.Query().Where(o => o.RelatedItems.Count != 1); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) != 0); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != o.NullableBool); + await (ExpectAllAsync(q, Does.Not.Contain("or case").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Where(r => r.Id == 0).Sum(r => r.Input.Length) != 5); + await (ExpectAllAsync(q, Does.Contain("or (").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != (o.NullableOutput.Length > 0)); + await (ExpectAsync(q, Does.Not.Contain("or case").IgnoreCase)); + } + + [Test] + public async Task NullEqualityWithNotNullAsync() + { + var q = session.Query().Where(o => o.Input == null); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull)); + + q = session.Query().Where(o => null == o.Input); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull)); + + q = session.Query().Where(o => o.InputNullability == AnotherEntityNullability.True); + await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet)); + + q = session.Query().Where(o => AnotherEntityNullability.True == o.InputNullability); + await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet)); + + q = session.Query().Where(o => "input" == o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => o.Input == "input"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => o.Input == o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Output == o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input == o.NullableOutput); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.NullableOutput == o.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Output == o.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input == o.NullableAnotherEntityRequired.Output); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Input == o.Output); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Output == o.NullableAnotherEntityRequired.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => 3 == o.Input.Length); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input.Length == 3); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequiredId ?? 0) == (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0)); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0) == (o.NullableAnotherEntityRequiredId ?? 0)); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.GetValueOrDefault() == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault()); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault() == o.NullableAnotherEntityRequiredId.GetValueOrDefault()); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value && o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == 0); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == 0 && o.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value == 0); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == 0 || o.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput == "test"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput == "test"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, OutputSet, BothDifferent, BothSame)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequiredId.Value); + await (ExpectAllAsync(q, Does.Contain("Id is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Any(r => r.Output == o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output == o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothSame, BothNull, InputSet, OutputSet)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output == o.NullableOutput)); + await (ExpectAllAsync(q, Does.Contain("Output is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null && o.NullableOutput != null && r.Output == o.NullableOutput)); + await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame, BothDifferent, OutputSet)); + + q = session.Query().Where(o => (o.NullableOutput + o.Output) == o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => (o.Output + o.Output) == o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => !o.Input.Equals(o.Output)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => !o.Output.Equals(o.Input)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => !o.Input.Equals(o.NullableOutput)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.Input)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.NullableOutput)); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.NullableOutput)); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.Address.City == o.NullableOutput); + await (ExpectAllAsync(q, Does.Contain("Output is null").IgnoreCase)); + + q = session.Query().Where(o => o.Address.Street != null && o.Address.Street == o.NullableOutput); + await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame)); + + await (ExpectAsync(session.Query().Where(o => o.CustomerId == null), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => null == o.CustomerId), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CompanyName == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.CreatedBy.Name == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.CreatedBy.CreatedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 == o.CreatedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase)); + } + + [Test] + public async Task NullEqualityWithNotNullSubSelectAsync() + { + if (!Dialect.SupportsScalarSubSelects) + { + Assert.Ignore("Dialect does not support scalar subselects"); + } + + var q = session.Query().Where(o => o.RelatedItems.Count == 1); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) == 0); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + } + [Test] public async Task NullEqualityAsync() { @@ -96,6 +442,36 @@ public async Task NullEqualityAsync() // Columns against columns q = from x in session.Query() where x.Input == x.Output select x; await (ExpectAsync(q, BothSame, BothNull)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.ContactName == null), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => null == o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.ContactName == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => null == o.Component.Property1), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.Property1 == null), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => "test" == o.Component.Property1), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.Property1 == "test"), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => null == o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == null), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => "test" == o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.CreatedBy.Name == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.CreatedBy.ModifiedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.ModifiedBy.CreatedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 == o.ModifiedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase)); } [Test] @@ -154,6 +530,36 @@ public async Task NullInequalityAsync() // Columns against columns q = from x in session.Query() where x.Input != x.Output select x; await (ExpectAsync(q, BothDifferent, InputSet, OutputSet)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.ContactName != null), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => null != o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.ContactName != "test"), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => null != o.Component.Property1), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.Property1 != null), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => "test" != o.Component.Property1), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.Property1 != "test"), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => null != o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != null), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => "test" != o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.CreatedBy.Name != "test"), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.CreatedBy.ModifiedBy.CreatedBy.Name), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.ModifiedBy.CreatedBy.Id != 5), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 != o.ModifiedBy.CreatedBy.Id), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase)); } [Test] @@ -316,5 +722,54 @@ private string Key(AnotherEntity e) { return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); } + + private Task ExpectAllAsync(IQueryable q, IResolveConstraint sqlConstraint) + { + return ExpectAsync(q, sqlConstraint, BothNull, BothSame, BothDifferent, InputSet, OutputSet); + } + + private async Task ExpectAsync(IQueryable q, IResolveConstraint sqlConstraint, params AnotherEntity[] entities) + { + IList results; + if (sqlConstraint == null) + { + results = await (GetResultsAsync(q)); + } + else + { + using (var sqlLog = new SqlLogSpy()) + { + results = await (GetResultsAsync(q)); + Assert.That(sqlLog.GetWholeLog(), sqlConstraint); + } + } + + IList check = entities.OrderBy(Key).ToList(); + + Assert.That(results.Count, Is.EqualTo(check.Count)); + for (var i = 0; i < check.Count; i++) + { + Assert.That(Key(results[i]), Is.EqualTo(Key(check[i]))); + } + } + + private async Task> GetResultsAsync(IQueryable q, CancellationToken cancellationToken = default(CancellationToken)) + { + return (await (q.ToListAsync(cancellationToken))).OrderBy(Key).ToList(); + } + + private static async Task ExpectAsync(IQueryable query, IResolveConstraint sqlConstraint, CancellationToken cancellationToken = default(CancellationToken)) + { + using (var sqlLog = new SqlLogSpy()) + { + var list = await (query.ToListAsync(cancellationToken)); + Assert.That(sqlLog.GetWholeLog(), sqlConstraint); + } + } + + private static string Key(AnotherEntityRequired e) + { + return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); + } } } diff --git a/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql b/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql index b9dbec1e05d..12235095e0d 100644 Binary files a/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql and b/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql differ diff --git a/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyCreateScript.sql b/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyCreateScript.sql index 43e78e33fa0..4d15d18a6e1 100644 --- a/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyCreateScript.sql +++ b/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyCreateScript.sql @@ -3774,6 +3774,8 @@ CREATE TABLE [dbo].[Users]( [Property1] [varchar](255) NULL, [Property2] [varchar](255) NULL, [OtherProperty1] [varchar](255) NULL, + [CreatedById] [int] NOT NULL, + [ModifiedById] [int] NULL PRIMARY KEY CLUSTERED ( [UserId] ASC @@ -3783,9 +3785,9 @@ GO SET ANSI_PADDING OFF GO SET IDENTITY_INSERT [dbo].[Users] ON -INSERT [dbo].[Users] ([UserId], [Name], [InvalidLoginAttempts], [RegisteredAt], [LastLoginDate], [Enum1], [Enum2], [RoleId], [Property1], [Property2], [OtherProperty1]) VALUES (1, N'ayende', 4, CAST(0x00009D9800000000 AS DateTime), NULL, N'Medium', 1, 1, N'test1', N'test2', N'othertest1') -INSERT [dbo].[Users] ([UserId], [Name], [InvalidLoginAttempts], [RegisteredAt], [LastLoginDate], [Enum1], [Enum2], [RoleId], [Property1], [Property2], [OtherProperty1]) VALUES (2, N'rahien', 5, CAST(0x00008D3E00000000 AS DateTime), NULL, N'Small', 0, 2, NULL, N'test2', NULL) -INSERT [dbo].[Users] ([UserId], [Name], [InvalidLoginAttempts], [RegisteredAt], [LastLoginDate], [Enum1], [Enum2], [Features], [RoleId], [Property1], [Property2], [OtherProperty1]) VALUES (3, N'nhibernate', 6, CAST(0x00008EAC00000000 AS DateTime), CAST(0x00009D970110B41C AS DateTime), N'Medium', 0, 8, NULL, NULL, NULL, NULL) +INSERT [dbo].[Users] ([UserId], [Name], [InvalidLoginAttempts], [RegisteredAt], [LastLoginDate], [Enum1], [Enum2], [RoleId], [Property1], [Property2], [OtherProperty1], [CreatedById], [ModifiedById]) VALUES (1, N'ayende', 4, CAST(0x00009D9800000000 AS DateTime), NULL, N'Medium', 1, 1, N'test1', N'test2', N'othertest1', 1, NULL) +INSERT [dbo].[Users] ([UserId], [Name], [InvalidLoginAttempts], [RegisteredAt], [LastLoginDate], [Enum1], [Enum2], [RoleId], [Property1], [Property2], [OtherProperty1], [CreatedById], [ModifiedById]) VALUES (2, N'rahien', 5, CAST(0x00008D3E00000000 AS DateTime), NULL, N'Small', 0, 2, NULL, N'test2', NULL, 1, NULL) +INSERT [dbo].[Users] ([UserId], [Name], [InvalidLoginAttempts], [RegisteredAt], [LastLoginDate], [Enum1], [Enum2], [Features], [RoleId], [Property1], [Property2], [OtherProperty1], [CreatedById], [ModifiedById]) VALUES (3, N'nhibernate', 6, CAST(0x00008EAC00000000 AS DateTime), CAST(0x00009D970110B41C AS DateTime), N'Medium', 0, 8, NULL, NULL, NULL, NULL, 1, NULL) SET IDENTITY_INSERT [dbo].[Users] OFF /****** Object: Table [dbo].[TimeSheetUsers] Script Date: 06/17/2010 13:08:54 ******/ SET ANSI_NULLS ON diff --git a/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyCreateScript.sql b/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyCreateScript.sql index fa32b85f55b..dd9d8716a71 100644 Binary files a/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyCreateScript.sql and b/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyCreateScript.sql differ diff --git a/src/NHibernate.Test/Linq/LinqTestCase.cs b/src/NHibernate.Test/Linq/LinqTestCase.cs index 26503786bb0..e047732d7ad 100755 --- a/src/NHibernate.Test/Linq/LinqTestCase.cs +++ b/src/NHibernate.Test/Linq/LinqTestCase.cs @@ -29,6 +29,7 @@ protected override string[] Mappings "Northwind.Mappings.Supplier.hbm.xml", "Northwind.Mappings.Territory.hbm.xml", "Northwind.Mappings.AnotherEntity.hbm.xml", + "Northwind.Mappings.AnotherEntityRequired.hbm.xml", "Northwind.Mappings.Role.hbm.xml", "Northwind.Mappings.User.hbm.xml", "Northwind.Mappings.TimeSheet.hbm.xml", @@ -69,4 +70,4 @@ public static void AssertByIds(IEnumerable entities, TId[ Assert.That(entities.Select(x => entityIdGetter(x)), Is.EquivalentTo(expectedIds)); } } -} \ No newline at end of file +} diff --git a/src/NHibernate.Test/Linq/NorthwindDbCreator.cs b/src/NHibernate.Test/Linq/NorthwindDbCreator.cs index dbb34f93bd2..fa484be3a92 100644 --- a/src/NHibernate.Test/Linq/NorthwindDbCreator.cs +++ b/src/NHibernate.Test/Linq/NorthwindDbCreator.cs @@ -70,6 +70,11 @@ public static void CreateMiscTestData(ISession session) } }; + foreach (var user in users) + { + user.CreatedBy = users[0]; + } + var timesheets = new[] { new Timesheet @@ -3707,4 +3712,4 @@ static void CreateOrderLines22(IStatelessSession session, IDictionary().Where(o => o.Input != null); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent); + + q = session.Query().Where(o => null != o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent); + + q = session.Query().Where(o => o.InputNullability != AnotherEntityNullability.True); + Expect(q, Does.Not.Contain("end is null").IgnoreCase, InputSet, BothSame, BothDifferent); + + q = session.Query().Where(o => AnotherEntityNullability.True != o.InputNullability); + Expect(q, Does.Not.Contain("end is null").IgnoreCase, InputSet, BothSame, BothDifferent); + + q = session.Query().Where(o => "input" != o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input != "input"); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input != o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => o.Output != o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => o.Input != o.NullableOutput); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.NullableOutput != o.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Output != o.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.Input != o.NullableAnotherEntityRequired.Output); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Input != o.Output); + Expect(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull); + + q = session.Query().Where(o => o.Output != o.NullableAnotherEntityRequired.Input); + Expect(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull); + + q = session.Query().Where(o => 3 != o.NullableOutput.Length); + Expect(q, Does.Contain("is null").IgnoreCase, InputSet, BothDifferent, BothNull, OutputSet); + + q = session.Query().Where(o => o.NullableOutput.Length != 3); + Expect(q, Does.Contain("is null").IgnoreCase, InputSet, BothDifferent, BothNull, OutputSet); + + q = session.Query().Where(o => 3 != o.Input.Length); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => o.Input.Length != 3); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequiredId ?? 0) != (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0)); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0) != (o.NullableAnotherEntityRequiredId ?? 0)); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.GetValueOrDefault() != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault()); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault() != o.NullableAnotherEntityRequiredId.GetValueOrDefault()); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value && o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != 0); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != 0 && o.NullableAnotherEntityRequiredId.HasValue); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value != 0); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != 0 || o.NullableAnotherEntityRequiredId.HasValue); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput != null && o.NullableOutput != "test"); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + + q = session.Query().Where(o => o.NullableOutput != "test" && o.NullableOutput != null); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput != "test"); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput != "test" || o.NullableOutput != null); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput != "test" && (o.NullableAnotherEntityRequiredId > 0 && o.NullableOutput != null)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + + q = session.Query().Where(o => o.NullableOutput != null && (o.NullableAnotherEntityRequiredId > 0 && o.NullableOutput != "test")); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequiredId.Value); + Expect(q, Does.Contain("or case").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Any(r => r.Output != o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Contain("Output is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Contain("Output is null").IgnoreCase, InputSet, OutputSet, BothDifferent, BothNull); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null && r.Output != o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothDifferent, OutputSet); + + q = session.Query().Where(o => (o.NullableOutput + o.Output) != o.Output); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => (o.Input + o.Output) != o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame, BothDifferent); + + q = session.Query().Where(o => o.Address.Street != o.Output); + Expect(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull); + + q = session.Query().Where(o => o.Address.City != o.Output); + Expect(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothNull); + + q = session.Query().Where(o => o.Address.City != null && o.Address.City != o.Output); + Expect(q, Does.Not.Contain("Output is null").IgnoreCase); + + q = session.Query().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput); + Expect(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent); + + Expect(session.Query().Where(o => o.CustomerId != null), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => null != o.CustomerId), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.CustomerId), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.CompanyName != "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.CreatedBy.Name != "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.CreatedBy.CreatedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.Id != 5), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 != o.CreatedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase); + } + + [Test] + public void NullInequalityWithNotNullSubSelect() + { + if (!Dialect.SupportsScalarSubSelects) + { + Assert.Ignore("Dialect does not support scalar subselects"); + } + + var q = session.Query().Where(o => o.RelatedItems.Count != 1); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) != 0); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != o.NullableBool); + ExpectAll(q, Does.Not.Contain("or case").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Where(r => r.Id == 0).Sum(r => r.Input.Length) != 5); + ExpectAll(q, Does.Contain("or (").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != (o.NullableOutput.Length > 0)); + Expect(q, Does.Not.Contain("or case").IgnoreCase); + } + + [Test] + public void NullEqualityWithNotNull() + { + var q = session.Query().Where(o => o.Input == null); + Expect(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull); + + q = session.Query().Where(o => null == o.Input); + Expect(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull); + + q = session.Query().Where(o => o.InputNullability == AnotherEntityNullability.True); + Expect(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet); + + q = session.Query().Where(o => AnotherEntityNullability.True == o.InputNullability); + Expect(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet); + + q = session.Query().Where(o => "input" == o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => o.Input == "input"); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => o.Input == o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Output == o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input == o.NullableOutput); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.NullableOutput == o.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Output == o.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input == o.NullableAnotherEntityRequired.Output); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Input == o.Output); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Output == o.NullableAnotherEntityRequired.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => 3 == o.Input.Length); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input.Length == 3); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequiredId ?? 0) == (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0)); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0) == (o.NullableAnotherEntityRequiredId ?? 0)); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.GetValueOrDefault() == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault()); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault() == o.NullableAnotherEntityRequiredId.GetValueOrDefault()); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value && o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == 0); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == 0 && o.NullableAnotherEntityRequiredId.HasValue); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value == 0); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == 0 || o.NullableAnotherEntityRequiredId.HasValue); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput == "test"); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput == "test"); + Expect(q, Does.Not.Contain("is null").IgnoreCase, OutputSet, BothDifferent, BothSame); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequiredId.Value); + ExpectAll(q, Does.Contain("Id is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Any(r => r.Output == o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output == o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothSame, BothNull, InputSet, OutputSet); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output == o.NullableOutput)); + ExpectAll(q, Does.Contain("Output is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null && o.NullableOutput != null && r.Output == o.NullableOutput)); + Expect(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame, BothDifferent, OutputSet); + + q = session.Query().Where(o => (o.NullableOutput + o.Output) == o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => (o.Output + o.Output) == o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => !o.Input.Equals(o.Output)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => !o.Output.Equals(o.Input)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => !o.Input.Equals(o.NullableOutput)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.Input)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.NullableOutput)); + Expect(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.NullableOutput)); + Expect(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.Address.City == o.NullableOutput); + ExpectAll(q, Does.Contain("Output is null").IgnoreCase); + + q = session.Query().Where(o => o.Address.Street != null && o.Address.Street == o.NullableOutput); + Expect(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame); + + Expect(session.Query().Where(o => o.CustomerId == null), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => null == o.CustomerId), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.CustomerId), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.CompanyName == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.CreatedBy.Name == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.CreatedBy.CreatedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 == o.CreatedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase); + } + + [Test] + public void NullEqualityWithNotNullSubSelect() + { + if (!Dialect.SupportsScalarSubSelects) + { + Assert.Ignore("Dialect does not support scalar subselects"); + } + + var q = session.Query().Where(o => o.RelatedItems.Count == 1); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) == 0); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + } + [Test] public void NullEquality() { @@ -85,6 +430,36 @@ public void NullEquality() // Columns against columns q = from x in session.Query() where x.Input == x.Output select x; Expect(q, BothSame, BothNull); + + Expect(session.Query().Where(o => o.Order.Customer.ContactName == null), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => null == o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.ContactName == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => null == o.Component.Property1), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.Property1 == null), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => "test" == o.Component.Property1), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.Property1 == "test"), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => null == o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == null), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => "test" == o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.CreatedBy.Name == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.CreatedBy.ModifiedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.ModifiedBy.CreatedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 == o.ModifiedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase); } [Test] @@ -143,6 +518,36 @@ public void NullInequality() // Columns against columns q = from x in session.Query() where x.Input != x.Output select x; Expect(q, BothDifferent, InputSet, OutputSet); + + Expect(session.Query().Where(o => o.Order.Customer.ContactName != null), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => null != o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.ContactName != "test"), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => null != o.Component.Property1), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.Property1 != null), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => "test" != o.Component.Property1), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.Property1 != "test"), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => null != o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != null), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => "test" != o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.CreatedBy.Name != "test"), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.CreatedBy.ModifiedBy.CreatedBy.Name), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.ModifiedBy.CreatedBy.Id != 5), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 != o.ModifiedBy.CreatedBy.Id), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase); } [Test] @@ -305,5 +710,54 @@ private string Key(AnotherEntity e) { return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); } + + private void ExpectAll(IQueryable q, IResolveConstraint sqlConstraint) + { + Expect(q, sqlConstraint, BothNull, BothSame, BothDifferent, InputSet, OutputSet); + } + + private void Expect(IQueryable q, IResolveConstraint sqlConstraint, params AnotherEntity[] entities) + { + IList results; + if (sqlConstraint == null) + { + results = GetResults(q); + } + else + { + using (var sqlLog = new SqlLogSpy()) + { + results = GetResults(q); + Assert.That(sqlLog.GetWholeLog(), sqlConstraint); + } + } + + IList check = entities.OrderBy(Key).ToList(); + + Assert.That(results.Count, Is.EqualTo(check.Count)); + for (var i = 0; i < check.Count; i++) + { + Assert.That(Key(results[i]), Is.EqualTo(Key(check[i]))); + } + } + + private IList GetResults(IQueryable q) + { + return q.ToList().OrderBy(Key).ToList(); + } + + private static void Expect(IQueryable query, IResolveConstraint sqlConstraint) + { + using (var sqlLog = new SqlLogSpy()) + { + var list = query.ToList(); + Assert.That(sqlLog.GetWholeLog(), sqlConstraint); + } + } + + private static string Key(AnotherEntityRequired e) + { + return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); + } } } diff --git a/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs b/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs index 6b705d3d92e..ab60d1a3d29 100644 --- a/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs @@ -16,6 +16,8 @@ protected NhAggregatedExpression(Expression expression, System.Type type) Type = type; } + public virtual bool AllowsNullableReturnType => true; + public sealed override System.Type Type { get; } public Expression Expression { get; } diff --git a/src/NHibernate/Linq/Expressions/NhCountExpression.cs b/src/NHibernate/Linq/Expressions/NhCountExpression.cs index 6dc698add5c..e41ed926410 100644 --- a/src/NHibernate/Linq/Expressions/NhCountExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhCountExpression.cs @@ -10,6 +10,8 @@ protected NhCountExpression(Expression expression, System.Type type) { } + public override bool AllowsNullableReturnType => false; + protected override Expression Accept(NhExpressionVisitor visitor) { return visitor.VisitNhCount(this); diff --git a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs index 3a9462d49ef..53955d136a6 100644 --- a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs @@ -7,10 +7,12 @@ namespace NHibernate.Linq.Functions { - public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod - { - public IEnumerable SupportedMethods { get; protected set; } + public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGeneratorForMethodExtended + { + public IEnumerable SupportedMethods { get; protected set; } - public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); - } -} \ No newline at end of file + public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); + + public virtual bool AllowsNullableReturnType(MethodInfo method) => true; + } +} diff --git a/src/NHibernate/Linq/Functions/CompareGenerator.cs b/src/NHibernate/Linq/Functions/CompareGenerator.cs index a819bcd8151..e703e6c62e7 100644 --- a/src/NHibernate/Linq/Functions/CompareGenerator.cs +++ b/src/NHibernate/Linq/Functions/CompareGenerator.cs @@ -51,6 +51,7 @@ internal static bool IsCompareMethod(MethodInfo methodInfo) methodInfo.DeclaringType.FullName == "System.Data.Services.Providers.DataServiceProviderMethods"; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; public CompareGenerator() { SupportedMethods = ActingMethods.ToArray(); diff --git a/src/NHibernate/Linq/Functions/DictionaryGenerator.cs b/src/NHibernate/Linq/Functions/DictionaryGenerator.cs index 6131c4d7ab8..eb583d0cd2f 100644 --- a/src/NHibernate/Linq/Functions/DictionaryGenerator.cs +++ b/src/NHibernate/Linq/Functions/DictionaryGenerator.cs @@ -25,6 +25,8 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, public class DictionaryContainsKeyGenerator : BaseHqlGeneratorForMethod { + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { return treeBuilder.In(visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Indices(visitor.Visit(targetObject).AsExpression())); @@ -98,4 +100,4 @@ protected override string MethodName get { return "get_Item"; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Functions/EqualsGenerator.cs b/src/NHibernate/Linq/Functions/EqualsGenerator.cs index 27165978b34..55ec39e40ee 100644 --- a/src/NHibernate/Linq/Functions/EqualsGenerator.cs +++ b/src/NHibernate/Linq/Functions/EqualsGenerator.cs @@ -63,14 +63,14 @@ public EqualsGenerator() }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { Expression lhs = arguments.Count == 1 ? targetObject : arguments[0]; Expression rhs = arguments.Count == 1 ? arguments[0] : arguments[1]; - return treeBuilder.Equality( - visitor.Visit(lhs).ToArithmeticExpression(), - visitor.Visit(rhs).ToArithmeticExpression()); + return visitor.Visit(Expression.Equal(lhs, rhs)); } } } diff --git a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs index 87c8efb01a9..33cb12c2c6c 100644 --- a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs +++ b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs @@ -9,7 +9,7 @@ namespace NHibernate.Linq.Functions { - internal class GetValueOrDefaultGenerator : IHqlGeneratorForMethod, IRuntimeMethodHqlGenerator + internal class GetValueOrDefaultGenerator : IHqlGeneratorForMethod, IRuntimeMethodHqlGenerator, IHqlGeneratorForMethodExtended { public bool SupportsMethod(MethodInfo method) { @@ -40,5 +40,7 @@ private static HqlExpression GetRhs(MethodInfo method, ReadOnlyCollection !method.ReturnType.IsValueType; } } diff --git a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs index 06b1545ef23..fde4ffd45f0 100644 --- a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs @@ -12,4 +12,24 @@ public interface IHqlGeneratorForMethod IEnumerable SupportedMethods { get; } HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); } -} \ No newline at end of file + + // 6.0 TODO: Merge into IHqlGeneratorForMethod + internal interface IHqlGeneratorForMethodExtended + { + bool AllowsNullableReturnType(MethodInfo method); + } + + internal static class HqlGeneratorForMethodExtensions + { + // 6.0 TODO: Remove + public static bool AllowsNullableReturnType(this IHqlGeneratorForMethod generator, MethodInfo method) + { + if (generator is IHqlGeneratorForMethodExtended extendedGenerator) + { + return extendedGenerator.AllowsNullableReturnType(method); + } + + return true; + } + } +} diff --git a/src/NHibernate/Linq/Functions/QueryableGenerator.cs b/src/NHibernate/Linq/Functions/QueryableGenerator.cs index 4f3a9568c69..3da75fd0c99 100644 --- a/src/NHibernate/Linq/Functions/QueryableGenerator.cs +++ b/src/NHibernate/Linq/Functions/QueryableGenerator.cs @@ -22,6 +22,8 @@ public AnyHqlGenerator() }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { HqlAlias alias = null; @@ -59,6 +61,8 @@ public AllHqlGenerator() }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { // All has two arguments. Arg 1 is the source and arg 2 is the predicate @@ -148,6 +152,8 @@ public CollectionContainsGenerator() }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { // TODO - alias generator @@ -170,4 +176,4 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, where)); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Functions/StringGenerator.cs b/src/NHibernate/Linq/Functions/StringGenerator.cs index 7ec127b1f17..31edf4a42d1 100644 --- a/src/NHibernate/Linq/Functions/StringGenerator.cs +++ b/src/NHibernate/Linq/Functions/StringGenerator.cs @@ -10,7 +10,7 @@ namespace NHibernate.Linq.Functions { - public class LikeGenerator : IHqlGeneratorForMethod, IRuntimeMethodHqlGenerator + public class LikeGenerator : IHqlGeneratorForMethod, IRuntimeMethodHqlGenerator, IHqlGeneratorForMethodExtended { public IEnumerable SupportedMethods { @@ -57,6 +57,8 @@ public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) { return this; } + + public bool AllowsNullableReturnType(MethodInfo method) => false; } public class LengthGenerator : BaseHqlGeneratorForProperty @@ -79,6 +81,8 @@ public StartsWithGenerator() SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(x => x.StartsWith(null)) }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { return treeBuilder.Like( @@ -96,6 +100,8 @@ public EndsWithGenerator() SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(x => x.EndsWith(null)) }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { return treeBuilder.Like( @@ -113,6 +119,8 @@ public ContainsGenerator() SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(x => x.Contains(null)) }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { return treeBuilder.Like( diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 4e6df61afba..224af655a24 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -21,16 +21,18 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private readonly VisitorParameters _parameters; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; + private readonly NullableExpressionDetector _nullableExpressionDetector; public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) { - return new HqlGeneratorExpressionVisitor(parameters).VisitExpression(expression); + return new HqlGeneratorExpressionVisitor(parameters).Visit(expression); } public HqlGeneratorExpressionVisitor(VisitorParameters parameters) { _functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry; _parameters = parameters; + _nullableExpressionDetector = new NullableExpressionDetector(_parameters.SessionFactory, _functionRegistry); } public ISessionFactory SessionFactory { get { return _parameters.SessionFactory; } } @@ -303,6 +305,8 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) return TranslateInequalityComparison(expression); } + _nullableExpressionDetector.SearchForNotNullMemberChecks(expression); + var lhs = VisitExpression(expression.Left).AsExpression(); var rhs = VisitExpression(expression.Right).AsExpression(); @@ -384,8 +388,8 @@ private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression) return _hqlTreeBuilder.IsNotNull(lhs); } - var lhsNullable = IsNullable(lhs); - var rhsNullable = IsNullable(rhs); + var lhsNullable = _nullableExpressionDetector.IsNullable(expression.Left, expression); + var rhsNullable = _nullableExpressionDetector.IsNullable(expression.Right, expression); var inequality = _hqlTreeBuilder.Inequality(lhs, rhs); @@ -447,8 +451,8 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) return _hqlTreeBuilder.IsNull((lhs)); } - var lhsNullable = IsNullable(lhs); - var rhsNullable = IsNullable(rhs); + var lhsNullable = _nullableExpressionDetector.IsNullable(expression.Left, expression); + var rhsNullable = _nullableExpressionDetector.IsNullable(expression.Right, expression); var equality = _hqlTreeBuilder.Equality(lhs, rhs); @@ -467,12 +471,6 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) _hqlTreeBuilder.IsNull(rhs2))); } - static bool IsNullable(HqlExpression original) - { - var hqlDot = original as HqlDot; - return hqlDot != null && hqlDot.Children.Last() is HqlIdent; - } - protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) { switch (expression.NodeType) diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs new file mode 100644 index 00000000000..d9e2d6c06f5 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -0,0 +1,296 @@ +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using NHibernate.Engine; +using NHibernate.Linq.Clauses; +using NHibernate.Linq.Expressions; +using NHibernate.Linq.Functions; +using NHibernate.Util; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; + +namespace NHibernate.Linq.Visitors +{ + internal class NullableExpressionDetector + { + private static readonly HashSet NotNullOperators = new HashSet + { + typeof(AllResultOperator), + typeof(AnyResultOperator), + typeof(ContainsResultOperator), + typeof(CountResultOperator), + typeof(LongCountResultOperator) + }; + + private readonly Dictionary> _equalityNotNullMembers = + new Dictionary>(); + private readonly ISessionFactoryImplementor _sessionFactory; + private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; + + public NullableExpressionDetector(ISessionFactoryImplementor sessionFactory, ILinqToHqlGeneratorsRegistry functionRegistry) + { + _sessionFactory = sessionFactory; + _functionRegistry = functionRegistry; + } + + public void SearchForNotNullMemberChecks(BinaryExpression expression) + { + // Check for a member not null check that has a not equals expression + // Example: o.Status != null && o.Status != "New" + // Example: (o.Status != null && o.OldStatus != null) && (o.Status != o.OldStatus) + // Example: (o.Status != null && o.OldStatus != null) && (o.Status == o.OldStatus) + // Example: o.Status != null && (o.OldStatus != null && o.Status == o.OldStatus) + if ( + _equalityNotNullMembers.ContainsKey(expression) || + !IsAndOrAndAlso(expression) || + ( + !IsAndOrAndAlso(expression.Right) && + !IsEqualOrNotEqual(expression.Right) + ) || + ( + !IsAndOrAndAlso(expression.Left) && + !IsEqualOrNotEqual(expression.Left) + )) + { + return; + } + + // Find all not null members and cache them for each binary expression that is found, + // in order to verify whether the member in a binary expression is nullable or not + FindAllNotNullMembers(expression, new List()); + } + + public bool IsNullable(Expression expression, BinaryExpression equalityExpression) + { + switch (expression.NodeType) + { + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.TypeAs: + // a cast will not return null if the operand is not null (as long as TypeAs is not translated to + // try_convert in SQL). + return IsNullable(((UnaryExpression) expression).Operand, equalityExpression); + case ExpressionType.Not: + case ExpressionType.And: + case ExpressionType.Or: + case ExpressionType.ExclusiveOr: + case ExpressionType.LeftShift: + case ExpressionType.RightShift: + case ExpressionType.AndAlso: + case ExpressionType.OrElse: + case ExpressionType.Equal: + case ExpressionType.NotEqual: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + return false; + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Power: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + var binaryExpression = (BinaryExpression) expression; + return IsNullable(binaryExpression.Left, equalityExpression) || IsNullable(binaryExpression.Right, equalityExpression); + case ExpressionType.ArrayIndex: + return true; // for indexed lists we cannot determine whether the item will be null or not + case ExpressionType.Coalesce: + return IsNullable(((BinaryExpression) expression).Right, equalityExpression); + case ExpressionType.Conditional: + var conditionalExpression = (ConditionalExpression) expression; + return IsNullable(conditionalExpression.IfTrue, equalityExpression) || + IsNullable(conditionalExpression.IfFalse, equalityExpression); + case ExpressionType.Call: + var methodInfo = ((MethodCallExpression) expression).Method; + return !_functionRegistry.TryGetGenerator(methodInfo, out var method) || method.AllowsNullableReturnType(methodInfo); + case ExpressionType.MemberAccess: + return IsNullable((MemberExpression) expression, equalityExpression); + case ExpressionType.Extension: + return IsNullableExtension(expression, equalityExpression); + case ExpressionType.TypeIs: // an equal or in operator will be generated and those cannot return null + case ExpressionType.NewArrayInit: + return false; + case ExpressionType.Constant: + return VisitorUtil.IsNullConstant(expression); + case ExpressionType.Parameter: + return !expression.Type.IsValueType; + default: + return true; + } + } + + private bool IsNullable(MemberExpression memberExpression, BinaryExpression equalityExpression) + { + if (_functionRegistry.TryGetGenerator(memberExpression.Member, out _)) + { + // We have to skip the property as it will be converted to a function that can return null + // if the argument is null + return IsNullable(memberExpression.Expression, equalityExpression); + } + + var memberType = memberExpression.Member.GetPropertyOrFieldType(); + if (memberType?.IsValueType == true && !memberType.IsNullable()) + { + return IsNullable(memberExpression.Expression, equalityExpression); + } + + // Check if there was a not null check prior or after the equality expression + if (IsEqualOrNotEqual(equalityExpression) && + _equalityNotNullMembers.TryGetValue(equalityExpression, out var notNullMembers) && + notNullMembers.Any(o => AreEqual(o, memberExpression))) + { + return false; + } + + if (!ExpressionsHelper.TryGetMappedNullability(_sessionFactory, memberExpression, out var nullable) || nullable) + { + return true; // The expression contains one or many unsupported nodes or the member is nullable + } + + return IsNullable(memberExpression.Expression, equalityExpression); + } + + private bool IsNullableExtension(Expression extensionExpression, BinaryExpression equalityExpression) + { + switch (extensionExpression) + { + case QuerySourceReferenceExpression querySourceReferenceExpression: + switch (querySourceReferenceExpression.ReferencedQuerySource) + { + case MainFromClause _: + return false; // we reached to the root expression, there were no nullable expressions + case NhJoinClause joinClause: + return IsNullable(joinClause.FromExpression, equalityExpression); + default: + return true; // unknown query source + } + case SubQueryExpression subQueryExpression: + if (subQueryExpression.QueryModel.SelectClause.Selector is NhAggregatedExpression subQueryAggregatedExpression) + { + return subQueryAggregatedExpression.AllowsNullableReturnType; + } + else if (subQueryExpression.QueryModel.ResultOperators.Any(o => NotNullOperators.Contains(o.GetType()))) + { + return false; + } + + return true; + case NhAggregatedExpression aggregatedExpression: + return aggregatedExpression.AllowsNullableReturnType; + default: + return true; // a query can return null and we cannot calculate it as it is not yet executed + } + } + + private static bool TryGetMemberAccess(Expression expression, out MemberExpression memberExpression) + { + memberExpression = expression as MemberExpression; + if (memberExpression != null) + { + return true; + } + + // Nullable members can be wrapped in a convert expression + if (expression is UnaryExpression unaryExpression) + { + memberExpression = unaryExpression.Operand as MemberExpression; + } + + return memberExpression != null; + } + + private void FindAllNotNullMembers(Expression expression, List notNullMembers) + { + if (!(expression is BinaryExpression binaryExpression)) + { + return; + } + + // We may have multiple conditions + // Example: o.Status != null && o.OldStatus != null + // Example: o.Status != null && (o.OldStatus != null && o.Test > 0) + if (IsAndOrAndAlso(expression)) + { + FindAllNotNullMembers(binaryExpression, notNullMembers); + } + else if (IsEqualOrNotEqual(expression)) + { + FindNotNullMember(binaryExpression, notNullMembers); + } + } + + private void FindAllNotNullMembers(BinaryExpression binaryExpression, List notNullMembers) + { + _equalityNotNullMembers.Add(binaryExpression, notNullMembers); + FindAllNotNullMembers(binaryExpression.Left, notNullMembers); + FindAllNotNullMembers(binaryExpression.Right, notNullMembers); + } + + private void FindNotNullMember(BinaryExpression binaryExpression, List notNullMembers) + { + _equalityNotNullMembers[binaryExpression] = notNullMembers; + if (binaryExpression.NodeType != ExpressionType.NotEqual) + { + return; + } + + MemberExpression memberExpression; + if (VisitorUtil.IsNullConstant(binaryExpression.Right) && TryGetMemberAccess(binaryExpression.Left, out memberExpression)) + { + notNullMembers.Add(memberExpression); + } + else if (VisitorUtil.IsNullConstant(binaryExpression.Left) && TryGetMemberAccess(binaryExpression.Right, out memberExpression)) + { + notNullMembers.Add(memberExpression); + } + } + + private static bool AreEqual(MemberExpression memberExpression, MemberExpression otherMemberExpression) + { + if (memberExpression.Member != otherMemberExpression.Member || + memberExpression.Expression.NodeType != otherMemberExpression.Expression.NodeType) + { + return false; + } + + switch (memberExpression.Expression) + { + case QuerySourceReferenceExpression querySourceReferenceExpression: + if (otherMemberExpression.Expression is QuerySourceReferenceExpression otherQuerySourceReferenceExpression) + { + return querySourceReferenceExpression.ReferencedQuerySource == + otherQuerySourceReferenceExpression.ReferencedQuerySource; + } + + return false; + // Components have a nested member expression + case MemberExpression nestedMemberExpression: + if (otherMemberExpression.Expression is MemberExpression otherNestedMemberExpression) + { + return AreEqual(nestedMemberExpression, otherNestedMemberExpression); + } + + return false; + default: + return memberExpression.Expression == otherMemberExpression.Expression; + } + } + + private static bool IsAndOrAndAlso(Expression expression) + { + return expression.NodeType == ExpressionType.And || + expression.NodeType == ExpressionType.AndAlso; + } + + private static bool IsEqualOrNotEqual(Expression expression) + { + return expression.NodeType == ExpressionType.Equal || + expression.NodeType == ExpressionType.NotEqual; + } + } +}