Skip to content

Commit c0f5042

Browse files
maca88hazzik
andauthored
Reduce cast usage for COUNT aggregate and add support for Mssql count_big (#2061)
Co-authored-by: Alexander Zaytsev <hazzik@gmail.com>
1 parent 5cf0cf4 commit c0f5042

File tree

15 files changed

+270
-28
lines changed

15 files changed

+270
-28
lines changed

src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using System;
1212
using System.Linq;
1313
using NHibernate.Cfg;
14+
using NHibernate.Dialect;
1415
using NUnit.Framework;
1516
using NHibernate.Linq;
1617

@@ -110,5 +111,50 @@ into temp
110111

111112
Assert.That(result.Count, Is.EqualTo(77));
112113
}
114+
115+
[Test]
116+
public async Task CheckSqlFunctionNameLongCountAsync()
117+
{
118+
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
119+
using (var sqlLog = new SqlLogSpy())
120+
{
121+
var result = await (db.Orders.LongCountAsync());
122+
Assert.That(result, Is.EqualTo(830));
123+
124+
var log = sqlLog.GetWholeLog();
125+
Assert.That(log, Does.Contain($"{name}("));
126+
}
127+
}
128+
129+
[Test]
130+
public async Task CheckSqlFunctionNameForCountAsync()
131+
{
132+
using (var sqlLog = new SqlLogSpy())
133+
{
134+
var result = await (db.Orders.CountAsync());
135+
Assert.That(result, Is.EqualTo(830));
136+
137+
var log = sqlLog.GetWholeLog();
138+
Assert.That(log, Does.Contain("count("));
139+
}
140+
}
141+
142+
[Test]
143+
public async Task CheckMssqlCountCastAsync()
144+
{
145+
if (!(Dialect is MsSql2000Dialect))
146+
{
147+
Assert.Ignore();
148+
}
149+
150+
using (var sqlLog = new SqlLogSpy())
151+
{
152+
var result = await (db.Orders.CountAsync());
153+
Assert.That(result, Is.EqualTo(830));
154+
155+
var log = sqlLog.GetWholeLog();
156+
Assert.That(log, Does.Not.Contain("cast("));
157+
}
158+
}
113159
}
114160
}

src/NHibernate.Test/Linq/ByMethod/CountTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Linq;
33
using NHibernate.Cfg;
4+
using NHibernate.Dialect;
45
using NUnit.Framework;
56

67
namespace NHibernate.Test.Linq.ByMethod
@@ -98,5 +99,50 @@ into temp
9899

99100
Assert.That(result.Count, Is.EqualTo(77));
100101
}
102+
103+
[Test]
104+
public void CheckSqlFunctionNameLongCount()
105+
{
106+
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
107+
using (var sqlLog = new SqlLogSpy())
108+
{
109+
var result = db.Orders.LongCount();
110+
Assert.That(result, Is.EqualTo(830));
111+
112+
var log = sqlLog.GetWholeLog();
113+
Assert.That(log, Does.Contain($"{name}("));
114+
}
115+
}
116+
117+
[Test]
118+
public void CheckSqlFunctionNameForCount()
119+
{
120+
using (var sqlLog = new SqlLogSpy())
121+
{
122+
var result = db.Orders.Count();
123+
Assert.That(result, Is.EqualTo(830));
124+
125+
var log = sqlLog.GetWholeLog();
126+
Assert.That(log, Does.Contain("count("));
127+
}
128+
}
129+
130+
[Test]
131+
public void CheckMssqlCountCast()
132+
{
133+
if (!(Dialect is MsSql2000Dialect))
134+
{
135+
Assert.Ignore();
136+
}
137+
138+
using (var sqlLog = new SqlLogSpy())
139+
{
140+
var result = db.Orders.Count();
141+
Assert.That(result, Is.EqualTo(830));
142+
143+
var log = sqlLog.GetWholeLog();
144+
Assert.That(log, Does.Not.Contain("cast("));
145+
}
146+
}
101147
}
102148
}

