diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH0829/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH0829/Fixture.cs new file mode 100644 index 00000000000..36015f3d1c6 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH0829/Fixture.cs @@ -0,0 +1,73 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH0829 +{ + using System.Threading.Tasks; + [TestFixture] + public class FixtureAsync : BugTestCase + { + protected override void OnSetUp() + { + using var session = OpenSession(); + using var transaction = session.BeginTransaction(); + + var e1 = new Parent { Type = TestEnum.A | TestEnum.C }; + session.Save(e1); + + var e2 = new Child { Type = TestEnum.D, Parent = e1 }; + session.Save(e2); + + var e3 = new Child { Type = TestEnum.C, Parent = e1 }; + session.Save(e3); + + transaction.Commit(); + } + + [Test] + public async Task SelectClassAsync() + { + using var session = OpenSession(); + + var resultFound = await (session.Query().Where(x => x.Type.HasFlag(TestEnum.A)).FirstOrDefaultAsync()); + + var resultNotFound = await (session.Query().Where(x => x.Type.HasFlag(TestEnum.D)).FirstOrDefaultAsync()); + + Assert.That(resultFound, Is.Not.Null); + Assert.That(resultNotFound, Is.Null); + } + + [Test] + public async Task SelectChildClassContainedInParentAsync() + { + using var session = OpenSession(); + + var result = await (session.Query().Where(x => x.Parent.Type.HasFlag(x.Type)).FirstOrDefaultAsync()); + + Assert.That(result, Is.Not.Null); + } + + protected override void OnTearDown() + { + using var session = OpenSession(); + using var transaction = session.BeginTransaction(); + foreach (var entity in new[] { nameof(Child), nameof(Parent) }) + { + session.CreateQuery($"delete from {entity}").ExecuteUpdate(); + } + + transaction.Commit(); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH0829/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH0829/Fixture.cs new file mode 100644 index 00000000000..5726de6d4b2 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH0829/Fixture.cs @@ -0,0 +1,61 @@ +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH0829 +{ + [TestFixture] + public class Fixture : BugTestCase + { + protected override void OnSetUp() + { + using var session = OpenSession(); + using var transaction = session.BeginTransaction(); + + var e1 = new Parent { Type = TestEnum.A | TestEnum.C }; + session.Save(e1); + + var e2 = new Child { Type = TestEnum.D, Parent = e1 }; + session.Save(e2); + + var e3 = new Child { Type = TestEnum.C, Parent = e1 }; + session.Save(e3); + + transaction.Commit(); + } + + [Test] + public void SelectClass() + { + using var session = OpenSession(); + + var resultFound = session.Query().Where(x => x.Type.HasFlag(TestEnum.A)).FirstOrDefault(); + + var resultNotFound = session.Query().Where(x => x.Type.HasFlag(TestEnum.D)).FirstOrDefault(); + + Assert.That(resultFound, Is.Not.Null); + Assert.That(resultNotFound, Is.Null); + } + + [Test] + public void SelectChildClassContainedInParent() + { + using var session = OpenSession(); + + var result = session.Query().Where(x => x.Parent.Type.HasFlag(x.Type)).FirstOrDefault(); + + Assert.That(result, Is.Not.Null); + } + + protected override void OnTearDown() + { + using var session = OpenSession(); + using var transaction = session.BeginTransaction(); + foreach (var entity in new[] { nameof(Child), nameof(Parent) }) + { + session.CreateQuery($"delete from {entity}").ExecuteUpdate(); + } + + transaction.Commit(); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH0829/Mappings.hbm.xml b/src/NHibernate.Test/NHSpecificTest/GH0829/Mappings.hbm.xml new file mode 100644 index 00000000000..7527c9d74c7 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH0829/Mappings.hbm.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.Test/NHSpecificTest/GH0829/Model.cs b/src/NHibernate.Test/NHSpecificTest/GH0829/Model.cs new file mode 100644 index 00000000000..92af512fddb --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH0829/Model.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; + +namespace NHibernate.Test.NHSpecificTest.GH0829 +{ + public class Parent + { + public virtual Guid Id { get; set; } + + public virtual TestEnum Type { get; set; } + + public virtual IList Children { get; set; } = new List(); + } + + public class Child + { + public virtual Guid Id { get; set; } + + public virtual TestEnum Type { get; set; } + + public virtual Parent Parent { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH0829/TestEnum.cs b/src/NHibernate.Test/NHSpecificTest/GH0829/TestEnum.cs new file mode 100644 index 00000000000..5d3a919400e --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH0829/TestEnum.cs @@ -0,0 +1,13 @@ +using System; + +namespace NHibernate.Test.NHSpecificTest.GH0829 +{ + [Flags] + public enum TestEnum + { + A = 1 << 0, + B = 1 << 1, + C = 1 << 2, + D = 1 << 3 + } +} diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs index 8feac321a1c..f8e152229f2 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs @@ -1,3 +1,4 @@ +using System; using System.Linq.Expressions; using NHibernate.Util; using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; @@ -20,6 +21,7 @@ public class RemoveRedundantCast : IExpressionTransformer public Expression Transform(UnaryExpression expression) { if (expression.Type != typeof(object) && + expression.Type != typeof(Enum) && expression.Type.IsAssignableFrom(expression.Operand.Type) && expression.Method == null && !expression.IsLiftedToNull) diff --git a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs index 29595877d9f..6800271f9f0 100644 --- a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs +++ b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs @@ -69,6 +69,7 @@ public DefaultLinqToHqlGeneratorsRegistry() this.Merge(new DecimalNegateGenerator()); this.Merge(new RoundGenerator()); this.Merge(new TruncateGenerator()); + this.Merge(new HasFlagGenerator()); var indexerGenerator = new ListIndexerGenerator(); RegisterGenerator(indexerGenerator); diff --git a/src/NHibernate/Linq/Functions/HasFlagGenerator.cs b/src/NHibernate/Linq/Functions/HasFlagGenerator.cs new file mode 100644 index 00000000000..99b0c0ced7a --- /dev/null +++ b/src/NHibernate/Linq/Functions/HasFlagGenerator.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Visitors; +using NHibernate.Util; + +namespace NHibernate.Linq.Functions +{ + internal class HasFlagGenerator : BaseHqlGeneratorForMethod + { + private const string _bitAndFunctionName = "band"; + + public HasFlagGenerator() + { + SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(x => x.HasFlag(default)) }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Equality( + treeBuilder.MethodCall( + _bitAndFunctionName, + visitor.Visit(targetObject).AsExpression(), + visitor.Visit(arguments[0]).AsExpression()), + visitor.Visit(arguments[0]).AsExpression()); + } + } +}