diff --git a/src/NHibernate.Test/Async/Linq/EnumTests.cs b/src/NHibernate.Test/Async/Linq/EnumTests.cs index c869012844c..2fb552412f5 100644 --- a/src/NHibernate.Test/Async/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Async/Linq/EnumTests.cs @@ -174,13 +174,11 @@ public async Task ConditionalNavigationPropertyAsync() } } - [Test] - public async Task CanQueryComplexExpressionOnTestEnumAsync() + [TestCase(null)] + [TestCase(TestEnum.Unspecified)] + public async Task CanQueryComplexExpressionOnTestEnumAsync(TestEnum? type) { - //TODO: Fix issue on SQLite with type set to TestEnum.Unspecified - TestEnum? type = null; using (var session = OpenSession()) - using (var trans = session.BeginTransaction()) { var entities = session.Query(); @@ -197,7 +195,7 @@ public async Task CanQueryComplexExpressionOnTestEnumAsync() coalesce = user.NullableEnum1 ?? TestEnum.Medium }).ToListAsync()); - Assert.That(query.Count, Is.EqualTo(0)); + Assert.That(query.Count, Is.EqualTo(type == TestEnum.Unspecified ? 1 : 0)); } } diff --git a/src/NHibernate.Test/Async/Linq/WhereTests.cs b/src/NHibernate.Test/Async/Linq/WhereTests.cs index 56d183e49c9..3be794730b5 100644 --- a/src/NHibernate.Test/Async/Linq/WhereTests.cs +++ b/src/NHibernate.Test/Async/Linq/WhereTests.cs @@ -67,6 +67,32 @@ public async Task WhereWithConstantExpressionAsync() Assert.That(query.Count, Is.EqualTo(1)); } + [Test(Description = "GH-3256")] + public async Task CanUseStringEnumInConditionalAsync() + { + var query = db.Users + .Where( + user => (user.Enum1 == EnumStoredAsString.Small + ? EnumStoredAsString.Small + : EnumStoredAsString.Large) == user.Enum1) + .Select(x => x.Enum1); + + Assert.That(await (query.CountAsync()), Is.GreaterThan(0)); + } + + [Test(Description = "GH-3256")] + public async Task CanUseStringEnumInConditional2Async() + { + var query = db.Users + .Where( + user => (user.Enum1 == EnumStoredAsString.Small + ? user.Enum1 + : EnumStoredAsString.Large) == user.Enum1) + .Select(x => x.Enum1); + + Assert.That(await (query.CountAsync()), Is.GreaterThan(0)); + } + [Test] public async Task FirstElementWithWhereAsync() { diff --git a/src/NHibernate.Test/Linq/EnumTests.cs b/src/NHibernate.Test/Linq/EnumTests.cs index 21adf84bf3a..3a4a8711fef 100644 --- a/src/NHibernate.Test/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Linq/EnumTests.cs @@ -161,13 +161,11 @@ public void ConditionalNavigationProperty() } } - [Test] - public void CanQueryComplexExpressionOnTestEnum() + [TestCase(null)] + [TestCase(TestEnum.Unspecified)] + public void CanQueryComplexExpressionOnTestEnum(TestEnum? type) { - //TODO: Fix issue on SQLite with type set to TestEnum.Unspecified - TestEnum? type = null; using (var session = OpenSession()) - using (var trans = session.BeginTransaction()) { var entities = session.Query(); @@ -184,7 +182,7 @@ public void CanQueryComplexExpressionOnTestEnum() coalesce = user.NullableEnum1 ?? TestEnum.Medium }).ToList(); - Assert.That(query.Count, Is.EqualTo(0)); + Assert.That(query.Count, Is.EqualTo(type == TestEnum.Unspecified ? 1 : 0)); } } diff --git a/src/NHibernate.Test/Linq/WhereTests.cs b/src/NHibernate.Test/Linq/WhereTests.cs index 02dc58b34b7..ab6d6ec8763 100644 --- a/src/NHibernate.Test/Linq/WhereTests.cs +++ b/src/NHibernate.Test/Linq/WhereTests.cs @@ -55,6 +55,32 @@ public void WhereWithConstantExpression() Assert.That(query.Count, Is.EqualTo(1)); } + [Test(Description = "GH-3256")] + public void CanUseStringEnumInConditional() + { + var query = db.Users + .Where( + user => (user.Enum1 == EnumStoredAsString.Small + ? EnumStoredAsString.Small + : EnumStoredAsString.Large) == user.Enum1) + .Select(x => x.Enum1); + + Assert.That(query.Count(), Is.GreaterThan(0)); + } + + [Test(Description = "GH-3256")] + public void CanUseStringEnumInConditional2() + { + var query = db.Users + .Where( + user => (user.Enum1 == EnumStoredAsString.Small + ? user.Enum1 + : EnumStoredAsString.Large) == user.Enum1) + .Select(x => x.Enum1); + + Assert.That(query.Count(), Is.GreaterThan(0)); + } + [Test] public void FirstElementWithWhere() { diff --git a/src/NHibernate/Dialect/Function/CastFunction.cs b/src/NHibernate/Dialect/Function/CastFunction.cs index 7e1e38db9bc..eab47b50d62 100644 --- a/src/NHibernate/Dialect/Function/CastFunction.cs +++ b/src/NHibernate/Dialect/Function/CastFunction.cs @@ -61,31 +61,12 @@ public virtual SqlString Render(IList args, ISessionFactoryImplementor factory) throw new QueryException("cast() requires two arguments"); } string typeName = args[1].ToString(); - string sqlType; - IType hqlType = TypeFactory.HeuristicType(typeName); - if (hqlType != null) - { - SqlType[] sqlTypeCodes = hqlType.SqlTypes(factory); - if (sqlTypeCodes.Length != 1) - { - throw new QueryException("invalid NHibernate type for cast(), was:" + typeName); - } - - sqlType = factory.Dialect.GetCastTypeName(sqlTypeCodes[0]); - //{ - // //trim off the length/precision/scale - // int loc = sqlType.IndexOf('('); - // if (loc>-1) - // { - // sqlType = sqlType.Substring(0, loc); - // } - //} - } - else - { - throw new QueryException(string.Format("invalid Hibernate type for cast(): type {0} not found", typeName)); - } + IType hqlType = + TypeFactory.HeuristicType(typeName) + ?? throw new QueryException(string.Format("invalid Hibernate type for cast(): type {0} not found", typeName)); + + string sqlType = GetCastTypeName(factory, hqlType, typeName); // TODO 6.0: Remove pragma block with its content #pragma warning disable 618 @@ -117,6 +98,20 @@ protected virtual SqlString Render(object expression, string sqlType, ISessionFa return new SqlString("cast(", expression, " as ", sqlType, ")"); } + internal SqlString Render(IList args, IType expectedType, ISessionFactoryImplementor factory) + { + return Render(args[0], GetCastTypeName(factory, expectedType, expectedType.Name), factory); + } + + private static string GetCastTypeName(ISessionFactoryImplementor factory, IType hqlType, string typeName) + { + SqlType[] sqlTypeCodes = hqlType.SqlTypes(factory); + if (sqlTypeCodes.Length != 1) + throw new QueryException("invalid NHibernate type for cast(), was:" + typeName); + + return factory.Dialect.GetCastTypeName(sqlTypeCodes[0]); + } + #region IFunctionGrammar Members bool IFunctionGrammar.IsSeparator(string token) diff --git a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs index 1b96ede99ad..487bb246b23 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs @@ -285,7 +285,7 @@ private void EndFunctionTemplate(IASTNode m) // this function has a template -> restore output, apply the template and write the result out var functionArguments = (FunctionArguments)writer; // TODO: Downcast to avoid using an interface? Yuck. writer = outputStack.Pop(); - Out(template.Render(functionArguments.Args, sessionFactory)); + Out(methodNode.Render(functionArguments.Args)); } } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs index cae4b920ec8..8d1f7efce38 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs @@ -56,14 +56,11 @@ public virtual void Initialize() IType lhsType = ExtractDataType( lhs ); IType rhsType = ExtractDataType( rhs ); - if ( lhsType == null ) - { + if (lhsType == null || (IsGuessedType(lhs) && rhsType != null)) lhsType = rhsType; - } - if ( rhsType == null ) - { + + if (rhsType == null || (IsGuessedType(rhs) && lhsType != null)) rhsType = lhsType; - } if (lhs is IExpectedTypeAwareNode lshTypeAwareNode && lshTypeAwareNode.ExpectedType == null) { @@ -248,6 +245,8 @@ private protected static string[] ExtractMutationTexts(IASTNode operand, int cou throw new HibernateException( "dont know how to extract row value elements from node : " + operand ); } + private static bool IsGuessedType(IASTNode operand) => TransparentCastNode.IsTransparentCast(operand); + protected static IType ExtractDataType(IASTNode operand) { IType type = null; diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/HqlSqlWalkerTreeAdapter.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/HqlSqlWalkerTreeAdapter.cs index aa08d2fba0c..fcbe8dd4558 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/HqlSqlWalkerTreeAdapter.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/HqlSqlWalkerTreeAdapter.cs @@ -69,7 +69,9 @@ public override object Create(IToken payload) ret = new SqlFragment(payload); break; case HqlSqlWalker.METHOD_CALL: - ret = new MethodNode(payload); + ret = payload.Text == TransparentCastNode.Name + ? new TransparentCastNode(payload) + : new MethodNode(payload); break; case HqlSqlWalker.ELEMENTS: case HqlSqlWalker.INDICES: @@ -206,4 +208,4 @@ private void Initialise(object node) } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs index 8638a81ce94..3a36e9e635f 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs @@ -1,10 +1,12 @@ using System; +using System.Collections; using System.Collections.Generic; using Antlr.Runtime; using NHibernate.Dialect.Function; using NHibernate.Hql.Ast.ANTLR.Util; using NHibernate.Persister.Collection; +using NHibernate.SqlCommand; using NHibernate.Type; namespace NHibernate.Hql.Ast.ANTLR.Tree @@ -212,5 +214,10 @@ private void DialectFunction(IASTNode exprList) methodName = (String) getWalker().getTokenReplacements().get( methodName ); }*/ } + + public virtual SqlString Render(IList args) + { + return _function.Render(args, SessionFactoryHelper.Factory); + } } } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/TransparentCastNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/TransparentCastNode.cs new file mode 100644 index 00000000000..9f69e69bdcd --- /dev/null +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/TransparentCastNode.cs @@ -0,0 +1,45 @@ +using System.Collections; +using Antlr.Runtime; +using NHibernate.Dialect.Function; +using NHibernate.SqlCommand; +using NHibernate.Type; + +namespace NHibernate.Hql.Ast.ANTLR.Tree +{ + class TransparentCastNode : MethodNode, IExpectedTypeAwareNode + { + private IType _expectedType; + + public const string Name = "transparentcast"; + + public static bool IsTransparentCast(IASTNode node) + { + return node.Type == HqlSqlWalker.METHOD_CALL && node.Text == Name; + } + + public TransparentCastNode(IToken token) : base(token) + { + } + + public IType ExpectedType + { + get => _expectedType; + set + { + _expectedType = value; + var node = GetChild(0).NextSibling.GetChild(0); + // A transparent cast on parameters is a special use case - skip it. + if (node.Type != HqlSqlWalker.NAMED_PARAM && node is IExpectedTypeAwareNode typeNode && typeNode.ExpectedType == null) + typeNode.ExpectedType = value; + } + } + + public override SqlString Render(IList args) + { + return ExpectedType != null + // Provide the expected type in case the transparent cast is transformed to an actual cast. + ? ((CastFunction) SQLFunction).Render(args, ExpectedType, SessionFactoryHelper.Factory) + : base.Render(args); + } + } +} diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 7f6156f7dd6..98e323602ea 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -724,9 +724,9 @@ public HqlCast(IASTFactory factory, HqlExpression expression, System.Type type) public class HqlTransparentCast : HqlExpression { public HqlTransparentCast(IASTFactory factory, HqlExpression expression, System.Type type) - : base(HqlSqlWalker.METHOD_CALL, "method", factory) + : base(HqlSqlWalker.METHOD_CALL, TransparentCastNode.Name, factory) { - AddChild(new HqlIdent(factory, "transparentcast")); + AddChild(new HqlIdent(factory, TransparentCastNode.Name)); AddChild(new HqlExpressionList(factory, expression, new HqlIdent(factory, type))); } }