Skip to content

Commit d2e8cc3

Browse files
Add support for single-argument truncate to dialects that do not support it natively (#1597)
- Add truncate to PostgreSQLDialect - Add support for single-argument round to SybaseSQLAnywhere10Dialect Co-authored-by: Alexander Zaytsev <hazzik@users.noreply.github.com>
1 parent 8993a44 commit d2e8cc3

File tree

8 files changed

+137
-45
lines changed

8 files changed

+137
-45
lines changed

src/NHibernate.Test/Hql/HQLFunctions.cs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,65 @@ public void Round()
568568
}
569569
}
570570

571+
[Test]
572+
public void Truncate()
573+
{
574+
AssumeFunctionSupported("truncate");
575+
576+
using (var s = OpenSession())
577+
{
578+
var a1 = new Animal("a1", 1.87f);
579+
s.Save(a1);
580+
var m1 = new MaterialResource("m1", "18", MaterialResource.MaterialState.Available) { Cost = 51.76m };
581+
s.Save(m1);
582+
s.Flush();
583+
}
584+
using (var s = OpenSession())
585+
{
586+
var roundF = s.CreateQuery("select truncate(a.BodyWeight) from Animal a").UniqueResult<float>();
587+
Assert.That(roundF, Is.EqualTo(1), "Selecting truncate(double) failed.");
588+
var countF =
589+
s
590+
.CreateQuery("select count(*) from Animal a where truncate(a.BodyWeight) = :c")
591+
.SetInt32("c", 1)
592+
.UniqueResult<long>();
593+
Assert.That(countF, Is.EqualTo(1), "Filtering truncate(double) failed.");
594+
595+
roundF = s.CreateQuery("select truncate(a.BodyWeight, 1) from Animal a").UniqueResult<float>();
596+
Assert.That(roundF, Is.EqualTo(1.8f).Within(0.01f), "Selecting truncate(double, 1) failed.");
597+
countF =
598+
s
599+
.CreateQuery("select count(*) from Animal a where truncate(a.BodyWeight, 1) between :c1 and :c2")
600+
.SetDouble("c1", 1.79)
601+
.SetDouble("c2", 1.81)
602+
.UniqueResult<long>();
603+
Assert.That(countF, Is.EqualTo(1), "Filtering truncate(double, 1) failed.");
604+
605+
var roundD = s.CreateQuery("select truncate(m.Cost) from MaterialResource m").UniqueResult<decimal?>();
606+
Assert.That(roundD, Is.EqualTo(51), "Selecting truncate(decimal) failed.");
607+
var count =
608+
s
609+
.CreateQuery("select count(*) from MaterialResource m where truncate(m.Cost) = :c")
610+
.SetInt32("c", 51)
611+
.UniqueResult<long>();
612+
Assert.That(count, Is.EqualTo(1), "Filtering truncate(decimal) failed.");
613+
614+
roundD = s.CreateQuery("select truncate(m.Cost, 1) from MaterialResource m").UniqueResult<decimal?>();
615+
Assert.That(roundD, Is.EqualTo(51.7m), "Selecting truncate(decimal, 1) failed.");
616+
617+
if (TestDialect.HasBrokenDecimalType)
618+
// SQLite fails the equality test due to using double instead, wich requires a tolerance.
619+
return;
620+
621+
count =
622+
s
623+
.CreateQuery("select count(*) from MaterialResource m where truncate(m.Cost, 1) = :c")
624+
.SetDecimal("c", 51.7m)
625+
.UniqueResult<long>();
626+
Assert.That(count, Is.EqualTo(1), "Filtering truncate(decimal, 1) failed.");
627+
}
628+
}
629+
571630
[Test]
572631
public void Mod()
573632
{

src/NHibernate/Dialect/Function/RoundFunction.cs

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using System;
2+
using System.Collections;
3+
using System.Linq;
4+
using NHibernate.Engine;
5+
using NHibernate.SqlCommand;
6+
using NHibernate.Type;
7+
8+
namespace NHibernate.Dialect.Function
9+
{
10+
/// <summary>
11+
/// A SQL function which substitutes required missing parameters with defaults.
12+
/// </summary>
13+
[Serializable]
14+
internal class StandardSQLFunctionWithRequiredParameters : StandardSQLFunction
15+
{
16+
private readonly object[] _requiredArgs;
17+
18+
/// <inheritdoc />
19+
public StandardSQLFunctionWithRequiredParameters(string name, object[] requiredArgs)
20+
: base(name)
21+
{
22+
_requiredArgs = requiredArgs;
23+
}
24+
25+
/// <inheritdoc />
26+
public StandardSQLFunctionWithRequiredParameters(string name, IType typeValue, object[] requiredArgs)
27+
: base(name, typeValue)
28+
{
29+
_requiredArgs = requiredArgs;
30+
}
31+
32+
/// <inheritdoc />
33+
public override SqlString Render(IList args, ISessionFactoryImplementor factory)
34+
{
35+
var combinedArgs =
36+
args.Cast<object>()
37+
.Concat(_requiredArgs.Skip(args.Count))
38+
.ToArray();
39+
return base.Render(combinedArgs, factory);
40+
}
41+
}
42+
}

src/NHibernate/Dialect/MsSql2000Dialect.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ protected virtual void RegisterFunctions()
286286
RegisterFunction("ceiling", new StandardSQLFunction("ceiling"));
287287
RegisterFunction("ceil", new StandardSQLFunction("ceiling"));
288288
RegisterFunction("floor", new StandardSQLFunction("floor"));
289-
RegisterFunction("round", new RoundEmulatingSingleParameterFunction());
290-
RegisterFunction("truncate", new SQLFunctionTemplate(null, "round(?1, ?2, 1)"));
289+
RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"}));
290+
RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0", "1"}));
291291

292292
RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double));
293293

