Skip to content

Commit 0a59292

Browse files
committed
Fix oracle and mysql issues
1 parent 95d1dc1 commit 0a59292

File tree

5 files changed

+112
-27
lines changed

5 files changed

+112
-27
lines changed

src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ protected override void OnTearDown()
8484
[Test]
8585
public void NullableIntOverflow()
8686
{
87+
var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) !=
88+
Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType);
89+
8790
using (var session = OpenSession())
8891
using (session.BeginTransaction())
8992
using (var sqlLog = new SqlLogSpy())
@@ -96,7 +99,7 @@ public void NullableIntOverflow()
9699
})
97100
.ToList();
98101

99-
Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1));
102+
Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0));
100103
Assert.That(groups, Has.Count.EqualTo(1));
101104
Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2));
102105
}
@@ -105,6 +108,9 @@ public void NullableIntOverflow()
105108
[Test]
106109
public void IntOverflow()
107110
{
111+
var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) !=
112+
Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType);
113+
108114
using (var session = OpenSession())
109115
using (session.BeginTransaction())
110116
using (var sqlLog = new SqlLogSpy())
@@ -117,7 +123,7 @@ public void IntOverflow()
117123
})
118124
.ToList();
119125

120-
Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1));
126+
Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0));
121127
Assert.That(groups, Has.Count.EqualTo(1));
122128
Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3));
123129
}

src/NHibernate/Dialect/Dialect.cs

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,18 @@ public virtual string GetLongestTypeName(DbType dbType)
266266
public virtual string GetCastTypeName(SqlType sqlType) =>
267267
GetCastTypeName(sqlType, _typeNames);
268268

269+
/// <summary>
270+
/// Get the name of the database type appropriate for casting operations
271+
/// (via the CAST() SQL function) for the given <see cref="SqlType"/> typecode.
272+
/// </summary>
273+
/// <param name="sqlType">The <see cref="SqlType"/> typecode.</param>
274+
/// <param name="typeName">The database type name that will be set in case it was found.</param>
275+
/// <returns>Whether the type name was found.</returns>
276+
public virtual bool TryGetCastTypeName(SqlType sqlType, out string typeName)
277+
{
278+
return TryGetCastTypeName(sqlType, _typeNames, out typeName);
279+
}
280+
269281
/// <summary>
270282
/// Get the name of the database type appropriate for casting operations
271283
/// (via the CAST() SQL function) for the given <see cref="SqlType"/> typecode.
@@ -274,28 +286,46 @@ public virtual string GetCastTypeName(SqlType sqlType) =>
274286
/// <param name="castTypeNames">The source for type names.</param>
275287
/// <returns>The database type name.</returns>
276288
protected virtual string GetCastTypeName(SqlType sqlType, TypeNames castTypeNames)
289+
{
290+
if (!TryGetCastTypeName(sqlType, castTypeNames, out var result))
291+
{
292+
throw new ArgumentException("Dialect does not support DbType." + sqlType.DbType, nameof(sqlType));
293+
}
294+
295+
return result;
296+
}
297+
298+
/// <summary>
299+
/// Get the name of the database type appropriate for casting operations
300+
/// (via the CAST() SQL function) for the given <see cref="SqlType"/> typecode.
301+
/// </summary>
302+
/// <param name="sqlType">The <see cref="SqlType"/> typecode.</param>
303+
/// <param name="castTypeNames">The source for type names.</param>
304+
/// <param name="typeName">The database type name that will be set in case it was found.</param>
305+
/// <returns>Whether the type name was found.</returns>
306+
protected virtual bool TryGetCastTypeName(SqlType sqlType, TypeNames castTypeNames, out string typeName)
277307
{
278308
if (sqlType.LengthDefined || sqlType.PrecisionDefined || sqlType.ScaleDefined)
279-
return castTypeNames.Get(sqlType.DbType, sqlType.Length, sqlType.Precision, sqlType.Scale);
309+
return castTypeNames.TryGet(sqlType.DbType, sqlType.Length, sqlType.Precision, sqlType.Scale, out typeName);
280310
switch (sqlType.DbType)
281311
{
282312
case DbType.Decimal:
283313
// Oracle dialect defines precision and scale for double, because it uses number instead of binary_double.
284314
case DbType.Double:
285315
// We cannot know if the user needs its digit after or before the dot, so use a configurable
286316
// default.
287-
return castTypeNames.Get(sqlType.DbType, 0, DefaultCastPrecision, DefaultCastScale);
317+
return castTypeNames.TryGet(sqlType.DbType, 0, DefaultCastPrecision, DefaultCastScale, out typeName);
288318
case DbType.DateTime:
289319
case DbType.DateTime2:
290320
case DbType.DateTimeOffset:
291321
case DbType.Time:
292322
case DbType.Currency:
293323
// Use default for these, dialects are supposed to map them to max capacity
294-
return castTypeNames.Get(sqlType.DbType);
324+
return castTypeNames.TryGet(sqlType.DbType, out typeName);
295325
default:
296326
// Other types are either length bound or not length/precision/scale bound. Otherwise they need to be
297327
// handled previously.
298-
return castTypeNames.Get(sqlType.DbType, DefaultCastLength, 0, 0);
328+
return castTypeNames.TryGet(sqlType.DbType, DefaultCastLength, 0, 0, out typeName);
299329
}
300330
}
301331

