diff --git a/src/NHibernate.Test/Ado/GenericBatchingBatcherFixture.cs b/src/NHibernate.Test/Ado/GenericBatchingBatcherFixture.cs index b6ce4cd96ba..0188bda143f 100644 --- a/src/NHibernate.Test/Ado/GenericBatchingBatcherFixture.cs +++ b/src/NHibernate.Test/Ado/GenericBatchingBatcherFixture.cs @@ -5,6 +5,8 @@ using NHibernate.AdoNet; using NHibernate.Cfg; using NHibernate.Dialect; +using NHibernate.Linq; +using NHibernate.SqlCommand; using NUnit.Framework; using Environment = NHibernate.Cfg.Environment; @@ -102,6 +104,62 @@ public void MassivePerformanceTest(bool batched) } } + [Test] + public void InterceptorOnPrepareStatementTest() + { + var interceptor = new DatabaseInterceptor(); + using (var sqlLog = new SqlLogSpy()) + using (var s = Sfi.WithOptions().Interceptor(interceptor).OpenSession()) + using (var tx = s.BeginTransaction()) + { + s.SetBatchSize(5); + for (var i = 0; i < 20; i++) + { + s.Save(new VerySimple { Id = 1 + i, Name = $"Fabio{i}", Weight = 1.45 + i }); + } + + tx.Commit(); + + Assert.That(interceptor.TotalCalls, Is.EqualTo(1)); + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "/* TEST */"), Is.EqualTo(20)); + } + + interceptor = new DatabaseInterceptor(); + using (var sqlLog = new SqlLogSpy()) + using (var s = Sfi.WithOptions().Interceptor(interceptor).OpenSession()) + using (var tx = s.BeginTransaction()) + { + var future = s.Query().ToFuture(); + s.Query().Where(o => o.Weight > 0).ToFuture(); + + using (var enumerator = future.GetEnumerable().GetEnumerator()) + { + while (enumerator.MoveNext()) { } + } + + tx.Commit(); + + var totalCalls = Sfi.ConnectionProvider.Driver.SupportsMultipleQueries ? 1 : 2; + Assert.That(interceptor.TotalCalls, Is.EqualTo(totalCalls)); + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "/* TEST */"), Is.EqualTo(totalCalls)); + } + + Cleanup(); + } + + private class DatabaseInterceptor : EmptyInterceptor + { + public int TotalCalls { get; private set; } + + public override SqlString OnPrepareStatement(SqlString sql) + { + TotalCalls++; + return sql.Append("/* TEST */"); + } + } + private void BatchInsert(int totalRecords) { Sfi.Statistics.Clear(); diff --git a/src/NHibernate.Test/Async/Ado/GenericBatchingBatcherFixture.cs b/src/NHibernate.Test/Async/Ado/GenericBatchingBatcherFixture.cs index 99bbb4098b3..f5b701d70b2 100644 --- a/src/NHibernate.Test/Async/Ado/GenericBatchingBatcherFixture.cs +++ b/src/NHibernate.Test/Async/Ado/GenericBatchingBatcherFixture.cs @@ -15,9 +15,10 @@ using NHibernate.AdoNet; using NHibernate.Cfg; using NHibernate.Dialect; +using NHibernate.Linq; +using NHibernate.SqlCommand; using NUnit.Framework; using Environment = NHibernate.Cfg.Environment; -using NHibernate.Linq; namespace NHibernate.Test.Ado { @@ -115,6 +116,62 @@ public async Task MassivePerformanceTestAsync(bool batched) } } + [Test] + public async Task InterceptorOnPrepareStatementTestAsync() + { + var interceptor = new DatabaseInterceptor(); + using (var sqlLog = new SqlLogSpy()) + using (var s = Sfi.WithOptions().Interceptor(interceptor).OpenSession()) + using (var tx = s.BeginTransaction()) + { + s.SetBatchSize(5); + for (var i = 0; i < 20; i++) + { + await (s.SaveAsync(new VerySimple { Id = 1 + i, Name = $"Fabio{i}", Weight = 1.45 + i })); + } + + await (tx.CommitAsync()); + + Assert.That(interceptor.TotalCalls, Is.EqualTo(1)); + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "/* TEST */"), Is.EqualTo(20)); + } + + interceptor = new DatabaseInterceptor(); + using (var sqlLog = new SqlLogSpy()) + using (var s = Sfi.WithOptions().Interceptor(interceptor).OpenSession()) + using (var tx = s.BeginTransaction()) + { + var future = s.Query().ToFuture(); + s.Query().Where(o => o.Weight > 0).ToFuture(); + + using (var enumerator = (await (future.GetEnumerableAsync())).GetEnumerator()) + { + while (enumerator.MoveNext()) { } + } + + await (tx.CommitAsync()); + + var totalCalls = Sfi.ConnectionProvider.Driver.SupportsMultipleQueries ? 1 : 2; + Assert.That(interceptor.TotalCalls, Is.EqualTo(totalCalls)); + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "/* TEST */"), Is.EqualTo(totalCalls)); + } + + await (CleanupAsync()); + } + + private class DatabaseInterceptor : EmptyInterceptor + { + public int TotalCalls { get; private set; } + + public override SqlString OnPrepareStatement(SqlString sql) + { + TotalCalls++; + return sql.Append("/* TEST */"); + } + } + private async Task BatchInsertAsync(int totalRecords, CancellationToken cancellationToken = default(CancellationToken)) { Sfi.Statistics.Clear(); diff --git a/src/NHibernate/AdoNet/AbstractBatcher.cs b/src/NHibernate/AdoNet/AbstractBatcher.cs index 867c6fa997d..2c05c4a3fb5 100644 --- a/src/NHibernate/AdoNet/AbstractBatcher.cs +++ b/src/NHibernate/AdoNet/AbstractBatcher.cs @@ -81,7 +81,16 @@ protected DbCommand CurrentCommand public DbCommand Generate(CommandType type, SqlString sqlString, SqlType[] parameterTypes) { - SqlString sql = GetSQL(sqlString); + return Generate(type, sqlString, parameterTypes, false); + } + + private DbCommand Generate(CommandType type, SqlString sqlString, SqlType[] parameterTypes, bool batch) + { + var sql = GetSQL(sqlString); + if (batch) + { + OnPreparedBatchStatement(sql); + } var cmd = _factory.ConnectionProvider.Driver.GenerateCommand(type, sql, parameterTypes); LogOpenPreparedCommand(); @@ -141,7 +150,7 @@ public virtual DbCommand PrepareBatchCommand(CommandType type, SqlString sql, Sq } else { - _batchCommand = PrepareCommand(type, sql, parameterTypes); // calls ExecuteBatch() + _batchCommand = PrepareCommand(type, sql, parameterTypes, true); // calls ExecuteBatch() _batchCommandSql = sql; _batchCommandParameterTypes = parameterTypes; } @@ -150,6 +159,11 @@ public virtual DbCommand PrepareBatchCommand(CommandType type, SqlString sql, Sq } public DbCommand PrepareCommand(CommandType type, SqlString sql, SqlType[] parameterTypes) + { + return PrepareCommand(type, sql, parameterTypes, false); + } + + private DbCommand PrepareCommand(CommandType type, SqlString sql, SqlType[] parameterTypes, bool batch) { OnPreparedCommand(); @@ -157,7 +171,7 @@ public DbCommand PrepareCommand(CommandType type, SqlString sql, SqlType[] param // if the command is associated with an ADO.NET Transaction/Connection while // another open one Command is doing something then an exception will be // thrown. - return Generate(type, sql, parameterTypes); + return Generate(type, sql, parameterTypes, batch); } protected virtual void OnPreparedCommand() @@ -167,6 +181,8 @@ protected virtual void OnPreparedCommand() ExecuteBatch(); } + internal virtual void OnPreparedBatchStatement(SqlString sqlString) { } + public DbCommand PrepareQueryCommand(CommandType type, SqlString sql, SqlType[] parameterTypes) { // do not actually prepare the Command here - instead just generate it because diff --git a/src/NHibernate/AdoNet/GenericBatchingBatcher.cs b/src/NHibernate/AdoNet/GenericBatchingBatcher.cs index c7499067cd4..eac197df446 100644 --- a/src/NHibernate/AdoNet/GenericBatchingBatcher.cs +++ b/src/NHibernate/AdoNet/GenericBatchingBatcher.cs @@ -140,6 +140,11 @@ protected override void Dispose(bool isDisposing) _currentBatch.Clear(); } + internal override void OnPreparedBatchStatement(SqlString sqlString) + { + _currentBatch.CurrentStatement = sqlString; + } + private partial class BatchingCommandSet { private readonly string _statementTerminator; @@ -172,6 +177,8 @@ public BatchingCommandSet(GenericBatchingBatcher batcher, char statementTerminat public int CountOfParameters { get; private set; } + public SqlString CurrentStatement { get; set; } + public void Append(DbParameterCollection parameters) { if (CountOfCommands > 0) @@ -183,7 +190,7 @@ public void Append(DbParameterCollection parameters) _commandType = _batcher.CurrentCommand.CommandType; } - _sql.Add(_batcher.CurrentCommandSql.Copy()); + _sql.Add(CurrentStatement); _sqlTypes.AddRange(_batcher.CurrentCommandParameterTypes); foreach (DbParameter parameter in parameters) @@ -207,6 +214,7 @@ public int ExecuteNonQuery() { return 0; } + var batcherCommand = _batcher.Driver.GenerateCommand( _commandType, _sql.ToSqlString(), diff --git a/src/NHibernate/Async/AdoNet/AbstractBatcher.cs b/src/NHibernate/Async/AdoNet/AbstractBatcher.cs index 9ac50481b66..b90a3a14c48 100644 --- a/src/NHibernate/Async/AdoNet/AbstractBatcher.cs +++ b/src/NHibernate/Async/AdoNet/AbstractBatcher.cs @@ -79,7 +79,7 @@ public virtual async Task PrepareBatchCommandAsync(CommandType type, } else { - _batchCommand = await (PrepareCommandAsync(type, sql, parameterTypes, cancellationToken)).ConfigureAwait(false); // calls ExecuteBatch() + _batchCommand = await (PrepareCommandAsync(type, sql, parameterTypes, true, cancellationToken)).ConfigureAwait(false); // calls ExecuteBatch() _batchCommandSql = sql; _batchCommandParameterTypes = parameterTypes; } @@ -87,7 +87,16 @@ public virtual async Task PrepareBatchCommandAsync(CommandType type, return _batchCommand; } - public async Task PrepareCommandAsync(CommandType type, SqlString sql, SqlType[] parameterTypes, CancellationToken cancellationToken) + public Task PrepareCommandAsync(CommandType type, SqlString sql, SqlType[] parameterTypes, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + return PrepareCommandAsync(type, sql, parameterTypes, false, cancellationToken); + } + + private async Task PrepareCommandAsync(CommandType type, SqlString sql, SqlType[] parameterTypes, bool batch, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); await (OnPreparedCommandAsync(cancellationToken)).ConfigureAwait(false); @@ -96,7 +105,7 @@ public async Task PrepareCommandAsync(CommandType type, SqlString sql // if the command is associated with an ADO.NET Transaction/Connection while // another open one Command is doing something then an exception will be // thrown. - return Generate(type, sql, parameterTypes); + return Generate(type, sql, parameterTypes, batch); } protected virtual Task OnPreparedCommandAsync(CancellationToken cancellationToken) diff --git a/src/NHibernate/Async/AdoNet/GenericBatchingBatcher.cs b/src/NHibernate/Async/AdoNet/GenericBatchingBatcher.cs index f769b727a06..549dc3fee4e 100644 --- a/src/NHibernate/Async/AdoNet/GenericBatchingBatcher.cs +++ b/src/NHibernate/Async/AdoNet/GenericBatchingBatcher.cs @@ -89,6 +89,7 @@ public async Task ExecuteNonQueryAsync(CancellationToken cancellationToken) { return 0; } + var batcherCommand = _batcher.Driver.GenerateCommand( _commandType, _sql.ToSqlString(),