Skip to content

Commit a499f19

Browse files
committed
Fix auto flush for Enumerable and AsyncEnumerable methods
1 parent 1cb4e08 commit a499f19

File tree

7 files changed

+147
-67
lines changed

7 files changed

+147
-67
lines changed

src/NHibernate.Test/Async/GenericTest/Methods/Fixture.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ protected override string[] Mappings
2424
{
2525
get
2626
{
27-
return new string[] { "One.hbm.xml", "Many.hbm.xml" };
27+
return new string[] { "One.hbm.xml", "Many.hbm.xml", "Simple.hbm.xml" };
2828
}
2929
}
3030

@@ -49,12 +49,15 @@ protected override void OnSetUp()
4949
many2.One = one;
5050
one.Manies.Add( many2 );
5151

52-
using( ISession s = OpenSession() )
52+
var simple = new Simple(1) {Count = 1};
53+
54+
using ( ISession s = OpenSession() )
5355
using( ITransaction t = s.BeginTransaction() )
5456
{
5557
s.Save( one );
5658
s.Save( many1 );
5759
s.Save( many2 );
60+
s.Save(simple, 1);
5861
t.Commit();
5962
}
6063
}
@@ -66,6 +69,7 @@ protected override void OnTearDown()
6669
{
6770
session.Delete( "from Many" );
6871
session.Delete( "from One" );
72+
session.Delete("from Simple");
6973
tx.Commit();
7074
}
7175
base.OnTearDown();

src/NHibernate.Test/GenericTest/Methods/Fixture.cs

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ protected override string[] Mappings
1414
{
1515
get
1616
{
17-
return new string[] { "One.hbm.xml", "Many.hbm.xml" };
17+
return new string[] { "One.hbm.xml", "Many.hbm.xml", "Simple.hbm.xml" };
1818
}
1919
}
2020

@@ -39,12 +39,15 @@ protected override void OnSetUp()
3939
many2.One = one;
4040
one.Manies.Add( many2 );
4141

42-
using( ISession s = OpenSession() )
42+
var simple = new Simple(1) {Count = 1};
43+
44+
using ( ISession s = OpenSession() )
4345
using( ITransaction t = s.BeginTransaction() )
4446
{
4547
s.Save( one );
4648
s.Save( many1 );
4749
s.Save( many2 );
50+
s.Save(simple, 1);
4851
t.Commit();
4952
}
5053
}
@@ -56,6 +59,7 @@ protected override void OnTearDown()
5659
{
5760
session.Delete( "from Many" );
5861
session.Delete( "from One" );
62+
session.Delete("from Simple");
5963
tx.Commit();
6064
}
6165
base.OnTearDown();
@@ -106,6 +110,40 @@ public void QueryEnumerable()
106110
}
107111
}
108112

113+
[Test]
114+
public void AutoFlushQueryEnumerable()
115+
{
116+
using (var s = OpenSession())
117+
using (var t = s.BeginTransaction())
118+
{
119+
Assert.That(s.FlushMode, Is.EqualTo(FlushMode.Auto));
120+
var results = s.CreateQuery("from Simple").Enumerable<Simple>();
121+
122+
var id = 2;
123+
var simple = new Simple(id) {Count = id};
124+
s.Save(simple, id);
125+
var enumerator = results.GetEnumerator();
126+
127+
Assert.That(enumerator.MoveNext(), Is.True);
128+
Assert.That(enumerator.MoveNext(), Is.True);
129+
Assert.That(enumerator.MoveNext(), Is.False);
130+
enumerator.Dispose();
131+
132+
id++;
133+
simple = new Simple(id) {Count = id};
134+
s.Save(simple, id);
135+
enumerator = results.GetEnumerator();
136+
137+
Assert.That(enumerator.MoveNext(), Is.True);
138+
Assert.That(enumerator.MoveNext(), Is.True);
139+
Assert.That(enumerator.MoveNext(), Is.True);
140+
Assert.That(enumerator.MoveNext(), Is.False);
141+
enumerator.Dispose();
142+
143+
t.Rollback();
144+
}
145+
}
146+
109147
[Test]
110148
public async Task QueryEnumerableAsync()
111149
{
@@ -120,6 +158,40 @@ public async Task QueryEnumerableAsync()
120158
}
121159
}
122160

