diff --git a/src/NHibernate.Test/Hql/HQLFunctions.cs b/src/NHibernate.Test/Hql/HQLFunctions.cs index d9845162689..2d54fc6ebaf 100644 --- a/src/NHibernate.Test/Hql/HQLFunctions.cs +++ b/src/NHibernate.Test/Hql/HQLFunctions.cs @@ -568,6 +568,65 @@ public void Round() } } + [Test] + public void Truncate() + { + AssumeFunctionSupported("truncate"); + + using (var s = OpenSession()) + { + var a1 = new Animal("a1", 1.87f); + s.Save(a1); + var m1 = new MaterialResource("m1", "18", MaterialResource.MaterialState.Available) { Cost = 51.76m }; + s.Save(m1); + s.Flush(); + } + using (var s = OpenSession()) + { + var roundF = s.CreateQuery("select truncate(a.BodyWeight) from Animal a").UniqueResult(); + Assert.That(roundF, Is.EqualTo(1), "Selecting truncate(double) failed."); + var countF = + s + .CreateQuery("select count(*) from Animal a where truncate(a.BodyWeight) = :c") + .SetInt32("c", 1) + .UniqueResult(); + Assert.That(countF, Is.EqualTo(1), "Filtering truncate(double) failed."); + + roundF = s.CreateQuery("select truncate(a.BodyWeight, 1) from Animal a").UniqueResult(); + Assert.That(roundF, Is.EqualTo(1.8f).Within(0.01f), "Selecting truncate(double, 1) failed."); + countF = + s + .CreateQuery("select count(*) from Animal a where truncate(a.BodyWeight, 1) between :c1 and :c2") + .SetDouble("c1", 1.79) + .SetDouble("c2", 1.81) + .UniqueResult(); + Assert.That(countF, Is.EqualTo(1), "Filtering truncate(double, 1) failed."); + + var roundD = s.CreateQuery("select truncate(m.Cost) from MaterialResource m").UniqueResult(); + Assert.That(roundD, Is.EqualTo(51), "Selecting truncate(decimal) failed."); + var count = + s + .CreateQuery("select count(*) from MaterialResource m where truncate(m.Cost) = :c") + .SetInt32("c", 51) + .UniqueResult(); + Assert.That(count, Is.EqualTo(1), "Filtering truncate(decimal) failed."); + + roundD = s.CreateQuery("select truncate(m.Cost, 1) from MaterialResource m").UniqueResult(); + Assert.That(roundD, Is.EqualTo(51.7m), "Selecting truncate(decimal, 1) failed."); + + if (TestDialect.HasBrokenDecimalType) + // SQLite fails the equality test due to using double instead, wich requires a tolerance. + return; + + count = + s + .CreateQuery("select count(*) from MaterialResource m where truncate(m.Cost, 1) = :c") + .SetDecimal("c", 51.7m) + .UniqueResult(); + Assert.That(count, Is.EqualTo(1), "Filtering truncate(decimal, 1) failed."); + } + } + [Test] public void Mod() { diff --git a/src/NHibernate/Dialect/Function/RoundFunction.cs b/src/NHibernate/Dialect/Function/RoundFunction.cs deleted file mode 100644 index 58e7c203208..00000000000 --- a/src/NHibernate/Dialect/Function/RoundFunction.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System.Collections; -using NHibernate.Engine; -using NHibernate.SqlCommand; -using NHibernate.Type; -using System; - -namespace NHibernate.Dialect.Function -{ - /// - /// Provides a round implementation that supports single parameter round by translating to two parameters round. - /// - [Serializable] - public class RoundEmulatingSingleParameterFunction : ISQLFunction - { - private static readonly ISQLFunction SingleParamRound = new SQLFunctionTemplate(null, "round(?1, 0)"); - - private static readonly ISQLFunction Round = new StandardSQLFunction("round"); - - public IType ReturnType(IType columnType, IMapping mapping) => columnType; - - public bool HasArguments => true; - - public bool HasParenthesesIfNoArguments => true; - - public SqlString Render(IList args, ISessionFactoryImplementor factory) - { - return args.Count == 1 ? SingleParamRound.Render(args, factory) : Round.Render(args, factory); - } - - public override string ToString() => "round"; - } -} diff --git a/src/NHibernate/Dialect/Function/StandardSQLFunctionWithRequiredParameters.cs b/src/NHibernate/Dialect/Function/StandardSQLFunctionWithRequiredParameters.cs new file mode 100644 index 00000000000..8f82fb4a8e5 --- /dev/null +++ b/src/NHibernate/Dialect/Function/StandardSQLFunctionWithRequiredParameters.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections; +using System.Linq; +using NHibernate.Engine; +using NHibernate.SqlCommand; +using NHibernate.Type; + +namespace NHibernate.Dialect.Function +{ + /// + /// A SQL function which substitutes required missing parameters with defaults. + /// + [Serializable] + internal class StandardSQLFunctionWithRequiredParameters : StandardSQLFunction + { + private readonly object[] _requiredArgs; + + /// + public StandardSQLFunctionWithRequiredParameters(string name, object[] requiredArgs) + : base(name) + { + _requiredArgs = requiredArgs; + } + + /// + public StandardSQLFunctionWithRequiredParameters(string name, IType typeValue, object[] requiredArgs) + : base(name, typeValue) + { + _requiredArgs = requiredArgs; + } + + /// + public override SqlString Render(IList args, ISessionFactoryImplementor factory) + { + var combinedArgs = + args.Cast() + .Concat(_requiredArgs.Skip(args.Count)) + .ToArray(); + return base.Render(combinedArgs, factory); + } + } +} diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs index 3553d5923a6..25d973c66b6 100644 --- a/src/NHibernate/Dialect/MsSql2000Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs @@ -286,8 +286,8 @@ protected virtual void RegisterFunctions() RegisterFunction("ceiling", new StandardSQLFunction("ceiling")); RegisterFunction("ceil", new StandardSQLFunction("ceiling")); RegisterFunction("floor", new StandardSQLFunction("floor")); - RegisterFunction("round", new RoundEmulatingSingleParameterFunction()); - RegisterFunction("truncate", new SQLFunctionTemplate(null, "round(?1, ?2, 1)")); + RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"})); + RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0", "1"})); RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/MsSqlCeDialect.cs b/src/NHibernate/Dialect/MsSqlCeDialect.cs index 3aebc47e5e3..6feb7b0c889 100644 --- a/src/NHibernate/Dialect/MsSqlCeDialect.cs +++ b/src/NHibernate/Dialect/MsSqlCeDialect.cs @@ -194,8 +194,8 @@ protected virtual void RegisterFunctions() RegisterFunction("concat", new VarArgsSQLFunction(NHibernateUtil.String, "(", "+", ")")); RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))")); - RegisterFunction("round", new RoundEmulatingSingleParameterFunction()); - RegisterFunction("truncate", new SQLFunctionTemplate(null, "round(?1, ?2, 1)")); + RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"})); + RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0", "1"})); RegisterFunction("bit_length", new SQLFunctionTemplate(NHibernateUtil.Int32, "datalength(?1) * 8")); RegisterFunction("extract", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(?1, ?3)")); diff --git a/src/NHibernate/Dialect/MySQLDialect.cs b/src/NHibernate/Dialect/MySQLDialect.cs index 3292dabaffd..bd76c80adee 100644 --- a/src/NHibernate/Dialect/MySQLDialect.cs +++ b/src/NHibernate/Dialect/MySQLDialect.cs @@ -261,7 +261,7 @@ protected virtual void RegisterFunctions() RegisterFunction("ceiling", new StandardSQLFunction("ceiling")); RegisterFunction("floor", new StandardSQLFunction("floor")); RegisterFunction("round", new StandardSQLFunction("round")); - RegisterFunction("truncate", new StandardSafeSQLFunction("truncate", 2)); + RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("truncate", new object[] {null, "0"})); RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/PostgreSQLDialect.cs b/src/NHibernate/Dialect/PostgreSQLDialect.cs index 33510e20378..70169cc14fb 100644 --- a/src/NHibernate/Dialect/PostgreSQLDialect.cs +++ b/src/NHibernate/Dialect/PostgreSQLDialect.cs @@ -69,7 +69,9 @@ public PostgreSQLDialect() RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))")); RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32)); - RegisterFunction("round", new RoundFunction()); + RegisterFunction("round", new RoundFunction(false)); + RegisterFunction("truncate", new RoundFunction(true)); + RegisterFunction("trunc", new RoundFunction(true)); // Trigonometric functions. RegisterFunction("acos", new StandardSQLFunction("acos", NHibernateUtil.Double)); @@ -322,12 +324,34 @@ public override string CurrentTimestampSelectString private class RoundFunction : ISQLFunction { private static readonly ISQLFunction Round = new StandardSQLFunction("round"); + private static readonly ISQLFunction Truncate = new StandardSQLFunction("trunc"); - // PostgreSQL round with two arguments only accepts decimal as input, thus the cast. + // PostgreSQL round/trunc with two arguments only accepts decimal as input, thus the cast. // It also yields only decimal, but for emulating similar behavior to other databases, we need // to have it converted to the original input type, which will be done by NHibernate thanks to // not specifying the function type. private static readonly ISQLFunction RoundWith2Params = new SQLFunctionTemplate(null, "round(cast(?1 as numeric), ?2)"); + private static readonly ISQLFunction TruncateWith2Params = new SQLFunctionTemplate(null, "trunc(cast(?1 as numeric), ?2)"); + + private readonly ISQLFunction _singleParamFunction; + private readonly ISQLFunction _twoParamFunction; + private readonly string _name; + + public RoundFunction(bool truncate) + { + if (truncate) + { + _singleParamFunction = Truncate; + _twoParamFunction = TruncateWith2Params; + _name = "truncate"; + } + else + { + _singleParamFunction = Round; + _twoParamFunction = RoundWith2Params; + _name = "round"; + } + } public IType ReturnType(IType columnType, IMapping mapping) => columnType; @@ -337,10 +361,10 @@ private class RoundFunction : ISQLFunction public SqlString Render(IList args, ISessionFactoryImplementor factory) { - return args.Count == 2 ? RoundWith2Params.Render(args, factory) : Round.Render(args, factory); + return args.Count == 2 ? _twoParamFunction.Render(args, factory) : _singleParamFunction.Render(args, factory); } - public override string ToString() => "round"; + public override string ToString() => _name; } } } diff --git a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs index 09f8c6491b6..4fe1c0c1faf 100644 --- a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs +++ b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs @@ -140,12 +140,13 @@ protected virtual void RegisterMathFunctions() RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double)); RegisterFunction("rand", new StandardSQLFunction("rand", NHibernateUtil.Double)); RegisterFunction("remainder", new StandardSQLFunction("remainder")); - RegisterFunction("round", new StandardSQLFunction("round")); + RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"})); RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32)); RegisterFunction("sin", new StandardSQLFunction("sin", NHibernateUtil.Double)); RegisterFunction("sqrt", new StandardSQLFunction("sqrt", NHibernateUtil.Double)); RegisterFunction("tan", new StandardSQLFunction("tan", NHibernateUtil.Double)); - RegisterFunction("truncate", new StandardSQLFunction("truncate")); + RegisterFunction("truncnum", new StandardSQLFunctionWithRequiredParameters("truncnum", new object[] {null, "0"})); + RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("truncnum", new object[] {null, "0"})); } protected virtual void RegisterXmlFunctions() @@ -343,8 +344,6 @@ protected virtual void RegisterMiscellaneousFunctions() RegisterFunction("transactsql", new StandardSQLFunction("transactsql", NHibernateUtil.String)); RegisterFunction("varexists", new StandardSQLFunction("varexists", NHibernateUtil.Int32)); RegisterFunction("watcomsql", new StandardSQLFunction("watcomsql", NHibernateUtil.String)); - RegisterFunction("truncnum", new StandardSafeSQLFunction("truncnum", 2)); - RegisterFunction("truncate", new StandardSafeSQLFunction("truncnum", 2)); } #region private static readonly string[] DialectKeywords = { ... }