Skip to content

Commit 4decb54

Browse files
committed
Do not set guessed type as expected parameter type in LINQ
1 parent e0384a3 commit 4decb54

File tree

7 files changed

+277
-33
lines changed

7 files changed

+277
-33
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System.Data;
12+
using System.Linq;
13+
using NHibernate.Cfg.MappingSchema;
14+
using NHibernate.Mapping.ByCode;
15+
using NHibernate.SqlTypes;
16+
using NUnit.Framework;
17+
using NHibernate.Linq;
18+
19+
namespace NHibernate.Test.NHSpecificTest.NH3565
20+
{
21+
using System.Threading.Tasks;
22+
[TestFixture]
23+
public class ByCodeFixtureAsync : TestCaseMappingByCode
24+
{
25+
protected override HbmMapping GetMappings()
26+
{
27+
var mapper = new ModelMapper();
28+
mapper.Class<Entity>(rc =>
29+
{
30+
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
31+
rc.Property(x => x.Name, m =>
32+
{
33+
m.Type(NHibernateUtil.AnsiString);
34+
m.Length(10);
35+
});
36+
});
37+
38+
return mapper.CompileMappingForAllExplicitlyAddedEntities();
39+
}
40+
41+
protected override bool AppliesTo(Dialect.Dialect dialect)
42+
{
43+
return base.AppliesTo(dialect)
44+
//Dialects like SQL Server CE, Firebird don't distinguish AnsiString from String
45+
&& Dialect.GetTypeName(new SqlType(DbType.AnsiString)) != Dialect.GetTypeName(new SqlType(DbType.String));
46+
}
47+
48+
protected override void OnSetUp()
49+
{
50+
using (var session = OpenSession())
51+
using (var transaction = session.BeginTransaction())
52+
{
53+
var e1 = new Entity {Name = "Bob"};
54+
session.Save(e1);
55+
56+
var e2 = new Entity {Name = "Sally"};
57+
session.Save(e2);
58+
59+
transaction.Commit();
60+
}
61+
}
62+
63+
protected override void OnTearDown()
64+
{
65+
using (var session = OpenSession())
66+
using (var transaction = session.BeginTransaction())
67+
{
68+
session.CreateQuery("delete from System.Object").ExecuteUpdate();
69+
70+
transaction.Commit();
71+
}
72+
}
73+
74+
[Test]
75+
public async Task ParameterTypeForLikeIsProperlyDetectedAsync()
76+
{
77+
using (var logSpy = new SqlLogSpy())
78+
using (var session = OpenSession())
79+
{
80+
var result = from e in session.Query<Entity>()
81+
where NHibernate.Linq.SqlMethods.Like(e.Name, "Bob")
82+
select e;
83+
84+
Assert.That(await (result.ToListAsync()), Has.Count.EqualTo(1));
85+
Assert.That(logSpy.GetWholeLog(), Does.Contain("Type: AnsiString"));
86+
}
87+
}
88+
89+
[KnownBug("Not fixed yet")]
90+
[Test]
91+
public async Task ParameterTypeForContainsIsProperlyDetectedAsync()
92+
{
93+
using (var logSpy = new SqlLogSpy())
94+
using (var session = OpenSession())
95+
{
96+
var result = from e in session.Query<Entity>()
97+
where e.Name.Contains("Bob")
98+
select e;
99+
100+
Assert.That(await (result.ToListAsync()), Has.Count.EqualTo(1));
101+
Assert.That(logSpy.GetWholeLog(), Does.Contain("Type: AnsiString"));
102+
}
103+
}
104+
105+
[KnownBug("Not fixed yet")]
106+
[Test]
107+
public async Task ParameterTypeForStartsWithIsProperlyDetectedAsync()
108+
{
109+
using (var logSpy = new SqlLogSpy())
110+
using (var session = OpenSession())
111+
{
112+
var result = from e in session.Query<Entity>()
113+
where e.Name.StartsWith("Bob")
114+
select e;
115+
116+
Assert.That(await (result.ToListAsync()), Has.Count.EqualTo(1));
117+
Assert.That(logSpy.GetWholeLog(), Does.Contain("Type: AnsiString"));
118+
}
119+
}
120+
}
121+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
3+
namespace NHibernate.Test.NHSpecificTest.NH3565
4+
{
5+
class Entity
6+
{
7+
public virtual Guid Id { get; set; }
8+
public virtual string Name { get; set; }
9+
}
10+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using System.Data;
2+
using System.Linq;
3+
using NHibernate.Cfg.MappingSchema;
4+
using NHibernate.Mapping.ByCode;
5+
using NHibernate.SqlTypes;
6+
using NUnit.Framework;
7+
8+
namespace NHibernate.Test.NHSpecificTest.NH3565
9+
{
10+
[TestFixture]
11+
public class ByCodeFixture : TestCaseMappingByCode
12+
{
13+
protected override HbmMapping GetMappings()
14+
{
15+
var mapper = new ModelMapper();
16+
mapper.Class<Entity>(rc =>
17+
{
18+
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
19+
rc.Property(x => x.Name, m =>
20+
{
21+
m.Type(NHibernateUtil.AnsiString);
22+
m.Length(10);
23+
});
24+
});
25+
26+
return mapper.CompileMappingForAllExplicitlyAddedEntities();
27+
}
28+
29+
protected override bool AppliesTo(Dialect.Dialect dialect)
30+
{
31+
return base.AppliesTo(dialect)
32+
//Dialects like SQL Server CE, Firebird don't distinguish AnsiString from String
33+
&& Dialect.GetTypeName(new SqlType(DbType.AnsiString)) != Dialect.GetTypeName(new SqlType(DbType.String));
34+
}
35+
36+
protected override void OnSetUp()
37+
{
38+
using (var session = OpenSession())
39+
using (var transaction = session.BeginTransaction())
40+
{
41+
var e1 = new Entity {Name = "Bob"};
42+
session.Save(e1);
43+
44+
var e2 = new Entity {Name = "Sally"};
45+
session.Save(e2);
46+
47+
transaction.Commit();
48+
}
49+
}
50+
51+
protected override void OnTearDown()
52+
{
53+
using (var session = OpenSession())
54+
using (var transaction = session.BeginTransaction())
55+
{
56+
session.CreateQuery("delete from System.Object").ExecuteUpdate();
57+
58+
transaction.Commit();
59+
}
60+
}
61+
62+
[Test]
63+
public void ParameterTypeForLikeIsProperlyDetected()
64+
{
65+
using (var logSpy = new SqlLogSpy())
66+
using (var session = OpenSession())
67+
{
68+
var result = from e in session.Query<Entity>()
69+
where NHibernate.Linq.SqlMethods.Like(e.Name, "Bob")
70+
select e;
71+
72+
Assert.That(result.ToList(), Has.Count.EqualTo(1));
73+
Assert.That(logSpy.GetWholeLog(), Does.Contain("Type: AnsiString"));
74+
}
75+
}
76+
77+
[KnownBug("Not fixed yet")]
78+
[Test]
79+
public void ParameterTypeForContainsIsProperlyDetected()
80+
{
81+
using (var logSpy = new SqlLogSpy())
82+
using (var session = OpenSession())
83+
{
84+
var result = from e in session.Query<Entity>()
85+
where e.Name.Contains("Bob")
86+
select e;
87+
88+
Assert.That(result.ToList(), Has.Count.EqualTo(1));
89+
Assert.That(logSpy.GetWholeLog(), Does.Contain("Type: AnsiString"));
90+
}
91+
}
92+
93+
[KnownBug("Not fixed yet")]
94+
[Test]
95+
public void ParameterTypeForStartsWithIsProperlyDetected()
96+
{
97+
using (var logSpy = new SqlLogSpy())
98+
using (var session = OpenSession())
99+
{
100+
var result = from e in session.Query<Entity>()
101+
where e.Name.StartsWith("Bob")
102+
select e;
103+
104+
Assert.That(result.ToList(), Has.Count.EqualTo(1));
105+
Assert.That(logSpy.GetWholeLog(), Does.Contain("Type: AnsiString"));
106+
}
107+
}
108+
}
109+
}

src/NHibernate/Hql/Ast/ANTLR/Tree/CaseNode.cs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,38 @@ public override IType DataType
2929
if (ExpectedType != null)
3030
return ExpectedType;
3131

32-
foreach (var node in GetResultNodes())
32+
if (base.DataType != null)
33+
return base.DataType;
34+
35+
var dataType = GetTypeFromResultNodes();
36+
37+
foreach (var node in GetResultNodes().OfType<ISelectExpression>())
3338
{
34-
if (node is ISelectExpression select && !(node is ParameterNode))
35-
return select.DataType;
39+
if (node.DataType == null && node is IExpectedTypeAwareNode typeAwareNode)
40+
{
41+
typeAwareNode.ExpectedType = dataType;
42+
}
3643
}
3744

38-
throw new HibernateException("Unable to determine data type of CASE statement.");
45+
base.DataType = dataType;
46+
return dataType;
3947
}
4048
set { base.DataType = value; }
4149
}
4250