161+
[Test]
162+
public async Task AutoFlushQueryEnumerableAsync()
163+
{
164+
using (var s = OpenSession())
165+
using (var t = s.BeginTransaction())
166+
{
167+
Assert.That(s.FlushMode, Is.EqualTo(FlushMode.Auto));
168+
var results = s.CreateQuery("from Simple").AsyncEnumerable<Simple>();
169+
170+
var id = 2;
171+
var simple = new Simple(id) {Count = id};
172+
s.Save(simple, id);
173+
var enumerator = results.GetAsyncEnumerator();
174+
175+
Assert.That(await enumerator.MoveNextAsync(), Is.True);
176+
Assert.That(await enumerator.MoveNextAsync(), Is.True);
177+
Assert.That(await enumerator.MoveNextAsync(), Is.False);
178+
await enumerator.DisposeAsync();
179+
180+
id++;
181+
simple = new Simple(id) {Count = id};
182+
s.Save(simple, id);
183+
enumerator = results.GetAsyncEnumerator();
184+
185+
Assert.That(await enumerator.MoveNextAsync(), Is.True);
186+
Assert.That(await enumerator.MoveNextAsync(), Is.True);
187+
Assert.That(await enumerator.MoveNextAsync(), Is.True);
188+
Assert.That(await enumerator.MoveNextAsync(), Is.False);
189+
await enumerator.DisposeAsync();
190+
191+
await t.RollbackAsync();
192+
}
193+
}
194+
123195
[Test]
124196
public void Filter()
125197
{

src/NHibernate/Async/Impl/SessionImpl.cs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -302,38 +302,38 @@ public override async Task<IQueryTranslator[]> GetQueriesAsync(IQueryExpression
302302
/// <inheritdoc />
303303
// Since v5.3
304304
[Obsolete("Use AsyncEnumerable extension method instead.")]
305-
public override async Task<IEnumerable<T>> EnumerableAsync<T>(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken)
305+
public override Task<IEnumerable<T>> EnumerableAsync<T>(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken)
306306
{
307-
cancellationToken.ThrowIfCancellationRequested();
308-
using (BeginProcess())
307+
if (cancellationToken.IsCancellationRequested)
309308
{
310-
queryParameters.ValidateParameters();
311-
var plan = GetHQLQueryPlan(queryExpression, true);
312-
await (AutoFlushIfRequiredAsync(plan.QuerySpaces, cancellationToken)).ConfigureAwait(false);
313-
314-
using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
315-
{
316-
return plan.PerformIterate<T>(queryParameters, this);
317-
}
309+
return Task.FromCanceled<IEnumerable<T>>(cancellationToken);
310+
}
311+
try
312+
{
313+
return Task.FromResult<IEnumerable<T>>(Enumerable<T>(queryExpression, queryParameters));
314+
}
315+
catch (Exception ex)
316+
{
317+
return Task.FromException<IEnumerable<T>>(ex);
318318
}
319319
}
320320

321321
/// <inheritdoc />
322322
// Since v5.3
323323
[Obsolete("Use AsyncEnumerable extension method instead.")]
324-
public override async Task<IEnumerable> EnumerableAsync(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken)
324+
public override Task<IEnumerable> EnumerableAsync(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken)
325325
{
326-
cancellationToken.ThrowIfCancellationRequested();
327-
using (BeginProcess())
326+
if (cancellationToken.IsCancellationRequested)
328327
{
329-
queryParameters.ValidateParameters();
330-
var plan = GetHQLQueryPlan(queryExpression, true);
331-
await (AutoFlushIfRequiredAsync(plan.QuerySpaces, cancellationToken)).ConfigureAwait(false);
332-
333-
using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
334-
{
335-
return plan.PerformIterate(queryParameters, this);
336-
}
328+
return Task.FromCanceled<IEnumerable>(cancellationToken);
329+
}
330+
try
331+
{
332+
return Task.FromResult<IEnumerable>(Enumerable(queryExpression, queryParameters));
333+
}
334+
catch (Exception ex)
335+
{
336+
return Task.FromException<IEnumerable>(ex);
337337
}
338338
}
339339

src/NHibernate/Async/Loader/Hql/QueryLoader.cs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,26 @@ protected override async Task<object[]> GetResultRowAsync(object[] row, DbDataRe
9090
internal async Task<InitializeEnumerableResult> InitializeEnumerableAsync(QueryParameters queryParameters, ISessionImplementor session, CancellationToken cancellationToken)
9191
{
9292
cancellationToken.ThrowIfCancellationRequested();
93-
Stopwatch stopWatch = null;
94-
if (session.Factory.Statistics.IsStatisticsEnabled)
93+
await (session.AutoFlushIfRequiredAsync(_queryTranslator.QuerySpaces, cancellationToken)).ConfigureAwait(false);
94+
using (session.SuspendAutoFlush())
9595
{
96-
stopWatch = Stopwatch.StartNew();
97-
}
96+
Stopwatch stopWatch = null;
97+
if (session.Factory.Statistics.IsStatisticsEnabled)
98+
{
99+
stopWatch = Stopwatch.StartNew();
100+
}
98101

99-
var command = await (PrepareQueryCommandAsync(queryParameters, false, session, cancellationToken)).ConfigureAwait(false);
100-
var dataReader = await (GetResultSetAsync(command, queryParameters, session, null, cancellationToken)).ConfigureAwait(false);
101-
if (stopWatch != null)
102-
{
103-
stopWatch.Stop();
104-
session.Factory.StatisticsImplementor.QueryExecuted("HQL: " + _queryTranslator.QueryString, 0, stopWatch.Elapsed);
105-
session.Factory.StatisticsImplementor.QueryExecuted(QueryIdentifier, 0, stopWatch.Elapsed);
106-
}
102+
var command = await (PrepareQueryCommandAsync(queryParameters, false, session, cancellationToken)).ConfigureAwait(false);
103+
var dataReader = await (GetResultSetAsync(command, queryParameters, session, null, cancellationToken)).ConfigureAwait(false);
104+
if (stopWatch != null)
105+
{
106+
stopWatch.Stop();
107+
session.Factory.StatisticsImplementor.QueryExecuted("HQL: " + _queryTranslator.QueryString, 0, stopWatch.Elapsed);
108+
session.Factory.StatisticsImplementor.QueryExecuted(QueryIdentifier, 0, stopWatch.Elapsed);
109+
}
107110

108-
return new InitializeEnumerableResult(command, dataReader);
111+
return new InitializeEnumerableResult(command, dataReader);
112+
}
109113
}
110114
}
111115
}