src/NHibernate/Dialect/Dialect.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public abstract partial class Dialect
5555
static Dialect()
5656
{
5757
StandardAggregateFunctions["count"] = new CountQueryFunctionInfo();
58+
StandardAggregateFunctions["count_big"] = new CountQueryFunctionInfo();
5859
StandardAggregateFunctions["avg"] = new AvgQueryFunctionInfo();
5960
StandardAggregateFunctions["max"] = new ClassicAggregateFunction("max", false);
6061
StandardAggregateFunctions["min"] = new ClassicAggregateFunction("min", false);

src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
using System;
22
using System.Collections;
3-
using System.Text;
3+
using System.Collections.Generic;
4+
using System.Linq;
45
using NHibernate.Engine;
56
using NHibernate.SqlCommand;
67
using NHibernate.Type;
7-
using NHibernate.Util;
88

99
namespace NHibernate.Dialect.Function
1010
{
1111
[Serializable]
12-
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar
12+
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended
1313
{
1414
private IType returnType = null;
1515
private readonly string name;
@@ -45,6 +45,15 @@ public virtual IType ReturnType(IType columnType, IMapping mapping)
4545
return returnType ?? columnType;
4646
}
4747

48+
/// <inheritdoc />
49+
public virtual IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
50+
{
51+
return ReturnType(argumentTypes.FirstOrDefault(), mapping);
52+
}
53+
54+
/// <inheritdoc />
55+
public string FunctionName => name;
56+
4857
public bool HasArguments
4958
{
5059
get { return true; }

src/NHibernate/Dialect/Function/ISQLFunction.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System.Collections;
2+
using System.Collections.Generic;
3+
using System.Linq;
24
using NHibernate.Engine;
35
using NHibernate.SqlCommand;
46
using NHibernate.Type;
@@ -41,4 +43,45 @@ public interface ISQLFunction
4143
/// <returns>SQL fragment for the function.</returns>
4244
SqlString Render(IList args, ISessionFactoryImplementor factory);
4345
}
46+
47+
// 6.0 TODO: Remove
48+
internal static class SQLFunctionExtensions
49+
{
50+
/// <summary>
51+
/// Get the type that will be effectively returned by the underlying database.
52+
/// </summary>
53+
/// <param name="sqlFunction">The sql function.</param>
54+
/// <param name="argumentTypes">The types of arguments.</param>
55+
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
56+
/// <param name="throwOnError">Whether to throw when the number of arguments is invalid or they are not supported.</param>
57+
/// <returns>The type returned by the underlying database or <see langword="null"/> when the number of arguments
58+
/// is invalid or they are not supported.</returns>
59+
/// <exception cref="QueryException">When <paramref name="throwOnError"/> is set to <see langword="true"/> and the
60+
/// number of arguments is invalid or they are not supported.</exception>
61+
public static IType GetEffectiveReturnType(
62+
this ISQLFunction sqlFunction,
63+
IEnumerable<IType> argumentTypes,
64+
IMapping mapping,
65+
bool throwOnError)
66+
{
67+
if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction))
68+
{
69+
try
70+
{
71+
return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping);
72+
}
73+
catch (QueryException)
74+
{
75+
if (throwOnError)
76+
{
77+
throw;
78+
}
79+
80+
return null;
81+
}
82+
}
83+
84+
return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError);
85+
}
86+
}
4487
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System.Collections.Generic;
2+
using NHibernate.Engine;
3+
using NHibernate.Type;
4+
5+
namespace NHibernate.Dialect.Function
6+
{
7+
// 6.0 TODO: Merge into ISQLFunction
8+
internal interface ISQLFunctionExtended : ISQLFunction
9+
{
10+
/// <summary>
11+
/// The function name or <see langword="null"/> when multiple functions/operators/statements are used.
12+
/// </summary>
13+
string FunctionName { get; }
14+
15+
/// <summary>
16+
/// Get the type that will be effectively returned by the underlying database.
17+
/// </summary>
18+
/// <param name="argumentTypes">The types of arguments.</param>
19+
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
20+
/// <param name="throwOnError">Whether to throw when the number of arguments is invalid or they are not supported.</param>
21+
/// <returns>The type returned by the underlying database or <see langword="null"/> when the number of arguments
22+
/// is invalid or they are not supported.</returns>
23+
/// <exception cref="QueryException">When <paramref name="throwOnError"/> is set to <see langword="true"/> and the
24+
/// number of arguments is invalid or they are not supported.</exception>
25+
IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError);
26+
}
27+
}

src/NHibernate/Dialect/MsSql2000Dialect.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ protected virtual void RegisterKeywords()
286286