51+
private IType GetTypeFromResultNodes()
52+
{
53+
foreach (var node in GetResultNodes())
54+
{
55+
if (node is ISelectExpression select && select.DataType != null)
56+
{
57+
return select.DataType;
58+
}
59+
}
60+
61+
throw new HibernateException("Unable to determine data type of CASE statement.");
62+
}
63+
4364
public IEnumerable<IASTNode> GetResultNodes()
4465
{
4566
for (int i = 0; i < ChildCount; i++)

src/NHibernate/Linq/DefaultQueryProvider.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,11 @@ private static void SetParameters(IQuery query, IDictionary<string, NamedParamet
265265
}
266266
else
267267
{
268-
query.SetParameter(parameter.Name, parameter.Value);
268+
//Let HQL try to process guessed types (hql doesn't support type guessing for NULL)
269+
if (parameter.Type != null && (parameter.IsGuessedType == false || parameter.Value == null))
270+
query.SetParameter(parameter.Name, parameter.Value, parameter.Type);
271+
else
272+
query.SetParameter(parameter.Name, parameter.Value);
269273
}
270274
}
271275
}

src/NHibernate/Linq/Functions/StringGenerator.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
217217
{
218218
var expression = visitor.Visit(targetObject).AsExpression();
219219
var index = treeBuilder.Add(visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Constant(1));
220-
return treeBuilder.MethodCall("substring", expression, index, treeBuilder.Constant(1));
220+
221+
return treeBuilder.TransparentCast(treeBuilder.MethodCall("substring", expression, index, treeBuilder.Constant(1)), typeof(char));
221222
}
222223
}
223224

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ internal static void SetParameterTypes(
9393
continue;
9494
}
9595