src/NHibernate/Engine/ISessionImplementor.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ internal static void AutoFlushIfRequired(this ISessionImplementor implementor, I
6464
(implementor as AbstractSessionImpl)?.AutoFlushIfRequired(querySpaces);
6565
}
6666

67+
internal static IDisposable SuspendAutoFlush(this ISessionImplementor implementor)
68+
{
69+
return (implementor as IEventSource)?.SuspendAutoFlush();
70+
}
71+
6772
/// <summary>
6873
/// Returns an <see cref="IAsyncEnumerable{T}" /> which can be enumerated asynchronously.
6974
/// </summary>

src/NHibernate/Impl/SessionImpl.cs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -585,12 +585,9 @@ public override IEnumerable<T> Enumerable<T>(IQueryExpression queryExpression, Q
585585
{
586586
queryParameters.ValidateParameters();
587587
var plan = GetHQLQueryPlan(queryExpression, true);
588-
AutoFlushIfRequired(plan.QuerySpaces);
589588

590-
using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
591-
{
592-
return plan.PerformIterate<T>(queryParameters, this);
593-
}
589+
// AutoFlushIfRequired will be called when iterating through the enumerable
590+
return plan.PerformIterate<T>(queryParameters, this);
594591
}
595592
}
596593

@@ -600,12 +597,9 @@ public override IAsyncEnumerable<T> AsyncEnumerable<T>(IQueryExpression queryExp
600597
{
601598
queryParameters.ValidateParameters();
602599
var plan = GetHQLQueryPlan(queryExpression, true);
603-
AutoFlushIfRequired(plan.QuerySpaces);
604600

605-
using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
606-
{
607-
return plan.PerformAsyncIterate<T>(queryParameters, this);
608-
}
601+
// AutoFlushIfRequired will be called when iterating through the enumerable
602+
return plan.PerformAsyncIterate<T>(queryParameters, this);
609603
}
610604
}
611605

