diff --git a/src/NHibernate.Test/Async/Criteria/Lambda/ProjectIntegrationFixture.cs b/src/NHibernate.Test/Async/Criteria/Lambda/ProjectIntegrationFixture.cs index 446a0b1833a..e1c2cd5c089 100644 --- a/src/NHibernate.Test/Async/Criteria/Lambda/ProjectIntegrationFixture.cs +++ b/src/NHibernate.Test/Async/Criteria/Lambda/ProjectIntegrationFixture.cs @@ -134,5 +134,43 @@ var actual Assert.That((int) actual[1], Is.EqualTo(2), "distinct count by name"); } } + + [Test] + public async Task ProjectionCanCoalesceInSelectAsync() + { + using (var s = OpenSession()) + using (s.BeginTransaction()) + { + var actual + = (await (s.QueryOver() + .Select(x => x.Age.Coalesce(0)) + .Where(x => x.Age == 20) + .ListAsync())).FirstOrDefault(); + + Assert.That(actual, Is.EqualTo(20)); + } + } + + //NH-2983 + [Test] + public async Task ProjectionSelectSumOnCoalesceAsync() + { + using (var s = OpenSession()) + using (s.BeginTransaction()) + { + var actual + = (await (s.QueryOver() + .SelectList( + l => + l + .SelectSum(xx => xx.Age.Coalesce(0)) + .SelectSum(xx => xx.Age.Coalesce(1))) + .Where(x => x.Age == 20) + .ListAsync())).FirstOrDefault(); + + Assert.That(actual[0], Is.EqualTo(20)); + Assert.That(actual[1], Is.EqualTo(20)); + } + } } } diff --git a/src/NHibernate.Test/Criteria/Lambda/ProjectIntegrationFixture.cs b/src/NHibernate.Test/Criteria/Lambda/ProjectIntegrationFixture.cs index afecce3b193..b62e6ac98d0 100644 --- a/src/NHibernate.Test/Criteria/Lambda/ProjectIntegrationFixture.cs +++ b/src/NHibernate.Test/Criteria/Lambda/ProjectIntegrationFixture.cs @@ -123,5 +123,43 @@ var actual Assert.That((int) actual[1], Is.EqualTo(2), "distinct count by name"); } } + + [Test] + public void ProjectionCanCoalesceInSelect() + { + using (var s = OpenSession()) + using (s.BeginTransaction()) + { + var actual + = s.QueryOver() + .Select(x => x.Age.Coalesce(0)) + .Where(x => x.Age == 20) + .List().FirstOrDefault(); + + Assert.That(actual, Is.EqualTo(20)); + } + } + + //NH-2983 + [Test] + public void ProjectionSelectSumOnCoalesce() + { + using (var s = OpenSession()) + using (s.BeginTransaction()) + { + var actual + = s.QueryOver() + .SelectList( + l => + l + .SelectSum(xx => xx.Age.Coalesce(0)) + .SelectSum(xx => xx.Age.Coalesce(1))) + .Where(x => x.Age == 20) + .List().FirstOrDefault(); + + Assert.That(actual[0], Is.EqualTo(20)); + Assert.That(actual[1], Is.EqualTo(20)); + } + } } } diff --git a/src/NHibernate.Test/Criteria/Lambda/RestrictionsFixture.cs b/src/NHibernate.Test/Criteria/Lambda/RestrictionsFixture.cs index 101f360315d..9d59d86d77a 100644 --- a/src/NHibernate.Test/Criteria/Lambda/RestrictionsFixture.cs +++ b/src/NHibernate.Test/Criteria/Lambda/RestrictionsFixture.cs @@ -287,9 +287,9 @@ public void FunctionExtensions() .Add(Restrictions.Eq(Projections.SqlFunction("substring", NHibernateUtil.String, Projections.Property("Name"), Projections.Property("Age"), Projections.Constant(2)), "te")) .Add(Restrictions.Eq(Projections.SqlFunction("locate", NHibernateUtil.String, Projections.Constant("e"), Projections.Property("Name"), Projections.Constant(1)), 2)) .Add(Restrictions.Eq(Projections.SqlFunction("locate", NHibernateUtil.String, Projections.Constant("e"), Projections.Property("Name"), Projections.Property("Age")), 2)) - .Add(Restrictions.Eq(Projections.SqlFunction("coalesce", NHibernateUtil.Object, Projections.Property("Name"), Projections.Constant("not-null-val")), "test")) - .Add(Restrictions.Eq(Projections.SqlFunction("coalesce", NHibernateUtil.Object, Projections.Property("Name"), Projections.Property("Nickname")), "test")) - .Add(Restrictions.Eq(Projections.SqlFunction("coalesce", NHibernateUtil.Object, Projections.Property("NullableIsParent"), Projections.Constant(true)), true)) + .Add(Restrictions.Eq(new SqlFunctionProjection("coalesce", Projections.Property("Name"), Projections.Property("Name"), Projections.Constant("not-null-val")), "test")) + .Add(Restrictions.Eq(new SqlFunctionProjection("coalesce", Projections.Property("Name"), Projections.Property("Name"), Projections.Property("Nickname")), "test")) + .Add(Restrictions.Eq(new SqlFunctionProjection("coalesce", Projections.Property("NullableIsParent"), Projections.Property("NullableIsParent"), Projections.Constant(true)), true)) .Add(Restrictions.Eq(Projections.SqlFunction("concat", NHibernateUtil.String, Projections.Property("Name"), Projections.Constant(", "), Projections.Property("Name")), "test, test")) .Add(Restrictions.Eq(Projections.SqlFunction("mod", NHibernateUtil.Int32, Projections.Property("Height"), Projections.Constant(10)), 0)) .Add(Restrictions.Eq(Projections.SqlFunction("mod", NHibernateUtil.Int32, Projections.Property("Height"), Projections.Property("Age")), 0)); diff --git a/src/NHibernate/Criterion/ProjectionsExtensions.cs b/src/NHibernate/Criterion/ProjectionsExtensions.cs index 5ef9157e847..7d7067ea98b 100644 --- a/src/NHibernate/Criterion/ProjectionsExtensions.cs +++ b/src/NHibernate/Criterion/ProjectionsExtensions.cs @@ -316,7 +316,7 @@ internal static IProjection ProcessCoalesce(MethodCallExpression methodCallExpre { IProjection property = ExpressionProcessor.FindMemberProjection(methodCallExpression.Arguments[0]).AsProjection(); var replaceValueIfIsNull = ExpressionProcessor.FindMemberProjection(methodCallExpression.Arguments[1]); - return Projections.SqlFunction("coalesce", NHibernateUtil.Object, property, replaceValueIfIsNull.AsProjection()); + return new SqlFunctionProjection("coalesce", returnTypeProjection: property, property, replaceValueIfIsNull.AsProjection()); } /// diff --git a/src/NHibernate/Criterion/SqlFunctionProjection.cs b/src/NHibernate/Criterion/SqlFunctionProjection.cs index c958ce055f9..95d5298fa77 100644 --- a/src/NHibernate/Criterion/SqlFunctionProjection.cs +++ b/src/NHibernate/Criterion/SqlFunctionProjection.cs @@ -1,11 +1,10 @@ using System; -using System.Collections; using System.Collections.Generic; +using System.Linq; using NHibernate.Dialect.Function; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; -using NHibernate.Util; namespace NHibernate.Criterion { @@ -16,6 +15,7 @@ public class SqlFunctionProjection : SimpleProjection private readonly ISQLFunction function; private readonly string functionName; private readonly IType returnType; + private readonly IProjection returnTypeProjection; public SqlFunctionProjection(string functionName, IType returnType, params IProjection[] args) { @@ -31,6 +31,13 @@ public SqlFunctionProjection(ISQLFunction function, IType returnType, params IPr this.args = args; } + public SqlFunctionProjection(string functionName, IProjection returnTypeProjection, params IProjection[] args) + { + this.functionName = functionName; + this.returnTypeProjection = returnTypeProjection; + this.args = args; + } + public override bool IsAggregate { get { return false; } @@ -107,10 +114,18 @@ private static SqlString GetProjectionArgument(ICriteriaQuery criteriaQuery, ICr } public override IType[] GetTypes(ICriteria criteria, ICriteriaQuery criteriaQuery) + { + var type = GetReturnType(criteria, criteriaQuery); + return type != null ? new[] {type} : Array.Empty(); + } + + private IType GetReturnType(ICriteria criteria, ICriteriaQuery criteriaQuery) { ISQLFunction sqlFunction = GetFunction(criteriaQuery); - IType type = sqlFunction.ReturnType(returnType, criteriaQuery.Factory); - return new IType[] {type}; + + var resultType = returnType ?? returnTypeProjection?.GetTypes(criteria, criteriaQuery).FirstOrDefault(); + + return sqlFunction.ReturnType(resultType, criteriaQuery.Factory); } public override TypedValue[] GetTypedValues(ICriteria criteria, ICriteriaQuery criteriaQuery)