Skip to content

Treat transparentcast as guessed type in hql #3259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/NHibernate.Test/Async/Linq/EnumTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,11 @@ public async Task ConditionalNavigationPropertyAsync()
}
}

[Test]
public async Task CanQueryComplexExpressionOnTestEnumAsync()
[TestCase(null)]
[TestCase(TestEnum.Unspecified)]
public async Task CanQueryComplexExpressionOnTestEnumAsync(TestEnum? type)
{
//TODO: Fix issue on SQLite with type set to TestEnum.Unspecified
TestEnum? type = null;
using (var session = OpenSession())
using (var trans = session.BeginTransaction())
{
var entities = session.Query<EnumEntity>();

Expand All @@ -197,7 +195,7 @@ public async Task CanQueryComplexExpressionOnTestEnumAsync()
coalesce = user.NullableEnum1 ?? TestEnum.Medium
}).ToListAsync());

Assert.That(query.Count, Is.EqualTo(0));
Assert.That(query.Count, Is.EqualTo(type == TestEnum.Unspecified ? 1 : 0));
}
}

Expand Down
26 changes: 26 additions & 0 deletions src/NHibernate.Test/Async/Linq/WhereTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@ public async Task WhereWithConstantExpressionAsync()
Assert.That(query.Count, Is.EqualTo(1));
}

[Test(Description = "GH-3256")]
public async Task CanUseStringEnumInConditionalAsync()
{
var query = db.Users
.Where(
user => (user.Enum1 == EnumStoredAsString.Small
? EnumStoredAsString.Small
: EnumStoredAsString.Large) == user.Enum1)
.Select(x => x.Enum1);

Assert.That(await (query.CountAsync()), Is.GreaterThan(0));
}

[Test(Description = "GH-3256")]
public async Task CanUseStringEnumInConditional2Async()
{
var query = db.Users
.Where(
user => (user.Enum1 == EnumStoredAsString.Small
? user.Enum1
: EnumStoredAsString.Large) == user.Enum1)
.Select(x => x.Enum1);

Assert.That(await (query.CountAsync()), Is.GreaterThan(0));
}

[Test]
public async Task FirstElementWithWhereAsync()
{
Expand Down
10 changes: 4 additions & 6 deletions src/NHibernate.Test/Linq/EnumTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,11 @@ public void ConditionalNavigationProperty()
}
}

[Test]
public void CanQueryComplexExpressionOnTestEnum()
[TestCase(null)]
[TestCase(TestEnum.Unspecified)]
public void CanQueryComplexExpressionOnTestEnum(TestEnum? type)
{
//TODO: Fix issue on SQLite with type set to TestEnum.Unspecified
TestEnum? type = null;
using (var session = OpenSession())
using (var trans = session.BeginTransaction())
{
var entities = session.Query<EnumEntity>();

Expand All @@ -184,7 +182,7 @@ public void CanQueryComplexExpressionOnTestEnum()
coalesce = user.NullableEnum1 ?? TestEnum.Medium
}).ToList();

Assert.That(query.Count, Is.EqualTo(0));
Assert.That(query.Count, Is.EqualTo(type == TestEnum.Unspecified ? 1 : 0));
}
}

Expand Down
26 changes: 26 additions & 0 deletions src/NHibernate.Test/Linq/WhereTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ public void WhereWithConstantExpression()
Assert.That(query.Count, Is.EqualTo(1));
}

[Test(Description = "GH-3256")]
public void CanUseStringEnumInConditional()
{
var query = db.Users
.Where(
user => (user.Enum1 == EnumStoredAsString.Small
? EnumStoredAsString.Small
: EnumStoredAsString.Large) == user.Enum1)
.Select(x => x.Enum1);

Assert.That(query.Count(), Is.GreaterThan(0));
}

[Test(Description = "GH-3256")]
public void CanUseStringEnumInConditional2()
{
var query = db.Users
.Where(
user => (user.Enum1 == EnumStoredAsString.Small
? user.Enum1
: EnumStoredAsString.Large) == user.Enum1)
.Select(x => x.Enum1);

Assert.That(query.Count(), Is.GreaterThan(0));
}

[Test]
public void FirstElementWithWhere()
{
Expand Down
43 changes: 19 additions & 24 deletions src/NHibernate/Dialect/Function/CastFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,12 @@ public virtual SqlString Render(IList args, ISessionFactoryImplementor factory)
throw new QueryException("cast() requires two arguments");
}
string typeName = args[1].ToString();
string sqlType;
IType hqlType = TypeFactory.HeuristicType(typeName);

if (hqlType != null)
{
SqlType[] sqlTypeCodes = hqlType.SqlTypes(factory);
if (sqlTypeCodes.Length != 1)
{
throw new QueryException("invalid NHibernate type for cast(), was:" + typeName);
}

sqlType = factory.Dialect.GetCastTypeName(sqlTypeCodes[0]);
//{
// //trim off the length/precision/scale
// int loc = sqlType.IndexOf('(');
// if (loc>-1)
// {
// sqlType = sqlType.Substring(0, loc);
// }
//}
}
else
{
throw new QueryException(string.Format("invalid Hibernate type for cast(): type {0} not found", typeName));
}
IType hqlType =
TypeFactory.HeuristicType(typeName)
?? throw new QueryException(string.Format("invalid Hibernate type for cast(): type {0} not found", typeName));

string sqlType = GetCastTypeName(factory, hqlType, typeName);

// TODO 6.0: Remove pragma block with its content
#pragma warning disable 618
Expand Down Expand Up @@ -117,6 +98,20 @@ protected virtual SqlString Render(object expression, string sqlType, ISessionFa
return new SqlString("cast(", expression, " as ", sqlType, ")");
}

