Skip to content

Commit 9c4befc

Browse files
authored
Treat transparentcast as guessed type in hql (#3259)
Fixes #3256
1 parent d419ab3 commit 9c4befc

File tree

11 files changed

+143
-47
lines changed

11 files changed

+143
-47
lines changed

src/NHibernate.Test/Async/Linq/EnumTests.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,11 @@ public async Task ConditionalNavigationPropertyAsync()
174174
}
175175
}
176176

177-
[Test]
178-
public async Task CanQueryComplexExpressionOnTestEnumAsync()
177+
[TestCase(null)]
178+
[TestCase(TestEnum.Unspecified)]
179+
public async Task CanQueryComplexExpressionOnTestEnumAsync(TestEnum? type)
179180
{
180-
//TODO: Fix issue on SQLite with type set to TestEnum.Unspecified
181-
TestEnum? type = null;
182181
using (var session = OpenSession())
183-
using (var trans = session.BeginTransaction())
184182
{
185183
var entities = session.Query<EnumEntity>();
186184

@@ -197,7 +195,7 @@ public async Task CanQueryComplexExpressionOnTestEnumAsync()
197195
coalesce = user.NullableEnum1 ?? TestEnum.Medium
198196
}).ToListAsync());
199197

200-
Assert.That(query.Count, Is.EqualTo(0));
198+
Assert.That(query.Count, Is.EqualTo(type == TestEnum.Unspecified ? 1 : 0));
201199
}
202200
}
203201

src/NHibernate.Test/Async/Linq/WhereTests.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,32 @@ public async Task WhereWithConstantExpressionAsync()
6767
Assert.That(query.Count, Is.EqualTo(1));
6868
}
6969

70+
[Test(Description = "GH-3256")]
71+
public async Task CanUseStringEnumInConditionalAsync()
72+
{
73+
var query = db.Users
74+
.Where(
75+
user => (user.Enum1 == EnumStoredAsString.Small
76+
? EnumStoredAsString.Small
77+
: EnumStoredAsString.Large) == user.Enum1)
78+
.Select(x => x.Enum1);
79+
80+
Assert.That(await (query.CountAsync()), Is.GreaterThan(0));
81+
}
82+
83+
[Test(Description = "GH-3256")]
84+
public async Task CanUseStringEnumInConditional2Async()
85+
{
86+
var query = db.Users
87+
.Where(
88+
user => (user.Enum1 == EnumStoredAsString.Small
89+
? user.Enum1
90+
: EnumStoredAsString.Large) == user.Enum1)
91+
.Select(x => x.Enum1);
92+
93+
Assert.That(await (query.CountAsync()), Is.GreaterThan(0));
94+
}
95+
7096
[Test]
7197
public async Task FirstElementWithWhereAsync()
7298
{

src/NHibernate.Test/Linq/EnumTests.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,11 @@ public void ConditionalNavigationProperty()
161161
}
162162
}
163163

164-
[Test]
165-
public void CanQueryComplexExpressionOnTestEnum()
164+
[TestCase(null)]
165+
[TestCase(TestEnum.Unspecified)]
166+
public void CanQueryComplexExpressionOnTestEnum(TestEnum? type)
166167
{
167-
//TODO: Fix issue on SQLite with type set to TestEnum.Unspecified
168-
TestEnum? type = null;
169168
using (var session = OpenSession())
170-
using (var trans = session.BeginTransaction())
171169
{
172170
var entities = session.Query<EnumEntity>();
173171

@@ -184,7 +182,7 @@ public void CanQueryComplexExpressionOnTestEnum()
184182
coalesce = user.NullableEnum1 ?? TestEnum.Medium
185183
}).ToList();
186184

187-
Assert.That(query.Count, Is.EqualTo(0));
185+
Assert.That(query.Count, Is.EqualTo(type == TestEnum.Unspecified ? 1 : 0));
188186
}
189187
}
190188

src/NHibernate.Test/Linq/WhereTests.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,32 @@ public void WhereWithConstantExpression()
5555
Assert.That(query.Count, Is.EqualTo(1));
5656
}
5757

