From 9c73f0499abb5f8de1f9d2e648fc9f1804fa8436 Mon Sep 17 00:00:00 2001 From: Onur Gumus Date: Sun, 17 Aug 2014 15:49:27 +0300 Subject: [PATCH 1/3] NH-2285: Support for LockMode in linq provider --- src/NHibernate.Test/Linq/QueryLock.cs | 47 +++++++++++++++++++ src/NHibernate/Linq/LinqExtensionMethods.cs | 9 ++++ src/NHibernate/Linq/LockExpressionNode.cs | 29 ++++++++++++ src/NHibernate/Linq/LockResultOperator.cs | 39 +++++++++++++++ src/NHibernate/Linq/NhRelinqQueryParser.cs | 6 +-- .../Linq/ReWriters/ResultOperatorRewriter.cs | 1 + .../Linq/Visitors/QueryModelVisitor.cs | 1 + .../ResultOperatorProcessors/ProcessLock.cs | 10 ++++ .../Visitors/SubQueryFromClauseFlattener.cs | 1 + 9 files changed, 140 insertions(+), 3 deletions(-) create mode 100644 src/NHibernate.Test/Linq/QueryLock.cs create mode 100644 src/NHibernate/Linq/LockExpressionNode.cs create mode 100644 src/NHibernate/Linq/LockResultOperator.cs create mode 100644 src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs diff --git a/src/NHibernate.Test/Linq/QueryLock.cs b/src/NHibernate.Test/Linq/QueryLock.cs new file mode 100644 index 00000000000..694ab9443f1 --- /dev/null +++ b/src/NHibernate.Test/Linq/QueryLock.cs @@ -0,0 +1,47 @@ +using System.Linq; +using NHibernate.AdoNet; +using NHibernate.Cfg; +using NHibernate.Engine; +using NHibernate.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.Linq +{ + public class QueryLock : LinqTestCase + { + + [Test] + public void CanSetLockLinqQueries() + { + var result = (from e in db.Customers + where e.CompanyName == "Corp" + select e).SetLockMode(LockMode.Upgrade).ToList(); + + } + + + [Test] + public void CanSetLockOnLinqPagingQuery() + { + var result = (from e in db.Customers + where e.CompanyName == "Corp" + select e).Skip(5).Take(5).SetLockMode(LockMode.Upgrade).ToList(); + } + + + [Test] + public void CanLockBeforeSkipOnLinqOrderedPageQuery() + { + var result = (from e in db.Customers + orderby e.CompanyName + select e) + .SetLockMode(LockMode.Upgrade).Skip(5).Take(5).ToList(); + + + } + + + } + +} + diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index 18473d80128..55bf7733784 100644 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -2514,6 +2514,15 @@ public static IQueryable CacheRegion(this IQueryable query, string regi public static IQueryable Timeout(this IQueryable query, int timeout) => query.WithOptions(o => o.SetTimeout(timeout)); + public static IQueryable SetLockMode(this IQueryable query, LockMode lockMode) + { + var method = ReflectHelper.GetMethod(() => SetLockMode(query, lockMode)); + + var callExpression = Expression.Call(method, query.Expression, Expression.Constant(lockMode)); + + return new NhQueryable(query.Provider, callExpression); + } + /// /// Allows to specify the parameter NHibernate type to use for a literal in a queryable expression. /// diff --git a/src/NHibernate/Linq/LockExpressionNode.cs b/src/NHibernate/Linq/LockExpressionNode.cs new file mode 100644 index 00000000000..43fc7411284 --- /dev/null +++ b/src/NHibernate/Linq/LockExpressionNode.cs @@ -0,0 +1,29 @@ +using System.Linq.Expressions; +using Remotion.Linq.Clauses; +using Remotion.Linq.Parsing.Structure.IntermediateModel; + +namespace NHibernate.Linq +{ + internal class LockExpressionNode : ResultOperatorExpressionNodeBase + { + private readonly MethodCallExpressionParseInfo _parseInfo; + private readonly ConstantExpression _lockMode; + + public LockExpressionNode(MethodCallExpressionParseInfo parseInfo, ConstantExpression lockMode) + : base(parseInfo, null, null) + { + _parseInfo = parseInfo; + _lockMode = lockMode; + } + + public override Expression Resolve(ParameterExpression inputParameter, Expression expressionToBeResolved, ClauseGenerationContext clauseGenerationContext) + { + return Source.Resolve(inputParameter, expressionToBeResolved, clauseGenerationContext); + } + + protected override ResultOperatorBase CreateResultOperator(ClauseGenerationContext clauseGenerationContext) + { + return new LockResultOperator(_parseInfo, _lockMode); + } + } +} diff --git a/src/NHibernate/Linq/LockResultOperator.cs b/src/NHibernate/Linq/LockResultOperator.cs new file mode 100644 index 00000000000..641c91af141 --- /dev/null +++ b/src/NHibernate/Linq/LockResultOperator.cs @@ -0,0 +1,39 @@ +using System; +using System.Linq.Expressions; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.StreamedData; +using Remotion.Linq.Parsing.Structure.IntermediateModel; + +namespace NHibernate.Linq +{ + internal class LockResultOperator : ResultOperatorBase + { + public MethodCallExpressionParseInfo ParseInfo { get; } + public ConstantExpression LockMode { get; } + + public LockResultOperator(MethodCallExpressionParseInfo parseInfo, ConstantExpression lockMode) + { + ParseInfo = parseInfo; + LockMode = lockMode; + } + + public override IStreamedData ExecuteInMemory(IStreamedData input) + { + throw new NotImplementedException(); + } + + public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo) + { + return inputInfo; + } + + public override ResultOperatorBase Clone(CloneContext cloneContext) + { + throw new NotImplementedException(); + } + + public override void TransformExpressions(Func transformation) + { + } + } +} diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs index 01a72df3b81..0e3b698dd79 100644 --- a/src/NHibernate/Linq/NhRelinqQueryParser.cs +++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs @@ -1,5 +1,4 @@ using System.Collections; -using System.Linq; using System.Linq.Expressions; using System.Reflection; using NHibernate.Linq.ExpressionTransformers; @@ -7,11 +6,9 @@ using NHibernate.Util; using Remotion.Linq; using Remotion.Linq.EagerFetching.Parsing; -using Remotion.Linq.Parsing.ExpressionVisitors; using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; using Remotion.Linq.Parsing.Structure; using Remotion.Linq.Parsing.Structure.ExpressionTreeProcessors; -using Remotion.Linq.Parsing.Structure.IntermediateModel; using Remotion.Linq.Parsing.Structure.NodeTypeProviders; namespace NHibernate.Linq @@ -83,6 +80,9 @@ public NHibernateNodeTypeProvider() methodInfoRegistry.Register( new[] { ReflectHelper.GetMethodDefinition(() => EagerFetchingExtensionMethods.ThenFetchMany(null, null)) }, typeof(ThenFetchManyExpressionNode)); + methodInfoRegistry.Register( + new[] { ReflectHelper.GetMethodDefinition(() => LinqExtensionMethods.SetLockMode(null, LockMode.Read)) }, + typeof(LockExpressionNode)); var nodeTypeProvider = ExpressionTreeParser.CreateDefaultNodeTypeProvider(); nodeTypeProvider.InnerProviders.Add(methodInfoRegistry); diff --git a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs index 9f1142fbe2a..b05dddca40c 100644 --- a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs +++ b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs @@ -68,6 +68,7 @@ private class ResultOperatorExpressionRewriter : RelinqExpressionVisitor typeof(OfTypeResultOperator), typeof(CastResultOperator), typeof(AsQueryableResultOperator), + typeof(LockResultOperator), }; private readonly List resultOperators = new List(); diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 0fa769938a5..3567ad7c970 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -141,6 +141,7 @@ static QueryModelVisitor() ResultOperatorMap.Add(); ResultOperatorMap.Add(); ResultOperatorMap.Add(); + ResultOperatorMap.Add(); } private QueryModelVisitor(VisitorParameters visitorParameters, bool root, QueryModel queryModel, diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs new file mode 100644 index 00000000000..69420cc8888 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs @@ -0,0 +1,10 @@ +namespace NHibernate.Linq.Visitors.ResultOperatorProcessors +{ + internal class ProcessLock : IResultOperatorProcessor + { + public void Process(LockResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) + { + tree.AddAdditionalCriteria((q, p) => q.SetLockMode(queryModelVisitor.Model.MainFromClause.ItemName, (LockMode)resultOperator.LockMode.Value)); + } + } +} diff --git a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs index d8d1f2ca58d..f48cc1a77de 100644 --- a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs +++ b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs @@ -12,6 +12,7 @@ public class SubQueryFromClauseFlattener : NhQueryModelVisitorBase { private static readonly System.Type[] FlattenableResultOperators = { + typeof (LockResultOperator), typeof (FetchOneRequest), typeof (FetchManyRequest) }; From e1ff88d16b4bce76df7a8692b0a5a1131a96a818 Mon Sep 17 00:00:00 2001 From: Oskar Berggren Date: Sun, 20 Nov 2016 16:12:48 +0100 Subject: [PATCH 2/3] Strengthen asserts to verify actual locking has taken place - Add test case to verify that a change in lock mode will avoid reusing cached query. --- src/NHibernate.Test/Linq/QueryLock.cs | 100 +++++++++++++++++++++----- 1 file changed, 84 insertions(+), 16 deletions(-) diff --git a/src/NHibernate.Test/Linq/QueryLock.cs b/src/NHibernate.Test/Linq/QueryLock.cs index 694ab9443f1..b739b5e7264 100644 --- a/src/NHibernate.Test/Linq/QueryLock.cs +++ b/src/NHibernate.Test/Linq/QueryLock.cs @@ -1,47 +1,115 @@ using System.Linq; -using NHibernate.AdoNet; -using NHibernate.Cfg; -using NHibernate.Engine; +using System.Transactions; +using NHibernate.Dialect; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Exceptions; using NHibernate.Linq; using NUnit.Framework; + namespace NHibernate.Test.Linq { public class QueryLock : LinqTestCase { - [Test] public void CanSetLockLinqQueries() { - var result = (from e in db.Customers - where e.CompanyName == "Corp" - select e).SetLockMode(LockMode.Upgrade).ToList(); + using (session.BeginTransaction()) + { + var result = (from e in db.Customers + select e).SetLockMode(LockMode.Upgrade).ToList(); + Assert.That(result, Has.Count.EqualTo(91)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); + } } [Test] public void CanSetLockOnLinqPagingQuery() { - var result = (from e in db.Customers - where e.CompanyName == "Corp" - select e).Skip(5).Take(5).SetLockMode(LockMode.Upgrade).ToList(); - } + using (session.BeginTransaction()) + { + var result = (from e in db.Customers + select e).Skip(5).Take(5).SetLockMode(LockMode.Upgrade).ToList(); + Assert.That(result, Has.Count.EqualTo(5)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); + } + } [Test] public void CanLockBeforeSkipOnLinqOrderedPageQuery() { - var result = (from e in db.Customers - orderby e.CompanyName - select e) - .SetLockMode(LockMode.Upgrade).Skip(5).Take(5).ToList(); + using (session.BeginTransaction()) + { + var result = (from e in db.Customers + orderby e.CompanyName + select e) + .SetLockMode(LockMode.Upgrade).Skip(5).Take(5).ToList(); + Assert.That(result, Has.Count.EqualTo(5)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); + } + } + private void AssertSeparateTransactionIsLockedOut(string customerId) + { + using (new TransactionScope(TransactionScopeOption.Suppress)) + using (var s2 = OpenSession()) + using (s2.BeginTransaction()) + { + // TODO: We should try to verify that the exception actually IS a locking failure and not something unrelated. + Assert.Throws( + () => + { + var result2 = ( + from e in s2.Query() + where e.CustomerId == customerId + select e + ).SetLockMode(LockMode.UpgradeNoWait) + .WithOptions(o => o.SetTimeout(5)) + .ToList(); + Assert.That(result2, Is.Not.Null); + }, + "Expected an exception to indicate locking failure due to already locked."); + } } + [Test] + [Description("Verify that different lock modes are respected even if the query is otherwise exactly the same.")] + public void CanChangeLockModeForQuery() + { + // Limit to a few dialects where we know the "nowait" keyword is used to make life easier. + Assume.That(Dialect is MsSql2000Dialect || Dialect is Oracle8iDialect || Dialect is PostgreSQL81Dialect); + + using (session.BeginTransaction()) + { + var result = BuildQueryableAllCustomers(db.Customers, LockMode.Upgrade).ToList(); + Assert.That(result, Has.Count.EqualTo(91)); - } + using (var logSpy = new SqlLogSpy()) + { + // Only difference in query is the lockmode - make sure it gets picked up. + var result2 = BuildQueryableAllCustomers(session.Query(), LockMode.UpgradeNoWait) + .ToList(); + Assert.That(result2, Has.Count.EqualTo(91)); + + Assert.That(logSpy.GetWholeLog().ToLower(), Does.Contain("nowait")); + } + } + } + private static IQueryable BuildQueryableAllCustomers( + IQueryable dbCustomers, + LockMode lockMode) + { + return (from e in dbCustomers select e).SetLockMode(lockMode).WithOptions(o => o.SetTimeout(5)); + } + } } + From 6dea781be45127fa362cfd3fe4ffb731bf5d065d Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Tue, 31 Jul 2018 00:11:36 +1200 Subject: [PATCH 3/3] Correctly resolve referenced query source for the lock - Rename SetLockMode to WithLock - Add overload to support locking on collections in the query - Restrict tests only to dialects that do support locks - Ignore paging tests if dialect does not support combining locks with paging --- src/NHibernate.Test/Async/Linq/QueryLock.cs | 266 ++++++++++++++++++ src/NHibernate.Test/Linq/QueryLock.cs | 149 +++++++++- src/NHibernate.Test/TestDialect.cs | 15 +- .../TestDialects/FirebirdTestDialect.cs | 5 + .../TestDialects/MsSql2008TestDialect.cs | 15 + .../TestDialects/MsSqlCe40TestDialect.cs | 5 + .../TestDialects/Oracle10gTestDialect.cs | 14 + .../TestDialects/SQLiteTestDialect.cs | 5 + .../SapSQLAnywhere17TestDialect.cs | 5 + src/NHibernate/Linq/LinqExtensionMethods.cs | 10 +- src/NHibernate/Linq/LockExpressionNode.cs | 18 +- src/NHibernate/Linq/LockResultOperator.cs | 11 +- src/NHibernate/Linq/NhRelinqQueryParser.cs | 8 +- .../QueryReferenceExpressionFlattener.cs | 1 + .../ResultOperatorProcessors/ProcessLock.cs | 3 +- .../Visitors/SubQueryFromClauseFlattener.cs | 12 +- 16 files changed, 515 insertions(+), 27 deletions(-) create mode 100644 src/NHibernate.Test/Async/Linq/QueryLock.cs create mode 100644 src/NHibernate.Test/TestDialects/MsSql2008TestDialect.cs create mode 100644 src/NHibernate.Test/TestDialects/Oracle10gTestDialect.cs diff --git a/src/NHibernate.Test/Async/Linq/QueryLock.cs b/src/NHibernate.Test/Async/Linq/QueryLock.cs new file mode 100644 index 00000000000..4ecb0404c6c --- /dev/null +++ b/src/NHibernate.Test/Async/Linq/QueryLock.cs @@ -0,0 +1,266 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Linq; +using System.Transactions; +using NHibernate.Dialect; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; +using NHibernate.Engine; +using NHibernate.Exceptions; +using NHibernate.Linq; +using NUnit.Framework; + + +namespace NHibernate.Test.Linq +{ + using System.Threading.Tasks; + using System.Threading; + [TestFixture] + public class QueryLockAsync : LinqTestCase + { + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return TestDialect.SupportsSelectForUpdate; + } + + protected override bool AppliesTo(ISessionFactoryImplementor factory) + { + return !(factory.ConnectionProvider.Driver is OdbcDriver); + } + + [Test] + public async Task CanSetLockLinqQueriesOuterAsync() + { + using (session.BeginTransaction()) + { + var result = await ((from e in db.Customers + select e).WithLock(LockMode.Upgrade).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(91)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + await (AssertSeparateTransactionIsLockedOutAsync(result[0].CustomerId)); + } + } + + [Test] + public async Task CanSetLockLinqQueriesAsync() + { + using (session.BeginTransaction()) + { + var result = await ((from e in db.Customers.WithLock(LockMode.Upgrade) + select e).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(91)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + await (AssertSeparateTransactionIsLockedOutAsync(result[0].CustomerId)); + } + } + + [Test] + public async Task CanSetLockOnJoinHqlAsync() + { + using (session.BeginTransaction()) + { + await (session + .CreateQuery("select o from Customer c join c.Orders o") + .SetLockMode("o", LockMode.Upgrade) + .ListAsync()); + } + } + + [Test] + public async Task CanSetLockOnJoinAsync() + { + using (session.BeginTransaction()) + { + var result = await ((from c in db.Customers + from o in c.Orders.WithLock(LockMode.Upgrade) + select o).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public async Task CanSetLockOnJoinOuterAsync() + { + using (session.BeginTransaction()) + { + var result = await ((from c in db.Customers + from o in c.Orders + select o).WithLock(LockMode.Upgrade).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnJoinOuterNotSupportedAsync() + { + using (session.BeginTransaction()) + { + var query = ( + from c in db.Customers + from o in c.Orders + select new {o, c} + ).WithLock(LockMode.Upgrade); + + Assert.ThrowsAsync(() => query.ToListAsync()); + } + } + + [Test] + public async Task CanSetLockOnJoinOuter2HqlAsync() + { + using (session.BeginTransaction()) + { + await (session + .CreateQuery("select o, c from Customer c join c.Orders o") + .SetLockMode("o", LockMode.Upgrade) + .SetLockMode("c", LockMode.Upgrade) + .ListAsync()); + } + } + + [Test] + public async Task CanSetLockOnBothJoinAndMainAsync() + { + using (session.BeginTransaction()) + { + var result = await (( + from c in db.Customers.WithLock(LockMode.Upgrade) + from o in c.Orders.WithLock(LockMode.Upgrade) + select new {o, c} + ).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0].o), Is.EqualTo(LockMode.Upgrade)); + Assert.That(session.GetCurrentLockMode(result[0].c), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public async Task CanSetLockOnBothJoinAndMainComplexAsync() + { + using (session.BeginTransaction()) + { + var result = await (( + from c in db.Customers.Where(x => true).WithLock(LockMode.Upgrade) + from o in c.Orders.Where(x => true).WithLock(LockMode.Upgrade) + select new {o, c} + ).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0].o), Is.EqualTo(LockMode.Upgrade)); + Assert.That(session.GetCurrentLockMode(result[0].c), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public async Task CanSetLockOnLinqPagingQueryAsync() + { + Assume.That(TestDialect.SupportsSelectForUpdateWithPaging, Is.True, "Dialect does not support locking in subqueries"); + using (session.BeginTransaction()) + { + var result = await ((from e in db.Customers + select e).Skip(5).Take(5).WithLock(LockMode.Upgrade).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(5)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + + await (AssertSeparateTransactionIsLockedOutAsync(result[0].CustomerId)); + } + } + + [Test] + public async Task CanLockBeforeSkipOnLinqOrderedPageQueryAsync() + { + Assume.That(TestDialect.SupportsSelectForUpdateWithPaging, Is.True, "Dialect does not support locking in subqueries"); + using (session.BeginTransaction()) + { + var result = await ((from e in db.Customers + orderby e.CompanyName + select e) + .WithLock(LockMode.Upgrade).Skip(5).Take(5).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(5)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + + await (AssertSeparateTransactionIsLockedOutAsync(result[0].CustomerId)); + } + } + + private Task AssertSeparateTransactionIsLockedOutAsync(string customerId, CancellationToken cancellationToken = default(CancellationToken)) + { + try + { + using (new TransactionScope(TransactionScopeOption.Suppress, TransactionScopeAsyncFlowOption.Enabled)) + using (var s2 = OpenSession()) + using (s2.BeginTransaction()) + { + // TODO: We should try to verify that the exception actually IS a locking failure and not something unrelated. + Assert.ThrowsAsync( + async () => + { + var result2 = await (( + from e in s2.Query() + where e.CustomerId == customerId + select e + ).WithLock(LockMode.UpgradeNoWait) + .WithOptions(o => o.SetTimeout(5)) + .ToListAsync(cancellationToken)); + Assert.That(result2, Is.Not.Null); + }, + "Expected an exception to indicate locking failure due to already locked."); + } + return Task.CompletedTask; + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + [Test] + [Description("Verify that different lock modes are respected even if the query is otherwise exactly the same.")] + public async Task CanChangeLockModeForQueryAsync() + { + // Limit to a few dialects where we know the "nowait" keyword is used to make life easier. + Assume.That(Dialect is MsSql2000Dialect || Dialect is Oracle8iDialect || Dialect is PostgreSQL81Dialect); + + using (session.BeginTransaction()) + { + var result = await (BuildQueryableAllCustomers(db.Customers, LockMode.Upgrade).ToListAsync()); + Assert.That(result, Has.Count.EqualTo(91)); + + using (var logSpy = new SqlLogSpy()) + { + // Only difference in query is the lockmode - make sure it gets picked up. + var result2 = await (BuildQueryableAllCustomers(session.Query(), LockMode.UpgradeNoWait) + .ToListAsync()); + Assert.That(result2, Has.Count.EqualTo(91)); + + Assert.That(logSpy.GetWholeLog().ToLower(), Does.Contain("nowait")); + } + } + } + + private static IQueryable BuildQueryableAllCustomers( + IQueryable dbCustomers, + LockMode lockMode) + { + return (from e in dbCustomers select e).WithLock(lockMode).WithOptions(o => o.SetTimeout(5)); + } + } +} diff --git a/src/NHibernate.Test/Linq/QueryLock.cs b/src/NHibernate.Test/Linq/QueryLock.cs index b739b5e7264..ab6396ebc43 100644 --- a/src/NHibernate.Test/Linq/QueryLock.cs +++ b/src/NHibernate.Test/Linq/QueryLock.cs @@ -1,7 +1,10 @@ -using System.Linq; +using System; +using System.Linq; using System.Transactions; using NHibernate.Dialect; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; +using NHibernate.Engine; using NHibernate.Exceptions; using NHibernate.Linq; using NUnit.Framework; @@ -9,15 +12,40 @@ namespace NHibernate.Test.Linq { + [TestFixture] public class QueryLock : LinqTestCase { + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return TestDialect.SupportsSelectForUpdate; + } + + protected override bool AppliesTo(ISessionFactoryImplementor factory) + { + return !(factory.ConnectionProvider.Driver is OdbcDriver); + } + [Test] - public void CanSetLockLinqQueries() + public void CanSetLockLinqQueriesOuter() { using (session.BeginTransaction()) { var result = (from e in db.Customers - select e).SetLockMode(LockMode.Upgrade).ToList(); + select e).WithLock(LockMode.Upgrade).ToList(); + + Assert.That(result, Has.Count.EqualTo(91)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); + } + } + + [Test] + public void CanSetLockLinqQueries() + { + using (session.BeginTransaction()) + { + var result = (from e in db.Customers.WithLock(LockMode.Upgrade) + select e).ToList(); Assert.That(result, Has.Count.EqualTo(91)); Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); @@ -25,17 +53,120 @@ public void CanSetLockLinqQueries() } } + [Test] + public void CanSetLockOnJoinHql() + { + using (session.BeginTransaction()) + { + session + .CreateQuery("select o from Customer c join c.Orders o") + .SetLockMode("o", LockMode.Upgrade) + .List(); + } + } + + [Test] + public void CanSetLockOnJoin() + { + using (session.BeginTransaction()) + { + var result = (from c in db.Customers + from o in c.Orders.WithLock(LockMode.Upgrade) + select o).ToList(); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnJoinOuter() + { + using (session.BeginTransaction()) + { + var result = (from c in db.Customers + from o in c.Orders + select o).WithLock(LockMode.Upgrade).ToList(); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnJoinOuterNotSupported() + { + using (session.BeginTransaction()) + { + var query = ( + from c in db.Customers + from o in c.Orders + select new {o, c} + ).WithLock(LockMode.Upgrade); + + Assert.Throws(() => query.ToList()); + } + } + + [Test] + public void CanSetLockOnJoinOuter2Hql() + { + using (session.BeginTransaction()) + { + session + .CreateQuery("select o, c from Customer c join c.Orders o") + .SetLockMode("o", LockMode.Upgrade) + .SetLockMode("c", LockMode.Upgrade) + .List(); + } + } + + [Test] + public void CanSetLockOnBothJoinAndMain() + { + using (session.BeginTransaction()) + { + var result = ( + from c in db.Customers.WithLock(LockMode.Upgrade) + from o in c.Orders.WithLock(LockMode.Upgrade) + select new {o, c} + ).ToList(); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0].o), Is.EqualTo(LockMode.Upgrade)); + Assert.That(session.GetCurrentLockMode(result[0].c), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnBothJoinAndMainComplex() + { + using (session.BeginTransaction()) + { + var result = ( + from c in db.Customers.Where(x => true).WithLock(LockMode.Upgrade) + from o in c.Orders.Where(x => true).WithLock(LockMode.Upgrade) + select new {o, c} + ).ToList(); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0].o), Is.EqualTo(LockMode.Upgrade)); + Assert.That(session.GetCurrentLockMode(result[0].c), Is.EqualTo(LockMode.Upgrade)); + } + } [Test] public void CanSetLockOnLinqPagingQuery() { + Assume.That(TestDialect.SupportsSelectForUpdateWithPaging, Is.True, "Dialect does not support locking in subqueries"); using (session.BeginTransaction()) { var result = (from e in db.Customers - select e).Skip(5).Take(5).SetLockMode(LockMode.Upgrade).ToList(); + select e).Skip(5).Take(5).WithLock(LockMode.Upgrade).ToList(); Assert.That(result, Has.Count.EqualTo(5)); Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); } } @@ -43,15 +174,17 @@ public void CanSetLockOnLinqPagingQuery() [Test] public void CanLockBeforeSkipOnLinqOrderedPageQuery() { + Assume.That(TestDialect.SupportsSelectForUpdateWithPaging, Is.True, "Dialect does not support locking in subqueries"); using (session.BeginTransaction()) { var result = (from e in db.Customers orderby e.CompanyName select e) - .SetLockMode(LockMode.Upgrade).Skip(5).Take(5).ToList(); + .WithLock(LockMode.Upgrade).Skip(5).Take(5).ToList(); Assert.That(result, Has.Count.EqualTo(5)); Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); } } @@ -70,7 +203,7 @@ private void AssertSeparateTransactionIsLockedOut(string customerId) from e in s2.Query() where e.CustomerId == customerId select e - ).SetLockMode(LockMode.UpgradeNoWait) + ).WithLock(LockMode.UpgradeNoWait) .WithOptions(o => o.SetTimeout(5)) .ToList(); Assert.That(result2, Is.Not.Null); @@ -107,9 +240,7 @@ private static IQueryable BuildQueryableAllCustomers( IQueryable dbCustomers, LockMode lockMode) { - return (from e in dbCustomers select e).SetLockMode(lockMode).WithOptions(o => o.SetTimeout(5)); + return (from e in dbCustomers select e).WithLock(lockMode).WithOptions(o => o.SetTimeout(5)); } } } - - diff --git a/src/NHibernate.Test/TestDialect.cs b/src/NHibernate.Test/TestDialect.cs index 246c3078f9d..1e23fae97c3 100644 --- a/src/NHibernate.Test/TestDialect.cs +++ b/src/NHibernate.Test/TestDialect.cs @@ -46,7 +46,20 @@ public bool HasIdentityNativeGenerator public virtual bool SupportsNullCharactersInUtfStrings => true; - public virtual bool SupportsSelectForUpdateOnOuterJoin => true; + /// + /// Some databases do not support SELECT FOR UPDATE + /// + public virtual bool SupportsSelectForUpdate => true; + + /// + /// Some databases do not support SELECT FOR UPDATE with paging + /// + public virtual bool SupportsSelectForUpdateWithPaging => SupportsSelectForUpdate; + + /// + /// Some databases do not support SELECT FOR UPDATE in conjunction with outer joins + /// + public virtual bool SupportsSelectForUpdateOnOuterJoin => SupportsSelectForUpdate; public virtual bool SupportsHavingWithoutGroupBy => true; diff --git a/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs b/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs index 09b259a98f5..9d8d065f60b 100644 --- a/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs +++ b/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs @@ -12,5 +12,10 @@ public FirebirdTestDialect(Dialect.Dialect dialect) : base(dialect) /// Non-integer arguments are rounded before the division takes place. So, “7.5 mod 2.5” gives 2 (8 mod 3), not 0. /// public override bool SupportsModuloOnDecimal => false; + + /// + /// Does not support update locks + /// + public override bool SupportsSelectForUpdate => false; } } diff --git a/src/NHibernate.Test/TestDialects/MsSql2008TestDialect.cs b/src/NHibernate.Test/TestDialects/MsSql2008TestDialect.cs new file mode 100644 index 00000000000..bf4c375e761 --- /dev/null +++ b/src/NHibernate.Test/TestDialects/MsSql2008TestDialect.cs @@ -0,0 +1,15 @@ +namespace NHibernate.Test.TestDialects +{ + public class MsSql2008TestDialect : TestDialect + { + public MsSql2008TestDialect(Dialect.Dialect dialect) + : base(dialect) + { + } + + /// + /// Does not support SELECT FOR UPDATE with paging + /// + public override bool SupportsSelectForUpdateWithPaging => false; + } +} diff --git a/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs b/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs index 89156042f21..062ca3396c2 100644 --- a/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs +++ b/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs @@ -30,5 +30,10 @@ public MsSqlCe40TestDialect(Dialect.Dialect dialect) : base(dialect) /// Modulo is not supported on real, float, money, and numeric data types. [ Data type = numeric ] /// public override bool SupportsModuloOnDecimal => false; + + /// + /// Does not support update locks + /// + public override bool SupportsSelectForUpdate => false; } } diff --git a/src/NHibernate.Test/TestDialects/Oracle10gTestDialect.cs b/src/NHibernate.Test/TestDialects/Oracle10gTestDialect.cs new file mode 100644 index 00000000000..ef117344223 --- /dev/null +++ b/src/NHibernate.Test/TestDialects/Oracle10gTestDialect.cs @@ -0,0 +1,14 @@ +namespace NHibernate.Test.TestDialects +{ + public class Oracle10gTestDialect : TestDialect + { + public Oracle10gTestDialect(Dialect.Dialect dialect) : base(dialect) + { + } + + /// + /// Does not support SELECT FOR UPDATE with paging + /// + public override bool SupportsSelectForUpdateWithPaging => false; + } +} diff --git a/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs b/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs index da04e00f2f0..934425ed12f 100644 --- a/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs +++ b/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs @@ -46,5 +46,10 @@ public override bool SupportsHavingWithoutGroupBy } public override bool SupportsModuloOnDecimal => false; + + /// + /// Does not support update locks + /// + public override bool SupportsSelectForUpdate => false; } } diff --git a/src/NHibernate.Test/TestDialects/SapSQLAnywhere17TestDialect.cs b/src/NHibernate.Test/TestDialects/SapSQLAnywhere17TestDialect.cs index 3d08930b062..35a35092ce3 100644 --- a/src/NHibernate.Test/TestDialects/SapSQLAnywhere17TestDialect.cs +++ b/src/NHibernate.Test/TestDialects/SapSQLAnywhere17TestDialect.cs @@ -38,5 +38,10 @@ public SapSQLAnywhere17TestDialect(Dialect.Dialect dialect) /// numeric. See https://stackoverflow.com/q/52558715/1178314. /// public override bool HasBrokenTypeInferenceOnSelectedParameters => true; + + /// + /// Does not support SELECT FOR UPDATE + /// + public override bool SupportsSelectForUpdate => false; } } diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index 55bf7733784..0d7c8182997 100644 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -2514,15 +2514,21 @@ public static IQueryable CacheRegion(this IQueryable query, string regi public static IQueryable Timeout(this IQueryable query, int timeout) => query.WithOptions(o => o.SetTimeout(timeout)); - public static IQueryable SetLockMode(this IQueryable query, LockMode lockMode) + public static IQueryable WithLock(this IQueryable query, LockMode lockMode) { - var method = ReflectHelper.GetMethod(() => SetLockMode(query, lockMode)); + var method = ReflectHelper.GetMethod(() => WithLock(query, lockMode)); var callExpression = Expression.Call(method, query.Expression, Expression.Constant(lockMode)); return new NhQueryable(query.Provider, callExpression); } + public static IEnumerable WithLock(this IEnumerable query, LockMode lockMode) + { + throw new InvalidOperationException( + "The NHibernate.Linq.LinqExtensionMethods.WithLock(IEnumerable, LockMode) method can only be used in a Linq expression."); + } + /// /// Allows to specify the parameter NHibernate type to use for a literal in a queryable expression. /// diff --git a/src/NHibernate/Linq/LockExpressionNode.cs b/src/NHibernate/Linq/LockExpressionNode.cs index 43fc7411284..622e260dcac 100644 --- a/src/NHibernate/Linq/LockExpressionNode.cs +++ b/src/NHibernate/Linq/LockExpressionNode.cs @@ -1,19 +1,23 @@ +using System; using System.Linq.Expressions; using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing.Structure.IntermediateModel; namespace NHibernate.Linq { internal class LockExpressionNode : ResultOperatorExpressionNodeBase { - private readonly MethodCallExpressionParseInfo _parseInfo; + private static readonly ParameterExpression Parameter = Expression.Parameter(typeof(object)); + private readonly ConstantExpression _lockMode; + private readonly ResolvedExpressionCache _cache; public LockExpressionNode(MethodCallExpressionParseInfo parseInfo, ConstantExpression lockMode) : base(parseInfo, null, null) { - _parseInfo = parseInfo; _lockMode = lockMode; + _cache = new ResolvedExpressionCache(this); } public override Expression Resolve(ParameterExpression inputParameter, Expression expressionToBeResolved, ClauseGenerationContext clauseGenerationContext) @@ -23,7 +27,15 @@ public override Expression Resolve(ParameterExpression inputParameter, Expressio protected override ResultOperatorBase CreateResultOperator(ClauseGenerationContext clauseGenerationContext) { - return new LockResultOperator(_parseInfo, _lockMode); + //Resolve identity expression (_=>_). Normally this would be resolved into QuerySourceReferenceExpression. + + var expression = _cache.GetOrCreate( + r => r.GetResolvedExpression(Parameter, Parameter, clauseGenerationContext)); + + if (!(expression is QuerySourceReferenceExpression qsrExpression)) + throw new NotSupportedException($"WithLock is not supported on {expression}"); + + return new LockResultOperator(qsrExpression, _lockMode); } } } diff --git a/src/NHibernate/Linq/LockResultOperator.cs b/src/NHibernate/Linq/LockResultOperator.cs index 641c91af141..2fc841258d4 100644 --- a/src/NHibernate/Linq/LockResultOperator.cs +++ b/src/NHibernate/Linq/LockResultOperator.cs @@ -1,6 +1,7 @@ using System; using System.Linq.Expressions; using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.StreamedData; using Remotion.Linq.Parsing.Structure.IntermediateModel; @@ -8,12 +9,15 @@ namespace NHibernate.Linq { internal class LockResultOperator : ResultOperatorBase { - public MethodCallExpressionParseInfo ParseInfo { get; } + private QuerySourceReferenceExpression _qsrExpression; + + public IQuerySource QuerySource => _qsrExpression.ReferencedQuerySource; + public ConstantExpression LockMode { get; } - public LockResultOperator(MethodCallExpressionParseInfo parseInfo, ConstantExpression lockMode) + public LockResultOperator(QuerySourceReferenceExpression qsrExpression, ConstantExpression lockMode) { - ParseInfo = parseInfo; + _qsrExpression = qsrExpression; LockMode = lockMode; } @@ -34,6 +38,7 @@ public override ResultOperatorBase Clone(CloneContext cloneContext) public override void TransformExpressions(Func transformation) { + _qsrExpression = (QuerySourceReferenceExpression) transformation(_qsrExpression); } } } diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs index 0e3b698dd79..af920331057 100644 --- a/src/NHibernate/Linq/NhRelinqQueryParser.cs +++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs @@ -1,4 +1,6 @@ using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using System.Reflection; using NHibernate.Linq.ExpressionTransformers; @@ -81,7 +83,11 @@ public NHibernateNodeTypeProvider() new[] { ReflectHelper.GetMethodDefinition(() => EagerFetchingExtensionMethods.ThenFetchMany(null, null)) }, typeof(ThenFetchManyExpressionNode)); methodInfoRegistry.Register( - new[] { ReflectHelper.GetMethodDefinition(() => LinqExtensionMethods.SetLockMode(null, LockMode.Read)) }, + new[] + { + ReflectHelper.GetMethodDefinition(() => default(IQueryable).WithLock(LockMode.Read)), + ReflectHelper.GetMethodDefinition(() => default(IEnumerable).WithLock(LockMode.Read)) + }, typeof(LockExpressionNode)); var nodeTypeProvider = ExpressionTreeParser.CreateDefaultNodeTypeProvider(); diff --git a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs index 80cf888d41e..be66bfa727f 100644 --- a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs +++ b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs @@ -15,6 +15,7 @@ public class QueryReferenceExpressionFlattener : RelinqExpressionVisitor internal static readonly System.Type[] FlattenableResultOperators = { + typeof(LockResultOperator), typeof(FetchOneRequest), typeof(FetchManyRequest) }; diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs index 69420cc8888..5ddbe360b8e 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs @@ -4,7 +4,8 @@ internal class ProcessLock : IResultOperatorProcessor { public void Process(LockResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) { - tree.AddAdditionalCriteria((q, p) => q.SetLockMode(queryModelVisitor.Model.MainFromClause.ItemName, (LockMode)resultOperator.LockMode.Value)); + var alias = queryModelVisitor.VisitorParameters.QuerySourceNamer.GetName(resultOperator.QuerySource); + tree.AddAdditionalCriteria((q, p) => q.SetLockMode(alias, (LockMode) resultOperator.LockMode.Value)); } } } diff --git a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs index f48cc1a77de..d47b7cb47d1 100644 --- a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs +++ b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs @@ -12,9 +12,9 @@ public class SubQueryFromClauseFlattener : NhQueryModelVisitorBase { private static readonly System.Type[] FlattenableResultOperators = { - typeof (LockResultOperator), - typeof (FetchOneRequest), - typeof (FetchManyRequest) + typeof(LockResultOperator), + typeof(FetchOneRequest), + typeof(FetchManyRequest) }; public static void ReWrite(QueryModel queryModel) @@ -24,16 +24,14 @@ public static void ReWrite(QueryModel queryModel) public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { - var subQueryExpression = fromClause.FromExpression as SubQueryExpression; - if (subQueryExpression != null) + if (fromClause.FromExpression is SubQueryExpression subQueryExpression) FlattenSubQuery(subQueryExpression, fromClause, queryModel, index + 1); base.VisitAdditionalFromClause(fromClause, queryModel, index); } public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) { - var subQueryExpression = fromClause.FromExpression as SubQueryExpression; - if (subQueryExpression != null) + if (fromClause.FromExpression is SubQueryExpression subQueryExpression) FlattenSubQuery(subQueryExpression, fromClause, queryModel, 0); base.VisitMainFromClause(fromClause, queryModel); }