Skip to content

Commit 8530ca4

Browse files
maca88fredericDelaporte
authored andcommitted
Fix GenericBatchingBatcher to call IInterceptor.OnPrepareStatement (#2285)
1 parent 92133f0 commit 8530ca4

File tree

6 files changed

+157
-8
lines changed

6 files changed

+157
-8
lines changed

src/NHibernate.Test/Ado/GenericBatchingBatcherFixture.cs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using NHibernate.AdoNet;
66
using NHibernate.Cfg;
77
using NHibernate.Dialect;
8+
using NHibernate.Linq;
9+
using NHibernate.SqlCommand;
810
using NUnit.Framework;
911
using Environment = NHibernate.Cfg.Environment;
1012

@@ -102,6 +104,62 @@ public void MassivePerformanceTest(bool batched)
102104
}
103105
}
104106

107+
[Test]
108+
public void InterceptorOnPrepareStatementTest()
109+
{
110+
var interceptor = new DatabaseInterceptor();
111+
using (var sqlLog = new SqlLogSpy())
112+
using (var s = Sfi.WithOptions().Interceptor(interceptor).OpenSession())
113+
using (var tx = s.BeginTransaction())
114+
{
115+
s.SetBatchSize(5);
116+
for (var i = 0; i < 20; i++)
117+
{
118+
s.Save(new VerySimple { Id = 1 + i, Name = $"Fabio{i}", Weight = 1.45 + i });
119+
}
120+
121+
tx.Commit();
122+
123+
Assert.That(interceptor.TotalCalls, Is.EqualTo(1));
124+
var log = sqlLog.GetWholeLog();
125+
Assert.That(FindAllOccurrences(log, "/* TEST */"), Is.EqualTo(20));
126+
}
127+
128+
interceptor = new DatabaseInterceptor();
129+
using (var sqlLog = new SqlLogSpy())
130+
using (var s = Sfi.WithOptions().Interceptor(interceptor).OpenSession())
131+
using (var tx = s.BeginTransaction())
132+
{
133+
var future = s.Query<VerySimple>().ToFuture();
134+
s.Query<VerySimple>().Where(o => o.Weight > 0).ToFuture();
135+
136+
using (var enumerator = future.GetEnumerable().GetEnumerator())
137+
{
138+
while (enumerator.MoveNext()) { }
139+
}
140+
141+
tx.Commit();
142+
143+
var totalCalls = Sfi.ConnectionProvider.Driver.SupportsMultipleQueries ? 1 : 2;
144+
Assert.That(interceptor.TotalCalls, Is.EqualTo(totalCalls));
145+
var log = sqlLog.GetWholeLog();
146+
Assert.That(FindAllOccurrences(log, "/* TEST */"), Is.EqualTo(totalCalls));
147+
}
148+
149+
Cleanup();
150+
}
151+
152+
private class DatabaseInterceptor : EmptyInterceptor
153+
{
154+
public int TotalCalls { get; private set; }
155+
156+
public override SqlString OnPrepareStatement(SqlString sql)
157+
{
158+
TotalCalls++;
159+
return sql.Append("/* TEST */");
160+
}
161+
}
162+
105163
private void BatchInsert(int totalRecords)
106164
{
107165
Sfi.Statistics.Clear();

src/NHibernate.Test/Async/Ado/GenericBatchingBatcherFixture.cs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
using NHibernate.AdoNet;
1616
using NHibernate.Cfg;
1717
using NHibernate.Dialect;
18+
using NHibernate.Linq;
19+
using NHibernate.SqlCommand;
1820
using NUnit.Framework;
1921
using Environment = NHibernate.Cfg.Environment;
20-
using NHibernate.Linq;
2122

2223
namespace NHibernate.Test.Ado
2324
{
@@ -115,6 +116,62 @@ public async Task MassivePerformanceTestAsync(bool batched)
115116
}
116117
}
117118