287287
protected virtual void RegisterFunctions()
288288
{
289-
RegisterFunction("count", new CountBigQueryFunction());
289+
RegisterFunction("count", new CountQueryFunction());
290+
RegisterFunction("count_big", new CountBigQueryFunction());
290291

291292
RegisterFunction("abs", new StandardSQLFunction("abs"));
292293
RegisterFunction("absval", new StandardSQLFunction("absval"));
@@ -704,11 +705,15 @@ protected virtual string GetSelectExistingObject(string catalog, string schema,
704705
[Serializable]
705706
protected class CountBigQueryFunction : ClassicAggregateFunction
706707
{
707-
public CountBigQueryFunction() : base("count_big", true) { }
708+
public CountBigQueryFunction() : base("count_big", true, NHibernateUtil.Int64) { }
709+
}
708710

709-
public override IType ReturnType(IType columnType, IMapping mapping)
711+
[Serializable]
712+
private class CountQueryFunction : CountQueryFunctionInfo
713+
{
714+
public override IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
710715
{
711-
return NHibernateUtil.Int64;
716+
return NHibernateUtil.Int32;
712717
}
713718
}
714719

src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ private void EndFunctionTemplate(IASTNode m)
310310
}
311311
}
312312

313+
private void OutAggregateFunctionName(IASTNode m)
314+
{
315+
var aggregateNode = (AggregateNode) m;
316+
Out(aggregateNode.FunctionName);
317+
}
318+
313319
private void CommaBetweenParameters(String comma)
314320
{
315321
writer.CommaBetweenParameters(comma);

src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ selectExpr
150150
;
151151
152152
count
153-
: ^(COUNT { Out("count("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
153+
: ^(c=COUNT { OutAggregateFunctionName(c); Out("("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
154154
;
155155
156156
distinctOrAll
@@ -344,7 +344,7 @@ caseExpr
344344
;
345345
346346
aggregate
347-
: ^(a=AGGREGATE { Out(a); Out("("); } expr { Out(")"); } )
347+
: ^(a=AGGREGATE { OutAggregateFunctionName(a); Out("("); } expr { Out(")"); } )
348348
;
349349
350350

src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using Antlr.Runtime;
3+
using NHibernate.Dialect.Function;
34
using NHibernate.Type;
45
using NHibernate.Hql.Ast.ANTLR.Util;
56

@@ -19,6 +20,19 @@ public AggregateNode(IToken token)
1920
{
2021
}
2122

23+
public string FunctionName
24+
{
25+
get
26+
{
27+
if (SessionFactoryHelper.FindSQLFunction(Text) is ISQLFunctionExtended sqlFunction)
28+
{
29+
return sqlFunction.FunctionName;
30+
}
31+
32+
return Text;
33+
}
34+
}
35+
2236
public override IType DataType
2337
{
2438
get
@@ -31,6 +45,7 @@ public override IType DataType
3145
base.DataType = value;
3246
}
3347
}
48+
3449
public override void SetScalarColumnText(int i)
3550
{
3651
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i);

src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Antlr.Runtime;
2+
using NHibernate.Dialect.Function;
23
using NHibernate.Hql.Ast.ANTLR.Util;
34
using NHibernate.Type;
45

@@ -9,7 +10,7 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree
910
/// Author: josh
1011
/// Ported by: Steve Strong
1112
/// </summary>
12-
class CountNode : AbstractSelectExpression, ISelectExpression
13+
class CountNode : AggregateNode, ISelectExpression
1314
{
1415
public CountNode(IToken token) : base(token)
1516
{
@@ -26,9 +27,5 @@ public override IType DataType
2627
base.DataType = value;
2728
}
2829
}
29-
public override void SetScalarColumnText(int i)
30-
{
31-
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i);
32-
}
3330
}
3431
}

src/NHibernate/Hql/Ast/HqlTreeBuilder.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ public HqlCount Count(HqlExpression child)
307307
return new HqlCount(_factory, child);
308308
}
309309

310+
public HqlCountBig CountBig(HqlExpression child)
311+
{
312+
return new HqlCountBig(_factory, child);
313+
}
314+
310315
public HqlRowStar RowStar()
311316
{
312317
return new HqlRowStar(_factory);

src/NHibernate/Hql/Ast/HqlTreeNode.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,19 @@ public HqlCount(IASTFactory factory, HqlExpression child)
697697
}
698698
}
699699

700+
public class HqlCountBig : HqlExpression
701+
{
702+
public HqlCountBig(IASTFactory factory)
703+
: base(HqlSqlWalker.COUNT, "count_big", factory)
704+
{
705+
}
706+
707+
public HqlCountBig(IASTFactory factory, HqlExpression child)
708+
: base(HqlSqlWalker.COUNT, "count_big", factory, child)
709+
{
710+
}
711+
}
712+
700713
public class HqlAs : HqlExpression
701714
{
702715
public HqlAs(IASTFactory factory, HqlExpression expression, System.Type type)

0 commit comments

Comments
 (0)