src/NHibernate/Dialect/Function/CastFunction.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,9 @@ public SqlString Render(IList args, ISessionFactoryImplementor factory)
5050
{
5151
throw new QueryException("invalid NHibernate type for cast(), was:" + typeName);
5252
}
53-
sqlType = factory.Dialect.GetCastTypeName(sqlTypeCodes[0]);
54-
if (sqlType == null)
53+
54+
if (!factory.Dialect.TryGetCastTypeName(sqlTypeCodes[0], out sqlType))
5555
{
56-
//TODO: never reached, since GetTypeName() actually throws an exception!
5756
sqlType = typeName;
5857
}
5958
//else

src/NHibernate/Dialect/TypeNames.cs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,24 @@ public class TypeNames
5757
/// <returns>the default type name associated with the specified key</returns>
5858
public string Get(DbType typecode)
5959
{
60-
if (!defaults.TryGetValue(typecode, out var result))
60+
if (!TryGet(typecode, out var result))
6161
{
6262
throw new ArgumentException("Dialect does not support DbType." + typecode, nameof(typecode));
6363
}
6464
return result;
6565
}
6666

67+
/// <summary>
68+
/// Get default type name for specified type.
69+
/// </summary>
70+
/// <param name="typecode">The type key.</param>
71+
/// <param name="typeName">The default type name that will be set in case it was found.</param>
72+
/// <returns>Whether the default type name was found.</returns>
73+
public bool TryGet(DbType typecode, out string typeName)
74+
{
75+
return defaults.TryGetValue(typecode, out typeName);
76+
}
77+
6778
/// <summary>
6879
/// Get the type name specified type and size
6980
/// </summary>
@@ -76,6 +87,28 @@ public string Get(DbType typecode)
7687
/// if available, otherwise the default type name.
7788
/// </returns>
7889
public string Get(DbType typecode, int size, int precision, int scale)
90+
{
91+
if (!TryGet(typecode, size, precision, scale, out var result))
92+
{
93+
throw new ArgumentException("Dialect does not support DbType." + typecode, nameof(typecode));
94+
}
95+
96+
return result;
97+
}
98+
99+
/// <summary>
100+
/// Get the type name specified type and size.
101+
/// </summary>
102+
/// <param name="typecode">The type key.</param>
103+
/// <param name="size">The SQL length.</param>
104+
/// <param name="scale">The SQL scale.</param>
105+
/// <param name="precision">The SQL precision.</param>
106+
/// <param name="typeName">
107+
/// The associated name with smallest capacity >= size (or precision for decimal, or scale for date time types)
108+
/// if available, otherwise the default type name.
109+
/// </param>
110+
/// <returns>Whether the type name was found.</returns>
111+
public bool TryGet(DbType typecode, int size, int precision, int scale, out string typeName)
79112
{
80113
weighted.TryGetValue(typecode, out var map);
81114
if (map != null && map.Count > 0)
@@ -88,7 +121,8 @@ public string Get(DbType typecode, int size, int precision, int scale)
88121
{
89122
if (requiredCapacity <= entry.Key)
90123
{
91-
return Replace(entry.Value, size, precision, scale);
124+
typeName = Replace(entry.Value, size, precision, scale);
125+
return true;
92126
}
93127
}
94128
if (isPrecisionType && precision != 0)
@@ -102,11 +136,12 @@ public string Get(DbType typecode, int size, int precision, int scale)
102136
// But if the type is used for storing amounts, this may cause losing the ability to store cents...
103137
// So better just reduce as few as possible.
104138
var adjustedScale = Math.Min(scale, adjustedPrecision);
105-
return Replace(maxEntry.Value, size, adjustedPrecision, adjustedScale);
139+
typeName = Replace(maxEntry.Value, size, adjustedPrecision, adjustedScale);
140+
return true;
106141
}
107142
}
108143
//Could not find a specific type for the capacity, using the default
109-
return Get(typecode);
144+
return TryGet(typecode, out typeName);
110145
}
111146

