Skip to content

Commit 171dd10

Browse files
committed
Extend the logic to be used for other aggregate functions
1 parent 9ce5cb2 commit 171dd10

File tree

1 file changed

+104
-23
lines changed

1 file changed

+104
-23
lines changed

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 104 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Data;
23
using System.Dynamic;
34
using System.Linq;
45
using System.Linq.Expressions;
@@ -242,10 +243,13 @@ constant.Value is CallSite site &&
242243
protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
243244
{
244245
var hqlExpression = VisitExpression(expression.Expression).AsExpression();
245-
if (expression.Type != expression.Expression.Type)
246-
hqlExpression = _hqlTreeBuilder.Cast(hqlExpression, expression.Type);
246+
hqlExpression = IsCastRequired(expression.Expression, expression.Type)
247+
? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type)
248+
: _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type);
247249

248-
return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type);
250+
return IsCastRequired(expression.Type, "avg")
251+
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type)
252+
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Average(hqlExpression), expression.Type);
249253
}
250254

251255
protected HqlTreeNode VisitNhCount(NhCountExpression expression)
@@ -265,17 +269,9 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
265269

266270
protected HqlTreeNode VisitNhSum(NhSumExpression expression)
267271
{
268-
var type = expression.Type.UnwrapIfNullable();
269-
var nhType = TypeFactory.GetDefaultTypeFor(type);
270-
if (nhType != null && _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("sum")
271-
?.ReturnType(nhType, _parameters.SessionFactory)?.ReturnedClass == type)
272-
{
273-
return _hqlTreeBuilder.TransparentCast(
274-
_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()),
275-
expression.Type);
276-
}
277-
278-
return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type);
272+
return IsCastRequired(expression.Type, "sum")
273+
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type)
274+
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type);
279275
}
280276

281277
protected HqlTreeNode VisitNhDistinct(NhDistinctExpression expression)
@@ -489,15 +485,9 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
489485
case ExpressionType.Convert:
490486
case ExpressionType.ConvertChecked:
491487
case ExpressionType.TypeAs:
492-
var operandType = expression.Operand.Type.UnwrapIfNullable();
493-
if ((operandType.IsPrimitive || operandType == typeof(decimal)) &&
494-
(expression.Type.IsPrimitive || expression.Type == typeof(decimal)) &&
495-
expression.Type != operandType)
496-
{
497-
return _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type);
498-
}
499-
500-
return VisitExpression(expression.Operand);
488+
return IsCastRequired(expression.Operand, expression.Type)
489+
? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type)
490+
: VisitExpression(expression.Operand);
501491
}
502492

503493
throw new NotSupportedException(expression.ToString());
@@ -598,5 +588,96 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
598588
var expressionSubTree = expression.Expressions.Select(exp => VisitExpression(exp)).ToArray();
599589
return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree);
600590
}
591+
592+
private bool IsCastRequired(Expression expression, System.Type toType)
593+
{
594+
return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType));
595+
}
596+
597+
private bool IsCastRequired(IType type, IType toType)
598+
{
599+
// A type can be null when casting an entity into a base class, in that case we should not cast
600+
if (type == null || toType == null || Equals(type, toType))
601+
{
602+
return false;
603+
}
604+
605+
var sqlTypes = type.SqlTypes(_parameters.SessionFactory);
606+
var toSqlTypes = toType.SqlTypes(_parameters.SessionFactory);
607+
if (sqlTypes.Length != 1 || toSqlTypes.Length != 1)
608+
{
609+
return false; // Casting a multi-column type is not possible
610+
}
611+
612+
if (type.ReturnedClass.IsEnum && sqlTypes[0].DbType == DbType.String)
613+
{
614+
return false; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value
615+
}
616+
617+
return sqlTypes[0].DbType != toSqlTypes[0].DbType;
618+
}
619+
620+
private bool IsCastRequired(System.Type type, string sqlFunctionName)
621+
{
622+
if (type == typeof(object))
623+
{
624+
return false;
625+
}
626+
627+
var toType = TypeFactory.GetDefaultTypeFor(type);
628+
if (toType == null)
629+
{
630+
return true; // Fallback to the old behavior
631+
}
632+
633+
var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName);
634+
if (sqlFunction == null)
635+
{
636+
return true; // Fallback to the old behavior
637+
}
638+
639+
var fnReturnType = sqlFunction.ReturnType(toType, _parameters.SessionFactory);
640+
return fnReturnType == null || IsCastRequired(fnReturnType, toType);
641+
}
642+
643+
private IType GetType(Expression expression)
644+
{
645+
if (!(expression is MemberExpression memberExpression))
646+
{
647+
return expression.Type != typeof(object)
648+
? TypeFactory.GetDefaultTypeFor(expression.Type)
649+
: null;
650+
}
651+
652+
// Try to get the mapped type for the member as it may be a non default one
653+
var entityName = TryGetEntityName(memberExpression);
654+
if (entityName == null)
655+
{
656+
return TypeFactory.GetDefaultTypeFor(expression.Type); // Not mapped
657+
}
658+
659+
var persister = _parameters.SessionFactory.GetEntityPersister(entityName);
660+
var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name);
661+
return !index.HasValue
662+
? TypeFactory.GetDefaultTypeFor(expression.Type) // Not mapped
663+
: persister.EntityMetamodel.PropertyTypes[index.Value];
664+
}
665+
666+
private string TryGetEntityName(MemberExpression memberExpression)
667+
{
668+
System.Type entityType;
669+
// Try to get the actual entity type from the query source if possbile as member can be declared
670+
// in a base type
671+
if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression)
672+
{
673+
entityType = querySourceReferenceExpression.Type;
674+
}
675+
else
676+
{
677+
entityType = memberExpression.Member.ReflectedType;
678+
}
679+
680+
return _parameters.SessionFactory.TryGetGuessEntityName(entityType);
681+
}
601682
}
602683
}

0 commit comments

Comments
 (0)