119+
[Test]
120+
public async Task InterceptorOnPrepareStatementTestAsync()
121+
{
122+
var interceptor = new DatabaseInterceptor();
123+
using (var sqlLog = new SqlLogSpy())
124+
using (var s = Sfi.WithOptions().Interceptor(interceptor).OpenSession())
125+
using (var tx = s.BeginTransaction())
126+
{
127+
s.SetBatchSize(5);
128+
for (var i = 0; i < 20; i++)
129+
{
130+
await (s.SaveAsync(new VerySimple { Id = 1 + i, Name = $"Fabio{i}", Weight = 1.45 + i }));
131+
}
132+
133+
await (tx.CommitAsync());
134+
135+
Assert.That(interceptor.TotalCalls, Is.EqualTo(1));
136+
var log = sqlLog.GetWholeLog();
137+
Assert.That(FindAllOccurrences(log, "/* TEST */"), Is.EqualTo(20));
138+
}
139+
140+
interceptor = new DatabaseInterceptor();
141+
using (var sqlLog = new SqlLogSpy())
142+
using (var s = Sfi.WithOptions().Interceptor(interceptor).OpenSession())
143+
using (var tx = s.BeginTransaction())
144+
{
145+
var future = s.Query<VerySimple>().ToFuture();
146+
s.Query<VerySimple>().Where(o => o.Weight > 0).ToFuture();
147+
148+
using (var enumerator = (await (future.GetEnumerableAsync())).GetEnumerator())
149+
{
150+
while (enumerator.MoveNext()) { }
151+
}
152+
153+
await (tx.CommitAsync());
154+
155+
var totalCalls = Sfi.ConnectionProvider.Driver.SupportsMultipleQueries ? 1 : 2;
156+
Assert.That(interceptor.TotalCalls, Is.EqualTo(totalCalls));
157+
var log = sqlLog.GetWholeLog();
158+
Assert.That(FindAllOccurrences(log, "/* TEST */"), Is.EqualTo(totalCalls));
159+
}
160+
161+
await (CleanupAsync());
162+
}
163+
164+
private class DatabaseInterceptor : EmptyInterceptor
165+
{
166+
public int TotalCalls { get; private set; }
167+
168+
public override SqlString OnPrepareStatement(SqlString sql)
169+
{
170+
TotalCalls++;
171+
return sql.Append("/* TEST */");
172+
}
173+
}
174+
118175
private async Task BatchInsertAsync(int totalRecords, CancellationToken cancellationToken = default(CancellationToken))
119176
{
120177
Sfi.Statistics.Clear();

src/NHibernate/AdoNet/AbstractBatcher.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,16 @@ protected DbCommand CurrentCommand
8181

8282
public DbCommand Generate(CommandType type, SqlString sqlString, SqlType[] parameterTypes)
8383
{
84-
SqlString sql = GetSQL(sqlString);
84+
return Generate(type, sqlString, parameterTypes, false);
85+
}
86+
87+
private DbCommand Generate(CommandType type, SqlString sqlString, SqlType[] parameterTypes, bool batch)
88+
{
89+
var sql = GetSQL(sqlString);
90+
if (batch)
91+
{
92+
OnPreparedBatchStatement(sql);
93+
}
8594

8695
var cmd = _factory.ConnectionProvider.Driver.GenerateCommand(type, sql, parameterTypes);
8796
LogOpenPreparedCommand();
@@ -141,7 +150,7 @@ public virtual DbCommand PrepareBatchCommand(CommandType type, SqlString sql, Sq
141150
}
142151
else
143152
{
144-
_batchCommand = PrepareCommand(type, sql, parameterTypes); // calls ExecuteBatch()
153+
_batchCommand = PrepareCommand(type, sql, parameterTypes, true); // calls ExecuteBatch()
145154
_batchCommandSql = sql;
146155
_batchCommandParameterTypes = parameterTypes;
147156
}
@@ -150,14 +159,19 @@ public virtual DbCommand PrepareBatchCommand(CommandType type, SqlString sql, Sq
150159
}
151160

152161
public DbCommand PrepareCommand(CommandType type, SqlString sql, SqlType[] parameterTypes)
162+
{
163+
return PrepareCommand(type, sql, parameterTypes, false);
164+
}
165+
166+
private DbCommand PrepareCommand(CommandType type, SqlString sql, SqlType[] parameterTypes, bool batch)
153167
{
154168
OnPreparedCommand();
155169

156170
// do not actually prepare the Command here - instead just generate it because
157171
// if the command is associated with an ADO.NET Transaction/Connection while
158172
// another open one Command is doing something then an exception will be
159173
// thrown.
160-
return Generate(type, sql, parameterTypes);
174+
return Generate(type, sql, parameterTypes, batch);
161175
}
162176

163177
protected virtual void OnPreparedCommand()
@@ -167,6 +181,8 @@ protected virtual void OnPreparedCommand()
167181
ExecuteBatch();
168182
}
169183

184+
internal virtual void OnPreparedBatchStatement(SqlString sqlString) { }
185+
170186
public DbCommand PrepareQueryCommand(CommandType type, SqlString sql, SqlType[] parameterTypes)
171187
{
172188
// do not actually prepare the Command here - instead just generate it because

src/NHibernate/AdoNet/GenericBatchingBatcher.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ protected override void Dispose(bool isDisposing)
140140
_currentBatch.Clear();
141141
}
142142

143+
internal override void OnPreparedBatchStatement(SqlString sqlString)
144+
{
145+
_currentBatch.CurrentStatement = sqlString;
146+
}
147+
143148
private partial class BatchingCommandSet
144149
{
145150
private readonly string _statementTerminator;
@@ -172,6 +177,8 @@ public BatchingCommandSet(GenericBatchingBatcher batcher, char statementTerminat
172177

173178
public int CountOfParameters { get; private set; }
174179

180+
public SqlString CurrentStatement { get; set; }
181+
175182
public void Append(DbParameterCollection parameters)
176183
{
177184
if (CountOfCommands > 0)
@@ -183,7 +190,7 @@ public void Append(DbParameterCollection parameters)
183190
_commandType = _batcher.CurrentCommand.CommandType;
184191
}
185192

186-
_sql.Add(_batcher.CurrentCommandSql.Copy());
193+
_sql.Add(CurrentStatement);
187194
_sqlTypes.AddRange(_batcher.CurrentCommandParameterTypes);
188195

189196
foreach (DbParameter parameter in parameters)
@@ -207,6 +214,7 @@ public int ExecuteNonQuery()
207214
{
208215
return 0;
209216
}
217+
210218
var batcherCommand = _batcher.Driver.GenerateCommand(
211219
_commandType,
212220
_sql.ToSqlString(),

src/NHibernate/Async/AdoNet/AbstractBatcher.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,24 @@ public virtual async Task<DbCommand> PrepareBatchCommandAsync(CommandType type,
7979
}
8080
else
8181
{
82-
_batchCommand = await (PrepareCommandAsync(type, sql, parameterTypes, cancellationToken)).ConfigureAwait(false); // calls ExecuteBatch()
82+
_batchCommand = await (PrepareCommandAsync(type, sql, parameterTypes, true, cancellationToken)).ConfigureAwait(false); // calls ExecuteBatch()
8383
_batchCommandSql = sql;
8484
_batchCommandParameterTypes = parameterTypes;
8585
}
8686

8787
return _batchCommand;
8888
}
8989

90-
public async Task<DbCommand> PrepareCommandAsync(CommandType type, SqlString sql, SqlType[] parameterTypes, CancellationToken cancellationToken)
90+
public Task<DbCommand> PrepareCommandAsync(CommandType type, SqlString sql, SqlType[] parameterTypes, CancellationToken cancellationToken)
91+
{
92+
if (cancellationToken.IsCancellationRequested)
93+
{
94+
return Task.FromCanceled<DbCommand>(cancellationToken);
95+
}
96+
return PrepareCommandAsync(type, sql, parameterTypes, false, cancellationToken);
97+
}
98+
99+
private async Task<DbCommand> PrepareCommandAsync(CommandType type, SqlString sql, SqlType[] parameterTypes, bool batch, CancellationToken cancellationToken)
91100
{
92101
cancellationToken.ThrowIfCancellationRequested();
93102
await (OnPreparedCommandAsync(cancellationToken)).ConfigureAwait(false);
@@ -96,7 +105,7 @@ public async Task<DbCommand> PrepareCommandAsync(CommandType type, SqlString sql
96105
// if the command is associated with an ADO.NET Transaction/Connection while
97106
// another open one Command is doing something then an exception will be
98107
// thrown.
99-
return Generate(type, sql, parameterTypes);
108+
return Generate(type, sql, parameterTypes, batch);
100109
}
101110

102111
protected virtual Task OnPreparedCommandAsync(CancellationToken cancellationToken)

src/NHibernate/Async/AdoNet/GenericBatchingBatcher.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ public async Task<int> ExecuteNonQueryAsync(CancellationToken cancellationToken)
8989
{
9090
return 0;
9191
}
92+
9293
var batcherCommand = _batcher.Driver.GenerateCommand(
9394
_commandType,
9495
_sql.ToSqlString(),

0 commit comments

Comments
 (0)