1
1
using System ;
2
+ using System . Data ;
2
3
using System . Dynamic ;
3
4
using System . Linq ;
4
5
using System . Linq . Expressions ;
@@ -242,10 +243,13 @@ constant.Value is CallSite site &&
242
243
protected HqlTreeNode VisitNhAverage ( NhAverageExpression expression )
243
244
{
244
245
var hqlExpression = VisitExpression ( expression . Expression ) . AsExpression ( ) ;
245
- if ( expression . Type != expression . Expression . Type )
246
- hqlExpression = _hqlTreeBuilder . Cast ( hqlExpression , expression . Type ) ;
246
+ hqlExpression = IsCastRequired ( expression . Expression , expression . Type )
247
+ ? ( HqlExpression ) _hqlTreeBuilder . Cast ( hqlExpression , expression . Type )
248
+ : _hqlTreeBuilder . TransparentCast ( hqlExpression , expression . Type ) ;
247
249
248
- return _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type ) ;
250
+ return IsCastRequired ( expression . Type , "avg" )
251
+ ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type )
252
+ : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type ) ;
249
253
}
250
254
251
255
protected HqlTreeNode VisitNhCount ( NhCountExpression expression )
@@ -265,17 +269,9 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
265
269
266
270
protected HqlTreeNode VisitNhSum ( NhSumExpression expression )
267
271
{
268
- var type = expression . Type . UnwrapIfNullable ( ) ;
269
- var nhType = TypeFactory . GetDefaultTypeFor ( type ) ;
270
- if ( nhType != null && _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( "sum" )
271
- ? . ReturnType ( nhType , _parameters . SessionFactory ) ? . ReturnedClass == type )
272
- {
273
- return _hqlTreeBuilder . TransparentCast (
274
- _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) ,
275
- expression . Type ) ;
276
- }
277
-
278
- return _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
272
+ return IsCastRequired ( expression . Type , "sum" )
273
+ ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type )
274
+ : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
279
275
}
280
276
281
277
protected HqlTreeNode VisitNhDistinct ( NhDistinctExpression expression )
@@ -489,15 +485,9 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
489
485
case ExpressionType . Convert :
490
486
case ExpressionType . ConvertChecked :
491
487
case ExpressionType . TypeAs :
492
- var operandType = expression . Operand . Type . UnwrapIfNullable ( ) ;
493
- if ( ( operandType . IsPrimitive || operandType == typeof ( decimal ) ) &&
494
- ( expression . Type . IsPrimitive || expression . Type == typeof ( decimal ) ) &&
495
- expression . Type != operandType )
496
- {
497
- return _hqlTreeBuilder . Cast ( VisitExpression ( expression . Operand ) . AsExpression ( ) , expression . Type ) ;
498
- }
499
-
500
- return VisitExpression ( expression . Operand ) ;
488
+ return IsCastRequired ( expression . Operand , expression . Type )
489
+ ? _hqlTreeBuilder . Cast ( VisitExpression ( expression . Operand ) . AsExpression ( ) , expression . Type )
490
+ : VisitExpression ( expression . Operand ) ;
501
491
}
502
492
503
493
throw new NotSupportedException ( expression . ToString ( ) ) ;
@@ -598,5 +588,96 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
598
588
var expressionSubTree = expression . Expressions . Select ( exp => VisitExpression ( exp ) ) . ToArray ( ) ;
599
589
return _hqlTreeBuilder . ExpressionSubTreeHolder ( expressionSubTree ) ;
600
590
}
591
+
592
+ private bool IsCastRequired ( Expression expression , System . Type toType )
593
+ {
594
+ return toType != typeof ( object ) && IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) ) ;
595
+ }
596
+
597
+ private bool IsCastRequired ( IType type , IType toType )
598
+ {
599
+ // A type can be null when casting an entity into a base class, in that case we should not cast
600
+ if ( type == null || toType == null || Equals ( type , toType ) )
601
+ {
602
+ return false ;
603
+ }
604
+
605
+ var sqlTypes = type . SqlTypes ( _parameters . SessionFactory ) ;
606
+ var toSqlTypes = toType . SqlTypes ( _parameters . SessionFactory ) ;
607
+ if ( sqlTypes . Length != 1 || toSqlTypes . Length != 1 )
608
+ {
609
+ return false ; // Casting a multi-column type is not possible
610
+ }
611
+
612
+ if ( type . ReturnedClass . IsEnum && sqlTypes [ 0 ] . DbType == DbType . String )
613
+ {
614
+ return false ; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value
615
+ }
616
+
617
+ return sqlTypes [ 0 ] . DbType != toSqlTypes [ 0 ] . DbType ;
618
+ }
619
+
620
+ private bool IsCastRequired ( System . Type type , string sqlFunctionName )
621
+ {
622
+ if ( type == typeof ( object ) )
623
+ {
624
+ return false ;
625
+ }
626
+
627
+ var toType = TypeFactory . GetDefaultTypeFor ( type ) ;
628
+ if ( toType == null )
629
+ {
630
+ return true ; // Fallback to the old behavior
631
+ }
632
+
633
+ var sqlFunction = _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( sqlFunctionName ) ;
634
+ if ( sqlFunction == null )
635
+ {
636
+ return true ; // Fallback to the old behavior
637
+ }
638
+
639
+ var fnReturnType = sqlFunction . ReturnType ( toType , _parameters . SessionFactory ) ;
640
+ return fnReturnType == null || IsCastRequired ( fnReturnType , toType ) ;
641
+ }
642
+
643
+ private IType GetType ( Expression expression )
644
+ {
645
+ if ( ! ( expression is MemberExpression memberExpression ) )
646
+ {
647
+ return expression . Type != typeof ( object )
648
+ ? TypeFactory . GetDefaultTypeFor ( expression . Type )
649
+ : null ;
650
+ }
651
+
652
+ // Try to get the mapped type for the member as it may be a non default one
653
+ var entityName = TryGetEntityName ( memberExpression ) ;
654
+ if ( entityName == null )
655
+ {
656
+ return TypeFactory . GetDefaultTypeFor ( expression . Type ) ; // Not mapped
657
+ }
658
+
659
+ var persister = _parameters . SessionFactory . GetEntityPersister ( entityName ) ;
660
+ var index = persister . EntityMetamodel . GetPropertyIndexOrNull ( memberExpression . Member . Name ) ;
661
+ return ! index . HasValue
662
+ ? TypeFactory . GetDefaultTypeFor ( expression . Type ) // Not mapped
663
+ : persister . EntityMetamodel . PropertyTypes [ index . Value ] ;
664
+ }
665
+
666
+ private string TryGetEntityName ( MemberExpression memberExpression )
667
+ {
668
+ System . Type entityType ;
669
+ // Try to get the actual entity type from the query source if possbile as member can be declared
670
+ // in a base type
671
+ if ( memberExpression . Expression is QuerySourceReferenceExpression querySourceReferenceExpression )
672
+ {
673
+ entityType = querySourceReferenceExpression . Type ;
674
+ }
675
+ else
676
+ {
677
+ entityType = memberExpression . Member . ReflectedType ;
678
+ }
679
+
680
+ return _parameters . SessionFactory . TryGetGuessEntityName ( entityType ) ;
681
+ }
601
682
}
602
683
}
0 commit comments