96-
namedParameter.Type = GetParameterType(sessionFactory, constantExpressions, visitor, namedParameter, out var tryProcessInHql);
97-
namedParameter.IsGuessedType = tryProcessInHql;
96+
namedParameter.Type = GetParameterType(sessionFactory, constantExpressions, visitor, namedParameter, out var isGuessedType);
97+
namedParameter.IsGuessedType = isGuessedType;
9898
}
9999
}
100100

@@ -147,9 +147,9 @@ private static IType GetParameterType(
147147
HashSet<ConstantExpression> constantExpressions,
148148
ConstantTypeLocatorVisitor visitor,
149149
NamedParameter namedParameter,
150-
out bool tryProcessInHql)
150+
out bool isGuessedType)
151151
{
152-
tryProcessInHql = false;
152+
isGuessedType = false;
153153
// All constant expressions have the same type/value
154154
var constantExpression = constantExpressions.First();
155155
var constantType = constantExpression.Type.UnwrapIfNullable();
@@ -159,10 +159,7 @@ private static IType GetParameterType(
159159
return candidateType;
160160
}
161161

162-
if (visitor.NotGuessableConstants.Contains(constantExpression) && constantExpression.Value != null)
163-
{
164-
tryProcessInHql = true;
165-
}
162+
isGuessedType = true;
166163

167164
// No related MemberExpressions was found, guess the type by value or its type when null.
168165
// When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam))
@@ -174,13 +171,10 @@ private static IType GetParameterType(
174171

175172
private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
176173
{
177-
private bool _hqlGenerator;
178174
private readonly bool _removeMappedAsCalls;
179175
private readonly System.Type _targetType;
180176
private readonly IDictionary<ConstantExpression, NamedParameter> _parameters;
181177
private readonly ISessionFactoryImplementor _sessionFactory;
182-
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
183-
public readonly HashSet<ConstantExpression> NotGuessableConstants = new HashSet<ConstantExpression>();
184178
public readonly Dictionary<ConstantExpression, IType> ConstantExpressions =
185179
new Dictionary<ConstantExpression, IType>();
186180
public readonly Dictionary<NamedParameter, HashSet<ConstantExpression>> ParameterConstants =
@@ -198,7 +192,6 @@ public ConstantTypeLocatorVisitor(
198192
_targetType = targetType;
199193
_sessionFactory = sessionFactory;
200194
_parameters = parameters;
201-
_functionRegistry = sessionFactory.Settings.LinqToHqlGeneratorsRegistry;
202195
}
203196

204197
protected override Expression VisitBinary(BinaryExpression node)
@@ -269,16 +262,6 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
269262
return node;
270263
}
271264

272-
// For hql method generators we do not want to guess the parameter type here, let hql logic figure it out.
273-
if (_functionRegistry.TryGetGenerator(node.Method, out _))
274-
{
275-
var origHqlGenerator = _hqlGenerator;
276-
_hqlGenerator = true;
277-
var expression = base.VisitMethodCall(node);
278-
_hqlGenerator = origHqlGenerator;
279-
return expression;
280-
}
281-
282265
return base.VisitMethodCall(node);
283266
}
284267

@@ -289,11 +272,6 @@ protected override Expression VisitConstant(ConstantExpression node)
289272
return node;
290273
}
291274

292-
if (_hqlGenerator)
293-
{
294-
NotGuessableConstants.Add(node);
295-
}
296-
297275
RelatedExpressions.Add(node, new HashSet<Expression>());
298276
ConstantExpressions.Add(node, null);
299277
if (!ParameterConstants.TryGetValue(param, out var set))

0 commit comments

Comments
 (0)