@@ -615,12 +609,9 @@ public override IEnumerable Enumerable(IQueryExpression queryExpression, QueryPa
615609
{
616610
queryParameters.ValidateParameters();
617611
var plan = GetHQLQueryPlan(queryExpression, true);
618-
AutoFlushIfRequired(plan.QuerySpaces);
619612

620-
using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
621-
{
622-
return plan.PerformIterate(queryParameters, this);
623-
}
613+
// AutoFlushIfRequired will be called when iterating through the enumerable
614+
return plan.PerformIterate(queryParameters, this);
624615
}
625616
}
626617

src/NHibernate/Loader/Hql/QueryLoader.cs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -486,22 +486,26 @@ internal AsyncEnumerableImpl<T> GetAsyncEnumerable<T>(QueryParameters queryParam
486486

487487
internal InitializeEnumerableResult InitializeEnumerable(QueryParameters queryParameters, ISessionImplementor session)
488488
{
489-
Stopwatch stopWatch = null;
490-
if (session.Factory.Statistics.IsStatisticsEnabled)
489+
session.AutoFlushIfRequired(_queryTranslator.QuerySpaces);
490+
using (session.SuspendAutoFlush())
491491
{
492-
stopWatch = Stopwatch.StartNew();
493-
}
492+
Stopwatch stopWatch = null;
493+
if (session.Factory.Statistics.IsStatisticsEnabled)
494+
{
495+
stopWatch = Stopwatch.StartNew();
496+
}
494497

495-
var command = PrepareQueryCommand(queryParameters, false, session);
496-
var dataReader = GetResultSet(command, queryParameters, session, null);
497-
if (stopWatch != null)
498-
{
499-
stopWatch.Stop();
500-
session.Factory.StatisticsImplementor.QueryExecuted("HQL: " + _queryTranslator.QueryString, 0, stopWatch.Elapsed);
501-
session.Factory.StatisticsImplementor.QueryExecuted(QueryIdentifier, 0, stopWatch.Elapsed);
502-
}
498+
var command = PrepareQueryCommand(queryParameters, false, session);
499+
var dataReader = GetResultSet(command, queryParameters, session, null);
500+
if (stopWatch != null)
501+
{
502+
stopWatch.Stop();
503+
session.Factory.StatisticsImplementor.QueryExecuted("HQL: " + _queryTranslator.QueryString, 0, stopWatch.Elapsed);
504+
session.Factory.StatisticsImplementor.QueryExecuted(QueryIdentifier, 0, stopWatch.Elapsed);
505+
}
503506

504-
return new InitializeEnumerableResult(command, dataReader);
507+
return new InitializeEnumerableResult(command, dataReader);
508+
}
505509
}
506510

507511
protected override void ResetEffectiveExpectedType(IEnumerable<IParameterSpecification> parameterSpecs, QueryParameters queryParameters)

0 commit comments

Comments
 (0)