src/NHibernate/Dialect/MsSqlCeDialect.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ protected virtual void RegisterFunctions()
194194
RegisterFunction("concat", new VarArgsSQLFunction(NHibernateUtil.String, "(", "+", ")"));
195195
RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))"));
196196

197-
RegisterFunction("round", new RoundEmulatingSingleParameterFunction());
198-
RegisterFunction("truncate", new SQLFunctionTemplate(null, "round(?1, ?2, 1)"));
197+
RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"}));
198+
RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0", "1"}));
199199

200200
RegisterFunction("bit_length", new SQLFunctionTemplate(NHibernateUtil.Int32, "datalength(?1) * 8"));
201201
RegisterFunction("extract", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(?1, ?3)"));

src/NHibernate/Dialect/MySQLDialect.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ protected virtual void RegisterFunctions()
261261
RegisterFunction("ceiling", new StandardSQLFunction("ceiling"));
262262
RegisterFunction("floor", new StandardSQLFunction("floor"));
263263
RegisterFunction("round", new StandardSQLFunction("round"));
264-
RegisterFunction("truncate", new StandardSafeSQLFunction("truncate", 2));
264+
RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("truncate", new object[] {null, "0"}));
265265

266266
RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double));
267267

src/NHibernate/Dialect/PostgreSQLDialect.cs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ public PostgreSQLDialect()
6969
RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))"));
7070

7171
RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32));
72-
RegisterFunction("round", new RoundFunction());
72+
RegisterFunction("round", new RoundFunction(false));
73+
RegisterFunction("truncate", new RoundFunction(true));
74+
RegisterFunction("trunc", new RoundFunction(true));
7375