58+
[Test(Description = "GH-3256")]
59+
public void CanUseStringEnumInConditional()
60+
{
61+
var query = db.Users
62+
.Where(
63+
user => (user.Enum1 == EnumStoredAsString.Small
64+
? EnumStoredAsString.Small
65+
: EnumStoredAsString.Large) == user.Enum1)
66+
.Select(x => x.Enum1);
67+
68+
Assert.That(query.Count(), Is.GreaterThan(0));
69+
}
70+
71+
[Test(Description = "GH-3256")]
72+
public void CanUseStringEnumInConditional2()
73+
{
74+
var query = db.Users
75+
.Where(
76+
user => (user.Enum1 == EnumStoredAsString.Small
77+
? user.Enum1
78+
: EnumStoredAsString.Large) == user.Enum1)
79+
.Select(x => x.Enum1);
80+
81+
Assert.That(query.Count(), Is.GreaterThan(0));
82+
}
83+
5884
[Test]
5985
public void FirstElementWithWhere()
6086
{

src/NHibernate/Dialect/Function/CastFunction.cs

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,31 +61,12 @@ public virtual SqlString Render(IList args, ISessionFactoryImplementor factory)
6161
throw new QueryException("cast() requires two arguments");
6262
}
6363
string typeName = args[1].ToString();
64-
string sqlType;
65-
IType hqlType = TypeFactory.HeuristicType(typeName);
6664

67-
if (hqlType != null)
68-
{
69-
SqlType[] sqlTypeCodes = hqlType.SqlTypes(factory);
70-
if (sqlTypeCodes.Length != 1)
71-
{
72-
throw new QueryException("invalid NHibernate type for cast(), was:" + typeName);
73-
}
74-
75-
sqlType = factory.Dialect.GetCastTypeName(sqlTypeCodes[0]);
76-
//{
77-
// //trim off the length/precision/scale
78-
// int loc = sqlType.IndexOf('(');
79-
// if (loc>-1)
80-
// {
81-
// sqlType = sqlType.Substring(0, loc);
82-
// }
83-
//}
84-
}
85-
else
86-
{
87-
throw new QueryException(string.Format("invalid Hibernate type for cast(): type {0} not found", typeName));
88-
}
65+
IType hqlType =
66+
TypeFactory.HeuristicType(typeName)
67+
?? throw new QueryException(string.Format("invalid Hibernate type for cast(): type {0} not found", typeName));
68+
69+
string sqlType = GetCastTypeName(factory, hqlType, typeName);
8970

9071
// TODO 6.0: Remove pragma block with its content
9172
#pragma warning disable 618
@@ -117,6 +98,20 @@ protected virtual SqlString Render(object expression, string sqlType, ISessionFa
11798
return new SqlString("cast(", expression, " as ", sqlType, ")");
11899
}
119100

101+
internal SqlString Render(IList args, IType expectedType, ISessionFactoryImplementor factory)
102+
{
103+
return Render(args[0], GetCastTypeName(factory, expectedType, expectedType.Name), factory);
104+
}
105+
106+
private static string GetCastTypeName(ISessionFactoryImplementor factory, IType hqlType, string typeName)
107+
{
108+
SqlType[] sqlTypeCodes = hqlType.SqlTypes(factory);
109+
if (sqlTypeCodes.Length != 1)
110+
throw new QueryException("invalid NHibernate type for cast(), was:" + typeName);
111+
112+
return factory.Dialect.GetCastTypeName(sqlTypeCodes[0]);
113+
}
114+
120115
#region IFunctionGrammar Members
121116

122117
bool IFunctionGrammar.IsSeparator(string token)

src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ private void EndFunctionTemplate(IASTNode m)
285285
// this function has a template -> restore output, apply the template and write the result out
286286
var functionArguments = (FunctionArguments)writer; // TODO: Downcast to avoid using an interface? Yuck.
287287
writer = outputStack.Pop();
288-
Out(template.Render(functionArguments.Args, sessionFactory));
288+
Out(methodNode.Render(functionArguments.Args));
289289
}
290290
}
291291

src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,11 @@ public virtual void Initialize()
5656
IType lhsType = ExtractDataType( lhs );
5757
IType rhsType = ExtractDataType( rhs );
5858

59-
if ( lhsType == null )
60-
{
59+
if (lhsType == null || (IsGuessedType(lhs) && rhsType != null))
6160
lhsType = rhsType;
62-
}
63-
if ( rhsType == null )
64-
{
61+
62+
if (rhsType == null || (IsGuessedType(rhs) && lhsType != null))
6563
rhsType = lhsType;
66-
}
6764