112147
/// <summary>

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,13 @@ constant.Value is CallSite site &&
243243
protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
244244
{
245245
var hqlExpression = VisitExpression(expression.Expression).AsExpression();
246-
hqlExpression = IsCastRequired(expression.Expression, expression.Type)
246+
hqlExpression = IsCastRequired(expression.Expression, expression.Type, out _)
247247
? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type)
248248
: _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type);
249249

250-
return IsCastRequired(expression.Type, "avg")
251-
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type)
252-
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Average(hqlExpression), expression.Type);
250+
// In Oracle the avg function can return a number with up to 40 digits which cannot be retrieved from the data reader due to the lack of such
251+
// numeric type in .NET. In order to avoid that we have to add a cast to trim the number so that it can be converted into a .NET numeric type.
252+
return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type);
253253
}
254254

255255
protected HqlTreeNode VisitNhCount(NhCountExpression expression)
@@ -269,7 +269,7 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
269269

270270
protected HqlTreeNode VisitNhSum(NhSumExpression expression)
271271
{
272-
return IsCastRequired(expression.Type, "sum")
272+
return IsCastRequired(expression.Type, "sum", out _)
273273
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type)
274274
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type);
275275
}
@@ -485,9 +485,12 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
485485
case ExpressionType.Convert:
486486
case ExpressionType.ConvertChecked:
487487
case ExpressionType.TypeAs:
488-
return IsCastRequired(expression.Operand, expression.Type)
488+
return IsCastRequired(expression.Operand, expression.Type, out var existType)
489489
? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type)
490-
: VisitExpression(expression.Operand);
490+
// Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader
491+
: existType
492+
? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), expression.Type)
493+
: VisitExpression(expression.Operand);
491494
}
492495

493496
throw new NotSupportedException(expression.ToString());
@@ -589,63 +592,75 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
589592
return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree);
590593
}
591594

592-
private bool IsCastRequired(Expression expression, System.Type toType)
595+
private bool IsCastRequired(Expression expression, System.Type toType, out bool existType)
593596
{
594-
return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType));
597+
existType = false;
598+
return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
595599
}
596600

597-
private bool IsCastRequired(IType type, IType toType)
601+
private bool IsCastRequired(IType type, IType toType, out bool existType)
598602
{
599603
// A type can be null when casting an entity into a base class, in that case we should not cast
600604
if (type == null || toType == null || Equals(type, toType))
601605
{
606+
existType = false;
602607
return false;
603608
}
604609

605610
var sqlTypes = type.SqlTypes(_parameters.SessionFactory);
606611
var toSqlTypes = toType.SqlTypes(_parameters.SessionFactory);
607612
if (sqlTypes.Length != 1 || toSqlTypes.Length != 1)
608613
{
614+
existType = false;
609615
return false; // Casting a multi-column type is not possible
610616
}
611617

618+
existType = true;
612619
if (sqlTypes[0].DbType == toSqlTypes[0].DbType)
613620
{
614621
return false;
615622
}
616623

617624
if (type.ReturnedClass.IsEnum && sqlTypes[0].DbType == DbType.String)
618625
{
626+
existType = false;
619627
return false; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value
620628
}
621629

622630
// Some dialects can map several sql types into one, cast only if the dialect types are different
623-
var castTypeName = _parameters.SessionFactory.Dialect.GetCastTypeName(sqlTypes[0]);
624-
var toCastTypeName = _parameters.SessionFactory.Dialect.GetCastTypeName(toSqlTypes[0]);
631+
if (!_parameters.SessionFactory.Dialect.TryGetCastTypeName(sqlTypes[0], out var castTypeName) ||
632+
!_parameters.SessionFactory.Dialect.TryGetCastTypeName(toSqlTypes[0], out var toCastTypeName))
633+
{
634+
return false; // The dialect does not support such cast
635+
}
636+
625637
return castTypeName != toCastTypeName;
626638
}
627639

628-
private bool IsCastRequired(System.Type type, string sqlFunctionName)
640+
private bool IsCastRequired(System.Type type, string sqlFunctionName, out bool existType)
629641
{
630642
if (type == typeof(object))
631643
{
644+
existType = false;
632645
return false;
633646
}
634647

635648
var toType = TypeFactory.GetDefaultTypeFor(type);
636649
if (toType == null)
637650
{
651+
existType = false;
638652
return true; // Fallback to the old behavior
639653
}
640654

655+
existType = true;
641656
var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName);
642657
if (sqlFunction == null)
643658
{
644659
return true; // Fallback to the old behavior
645660
}
646661

647662
var fnReturnType = sqlFunction.ReturnType(toType, _parameters.SessionFactory);
648-
return fnReturnType == null || IsCastRequired(fnReturnType, toType);
663+
return fnReturnType == null || IsCastRequired(fnReturnType, toType, out existType);
649664
}
650665

651666
private IType GetType(Expression expression)

0 commit comments

Comments
 (0)