7476
// Trigonometric functions.
7577
RegisterFunction("acos", new StandardSQLFunction("acos", NHibernateUtil.Double));
@@ -322,12 +324,34 @@ public override string CurrentTimestampSelectString
322324
private class RoundFunction : ISQLFunction
323325
{
324326
private static readonly ISQLFunction Round = new StandardSQLFunction("round");
327+
private static readonly ISQLFunction Truncate = new StandardSQLFunction("trunc");
325328

326-
// PostgreSQL round with two arguments only accepts decimal as input, thus the cast.
329+
// PostgreSQL round/trunc with two arguments only accepts decimal as input, thus the cast.
327330
// It also yields only decimal, but for emulating similar behavior to other databases, we need
328331
// to have it converted to the original input type, which will be done by NHibernate thanks to
329332
// not specifying the function type.
330333
private static readonly ISQLFunction RoundWith2Params = new SQLFunctionTemplate(null, "round(cast(?1 as numeric), ?2)");
334+
private static readonly ISQLFunction TruncateWith2Params = new SQLFunctionTemplate(null, "trunc(cast(?1 as numeric), ?2)");
335+
336+
private readonly ISQLFunction _singleParamFunction;
337+
private readonly ISQLFunction _twoParamFunction;
338+
private readonly string _name;
339+
340+
public RoundFunction(bool truncate)
341+
{
342+
if (truncate)
343+
{
344+
_singleParamFunction = Truncate;
345+
_twoParamFunction = TruncateWith2Params;
346+
_name = "truncate";
347+
}
348+
else
349+
{
350+
_singleParamFunction = Round;
351+
_twoParamFunction = RoundWith2Params;
352+
_name = "round";
353+
}
354+
}
331355

332356
public IType ReturnType(IType columnType, IMapping mapping) => columnType;
333357

@@ -337,10 +361,10 @@ private class RoundFunction : ISQLFunction
337361

338362
public SqlString Render(IList args, ISessionFactoryImplementor factory)
339363
{
340-
return args.Count == 2 ? RoundWith2Params.Render(args, factory) : Round.Render(args, factory);
364+
return args.Count == 2 ? _twoParamFunction.Render(args, factory) : _singleParamFunction.Render(args, factory);
341365
}
342366

343-
public override string ToString() => "round";
367+
public override string ToString() => _name;
344368
}
345369
}
346370
}

src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,13 @@ protected virtual void RegisterMathFunctions()
140140
RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double));
141141
RegisterFunction("rand", new StandardSQLFunction("rand", NHibernateUtil.Double));
142142
RegisterFunction("remainder", new StandardSQLFunction("remainder"));
143-
RegisterFunction("round", new StandardSQLFunction("round"));
143+
RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"}));
144144
RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32));
145145
RegisterFunction("sin", new StandardSQLFunction("sin", NHibernateUtil.Double));
146146
RegisterFunction("sqrt", new StandardSQLFunction("sqrt", NHibernateUtil.Double));
147147
RegisterFunction("tan", new StandardSQLFunction("tan", NHibernateUtil.Double));
148-
RegisterFunction("truncate", new StandardSQLFunction("truncate"));
148+
RegisterFunction("truncnum", new StandardSQLFunctionWithRequiredParameters("truncnum", new object[] {null, "0"}));
149+
RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("truncnum", new object[] {null, "0"}));
149150
}
150151

151152
protected virtual void RegisterXmlFunctions()
@@ -343,8 +344,6 @@ protected virtual void RegisterMiscellaneousFunctions()
343344
RegisterFunction("transactsql", new StandardSQLFunction("transactsql", NHibernateUtil.String));
344345
RegisterFunction("varexists", new StandardSQLFunction("varexists", NHibernateUtil.Int32));
345346
RegisterFunction("watcomsql", new StandardSQLFunction("watcomsql", NHibernateUtil.String));
346-
RegisterFunction("truncnum", new StandardSafeSQLFunction("truncnum", 2));
347-
RegisterFunction("truncate", new StandardSafeSQLFunction("truncnum", 2));
348347
}
349348

350349
#region private static readonly string[] DialectKeywords = { ... }

0 commit comments

Comments
 (0)