6865
if (lhs is IExpectedTypeAwareNode lshTypeAwareNode && lshTypeAwareNode.ExpectedType == null)
6966
{
@@ -248,6 +245,8 @@ private protected static string[] ExtractMutationTexts(IASTNode operand, int cou
248245
throw new HibernateException( "dont know how to extract row value elements from node : " + operand );
249246
}
250247

248+
private static bool IsGuessedType(IASTNode operand) => TransparentCastNode.IsTransparentCast(operand);
249+
251250
protected static IType ExtractDataType(IASTNode operand)
252251
{
253252
IType type = null;

src/NHibernate/Hql/Ast/ANTLR/Tree/HqlSqlWalkerTreeAdapter.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ public override object Create(IToken payload)
6969
ret = new SqlFragment(payload);
7070
break;
7171
case HqlSqlWalker.METHOD_CALL:
72-
ret = new MethodNode(payload);
72+
ret = payload.Text == TransparentCastNode.Name
73+
? new TransparentCastNode(payload)
74+
: new MethodNode(payload);
7375
break;
7476
case HqlSqlWalker.ELEMENTS:
7577
case HqlSqlWalker.INDICES:
@@ -206,4 +208,4 @@ private void Initialise(object node)
206208
}
207209
}
208210
}
209-
}
211+
}

src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
using System;
2+
using System.Collections;
23
using System.Collections.Generic;
34
using Antlr.Runtime;
45

56
using NHibernate.Dialect.Function;
67
using NHibernate.Hql.Ast.ANTLR.Util;
78
using NHibernate.Persister.Collection;
9+
using NHibernate.SqlCommand;
810
using NHibernate.Type;
911

1012
namespace NHibernate.Hql.Ast.ANTLR.Tree
@@ -212,5 +214,10 @@ private void DialectFunction(IASTNode exprList)
212214
methodName = (String) getWalker().getTokenReplacements().get( methodName );
213215
}*/
214216
}
217+
218+
public virtual SqlString Render(IList args)
219+
{
220+
return _function.Render(args, SessionFactoryHelper.Factory);
221+
}
215222
}
216223
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using System.Collections;
2+
using Antlr.Runtime;
3+
using NHibernate.Dialect.Function;
4+
using NHibernate.SqlCommand;
5+
using NHibernate.Type;
6+
7+
namespace NHibernate.Hql.Ast.ANTLR.Tree
8+
{
9+
class TransparentCastNode : MethodNode, IExpectedTypeAwareNode
10+
{
11+
private IType _expectedType;
12+
13+
public const string Name = "transparentcast";
14+
15+
public static bool IsTransparentCast(IASTNode node)
16+
{
17+
return node.Type == HqlSqlWalker.METHOD_CALL && node.Text == Name;
18+
}
19+
20+
public TransparentCastNode(IToken token) : base(token)
21+
{
22+
}
23+
24+
public IType ExpectedType
25+
{
26+
get => _expectedType;
27+
set
28+
{
29+
_expectedType = value;
30+
var node = GetChild(0).NextSibling.GetChild(0);
31+
// A transparent cast on parameters is a special use case - skip it.
32+
if (node.Type != HqlSqlWalker.NAMED_PARAM && node is IExpectedTypeAwareNode typeNode && typeNode.ExpectedType == null)
33+
typeNode.ExpectedType = value;
34+
}
35+
}
36+
37+
public override SqlString Render(IList args)
38+
{
39+
return ExpectedType != null
40+
// Provide the expected type in case the transparent cast is transformed to an actual cast.
41+
? ((CastFunction) SQLFunction).Render(args, ExpectedType, SessionFactoryHelper.Factory)
42+
: base.Render(args);
43+
}
44+
}
45+
}

src/NHibernate/Hql/Ast/HqlTreeNode.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,9 +724,9 @@ public HqlCast(IASTFactory factory, HqlExpression expression, System.Type type)
724724
public class HqlTransparentCast : HqlExpression
725725
{
726726
public HqlTransparentCast(IASTFactory factory, HqlExpression expression, System.Type type)
727-
: base(HqlSqlWalker.METHOD_CALL, "method", factory)
727+
: base(HqlSqlWalker.METHOD_CALL, TransparentCastNode.Name, factory)
728728
{
729-
AddChild(new HqlIdent(factory, "transparentcast"));
729+
AddChild(new HqlIdent(factory, TransparentCastNode.Name));
730730
AddChild(new HqlExpressionList(factory, expression, new HqlIdent(factory, type)));
731731
}
732732
}

0 commit comments

Comments
 (0)