Skip to content

Commit 3c10311

Browse files
committed
Move number type checks to TypeExtensions and extract additional method to make code more readable
1 parent 9d0bbb5 commit 3c10311

File tree

2 files changed

+51
-24
lines changed

2 files changed

+51
-24
lines changed

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using Remotion.Linq.Clauses.Expressions;
1414
using Remotion.Linq.Clauses.ResultOperators;
1515
using Remotion.Linq.Parsing;
16+
using TypeExtensions = NHibernate.Util.TypeExtensions;
1617

1718
namespace NHibernate.Linq.Visitors
1819
{
@@ -50,25 +51,6 @@ public static class ParameterTypeLocator
5051
};
5152

5253

53-
private static readonly HashSet<System.Type> IntegralNumericTypes = new HashSet<System.Type>
54-
{
55-
typeof(sbyte),
56-
typeof(short),
57-
typeof(int),
58-
typeof(long),
59-
typeof(byte),
60-
typeof(ushort),
61-
typeof(uint),
62-
typeof(ulong)
63-
};
64-
65-
private static readonly HashSet<System.Type> FloatingPointNumericTypes = new HashSet<System.Type>
66-
{
67-
typeof(decimal),
68-
typeof(float),
69-
typeof(double)
70-
};
71-
7254
/// <summary>
7355
/// Set query parameter types based on the given query model.
7456
/// </summary>
@@ -150,6 +132,29 @@ private static IType GetCandidateType(
150132
return candidateType;
151133
}
152134

135+
private static IType GetCandidateType(
136+
ISessionFactoryImplementor sessionFactory,
137+
HashSet<ConstantExpression> constantExpressions,
138+
ConstantTypeLocatorVisitor visitor,
139+
System.Type constantType)
140+
{
141+
var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor);
142+
143+
if (candidateType == null)
144+
{
145+
return null;
146+
}
147+
148+
// When comparing an integral column with a real parameter, the parameter type must remain real type
149+
// and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
150+
if (constantType.IsRealNumberType() && candidateType.ReturnedClass.IsIntegralNumberType())
151+
{
152+
return null;
153+
}
154+
155+
return candidateType;
156+
}
157+
153158
private static IType GetParameterType(
154159
ISessionFactoryImplementor sessionFactory,
155160
HashSet<ConstantExpression> constantExpressions,
@@ -160,11 +165,8 @@ private static IType GetParameterType(
160165
// All constant expressions have the same type/value
161166
var constantExpression = constantExpressions.First();
162167
var constantType = constantExpression.Type.UnwrapIfNullable();
163-
var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor);
164-
if (candidateType != null &&
165-
// When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
166-
// and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
167-
!(FloatingPointNumericTypes.Contains(constantType) && IntegralNumericTypes.Contains(candidateType.ReturnedClass)))
168+
var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor, constantType);
169+
if (candidateType != null)
168170
{
169171
return candidateType;
170172
}

src/NHibernate/Util/TypeExtensions.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,30 @@ internal static System.Type UnwrapIfNullable(this System.Type type)
4848

4949
return type;
5050
}
51+
52+
internal static bool IsIntegralNumberType(this System.Type type)
53+
{
54+
var code = System.Type.GetTypeCode(type);
55+
if (code == TypeCode.SByte || code == TypeCode.Byte ||
56+
code == TypeCode.Int16 || code == TypeCode.UInt16 ||
57+
code == TypeCode.Int32 || code == TypeCode.UInt32 ||
58+
code == TypeCode.Int64 || code == TypeCode.UInt64)
59+
{
60+
return true;
61+
}
62+
63+
return false;
64+
}
65+
66+
internal static bool IsRealNumberType(this System.Type type)
67+
{
68+
var code = System.Type.GetTypeCode(type);
69+
if (code == TypeCode.Decimal || code == TypeCode.Single || code == TypeCode.Double)
70+
{
71+
return true;
72+
}
73+
74+
return false;
75+
}
5176
}
5277
}

0 commit comments

Comments
 (0)