internal SqlString Render(IList args, IType expectedType, ISessionFactoryImplementor factory)
{
return Render(args[0], GetCastTypeName(factory, expectedType, expectedType.Name), factory);
}

private static string GetCastTypeName(ISessionFactoryImplementor factory, IType hqlType, string typeName)
{
SqlType[] sqlTypeCodes = hqlType.SqlTypes(factory);
if (sqlTypeCodes.Length != 1)
throw new QueryException("invalid NHibernate type for cast(), was:" + typeName);

return factory.Dialect.GetCastTypeName(sqlTypeCodes[0]);
}

#region IFunctionGrammar Members

bool IFunctionGrammar.IsSeparator(string token)
Expand Down
2 changes: 1 addition & 1 deletion src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ private void EndFunctionTemplate(IASTNode m)
// this function has a template -> restore output, apply the template and write the result out
var functionArguments = (FunctionArguments)writer; // TODO: Downcast to avoid using an interface? Yuck.
writer = outputStack.Pop();
Out(template.Render(functionArguments.Args, sessionFactory));
Out(methodNode.Render(functionArguments.Args));
}
}

Expand Down
11 changes: 5 additions & 6 deletions src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,11 @@ public virtual void Initialize()
IType lhsType = ExtractDataType( lhs );
IType rhsType = ExtractDataType( rhs );

if ( lhsType == null )
{
if (lhsType == null || (IsGuessedType(lhs) && rhsType != null))
lhsType = rhsType;
}
if ( rhsType == null )
{

if (rhsType == null || (IsGuessedType(rhs) && lhsType != null))
rhsType = lhsType;
}

if (lhs is IExpectedTypeAwareNode lshTypeAwareNode && lshTypeAwareNode.ExpectedType == null)
{
Expand Down Expand Up @@ -248,6 +245,8 @@ private protected static string[] ExtractMutationTexts(IASTNode operand, int cou
throw new HibernateException( "dont know how to extract row value elements from node : " + operand );
}

private static bool IsGuessedType(IASTNode operand) => TransparentCastNode.IsTransparentCast(operand);

protected static IType ExtractDataType(IASTNode operand)
{
IType type = null;
Expand Down
6 changes: 4 additions & 2 deletions src/NHibernate/Hql/Ast/ANTLR/Tree/HqlSqlWalkerTreeAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ public override object Create(IToken payload)
ret = new SqlFragment(payload);
break;
case HqlSqlWalker.METHOD_CALL:
ret = new MethodNode(payload);
ret = payload.Text == TransparentCastNode.Name
? new TransparentCastNode(payload)
: new MethodNode(payload);
break;
case HqlSqlWalker.ELEMENTS:
case HqlSqlWalker.INDICES:
Expand Down Expand Up @@ -206,4 +208,4 @@ private void Initialise(object node)
}
}
}
}
}
7 changes: 7 additions & 0 deletions src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using System;
using System.Collections;
using System.Collections.Generic;
using Antlr.Runtime;

using NHibernate.Dialect.Function;
using NHibernate.Hql.Ast.ANTLR.Util;
using NHibernate.Persister.Collection;
using NHibernate.SqlCommand;
using NHibernate.Type;

namespace NHibernate.Hql.Ast.ANTLR.Tree
Expand Down Expand Up @@ -212,5 +214,10 @@ private void DialectFunction(IASTNode exprList)
methodName = (String) getWalker().getTokenReplacements().get( methodName );
}*/
}

public virtual SqlString Render(IList args)
{
return _function.Render(args, SessionFactoryHelper.Factory);
}
}
}
45 changes: 45 additions & 0 deletions src/NHibernate/Hql/Ast/ANTLR/Tree/TransparentCastNode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using System.Collections;
using Antlr.Runtime;
using NHibernate.Dialect.Function;
using NHibernate.SqlCommand;
using NHibernate.Type;

namespace NHibernate.Hql.Ast.ANTLR.Tree
{
class TransparentCastNode : MethodNode, IExpectedTypeAwareNode
{
private IType _expectedType;

public const string Name = "transparentcast";

public static bool IsTransparentCast(IASTNode node)
{
return node.Type == HqlSqlWalker.METHOD_CALL && node.Text == Name;
}

public TransparentCastNode(IToken token) : base(token)
{
}

public IType ExpectedType
{
get => _expectedType;
set
{
_expectedType = value;
var node = GetChild(0).NextSibling.GetChild(0);
// A transparent cast on parameters is a special use case - skip it.
if (node.Type != HqlSqlWalker.NAMED_PARAM && node is IExpectedTypeAwareNode typeNode && typeNode.ExpectedType == null)
typeNode.ExpectedType = value;
}
}

public override SqlString Render(IList args)
{
return ExpectedType != null
// Provide the expected type in case the transparent cast is transformed to an actual cast.
? ((CastFunction) SQLFunction).Render(args, ExpectedType, SessionFactoryHelper.Factory)
: base.Render(args);
}
}
}
4 changes: 2 additions & 2 deletions src/NHibernate/Hql/Ast/HqlTreeNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,9 @@ public HqlCast(IASTFactory factory, HqlExpression expression, System.Type type)
public class HqlTransparentCast : HqlExpression
{
public HqlTransparentCast(IASTFactory factory, HqlExpression expression, System.Type type)
: base(HqlSqlWalker.METHOD_CALL, "method", factory)
: base(HqlSqlWalker.METHOD_CALL, TransparentCastNode.Name, factory)
{
AddChild(new HqlIdent(factory, "transparentcast"));
AddChild(new HqlIdent(factory, TransparentCastNode.Name));
AddChild(new HqlExpressionList(factory, expression, new HqlIdent(factory, type)));
}
}
Expand Down