diff --git a/src/NHibernate.Test/Async/Linq/QueryLock.cs b/src/NHibernate.Test/Async/Linq/QueryLock.cs new file mode 100644 index 00000000000..4ecb0404c6c --- /dev/null +++ b/src/NHibernate.Test/Async/Linq/QueryLock.cs @@ -0,0 +1,266 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Linq; +using System.Transactions; +using NHibernate.Dialect; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; +using NHibernate.Engine; +using NHibernate.Exceptions; +using NHibernate.Linq; +using NUnit.Framework; + + +namespace NHibernate.Test.Linq +{ + using System.Threading.Tasks; + using System.Threading; + [TestFixture] + public class QueryLockAsync : LinqTestCase + { + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return TestDialect.SupportsSelectForUpdate; + } + + protected override bool AppliesTo(ISessionFactoryImplementor factory) + { + return !(factory.ConnectionProvider.Driver is OdbcDriver); + } + + [Test] + public async Task CanSetLockLinqQueriesOuterAsync() + { + using (session.BeginTransaction()) + { + var result = await ((from e in db.Customers + select e).WithLock(LockMode.Upgrade).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(91)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + await (AssertSeparateTransactionIsLockedOutAsync(result[0].CustomerId)); + } + } + + [Test] + public async Task CanSetLockLinqQueriesAsync() + { + using (session.BeginTransaction()) + { + var result = await ((from e in db.Customers.WithLock(LockMode.Upgrade) + select e).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(91)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + await (AssertSeparateTransactionIsLockedOutAsync(result[0].CustomerId)); + } + } + + [Test] + public async Task CanSetLockOnJoinHqlAsync() + { + using (session.BeginTransaction()) + { + await (session + .CreateQuery("select o from Customer c join c.Orders o") + .SetLockMode("o", LockMode.Upgrade) + .ListAsync()); + } + } + + [Test] + public async Task CanSetLockOnJoinAsync() + { + using (session.BeginTransaction()) + { + var result = await ((from c in db.Customers + from o in c.Orders.WithLock(LockMode.Upgrade) + select o).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public async Task CanSetLockOnJoinOuterAsync() + { + using (session.BeginTransaction()) + { + var result = await ((from c in db.Customers + from o in c.Orders + select o).WithLock(LockMode.Upgrade).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnJoinOuterNotSupportedAsync() + { + using (session.BeginTransaction()) + { + var query = ( + from c in db.Customers + from o in c.Orders + select new {o, c} + ).WithLock(LockMode.Upgrade); + + Assert.ThrowsAsync(() => query.ToListAsync()); + } + } + + [Test] + public async Task CanSetLockOnJoinOuter2HqlAsync() + { + using (session.BeginTransaction()) + { + await (session + .CreateQuery("select o, c from Customer c join c.Orders o") + .SetLockMode("o", LockMode.Upgrade) + .SetLockMode("c", LockMode.Upgrade) + .ListAsync()); + } + } + + [Test] + public async Task CanSetLockOnBothJoinAndMainAsync() + { + using (session.BeginTransaction()) + { + var result = await (( + from c in db.Customers.WithLock(LockMode.Upgrade) + from o in c.Orders.WithLock(LockMode.Upgrade) + select new {o, c} + ).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0].o), Is.EqualTo(LockMode.Upgrade)); + Assert.That(session.GetCurrentLockMode(result[0].c), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public async Task CanSetLockOnBothJoinAndMainComplexAsync() + { + using (session.BeginTransaction()) + { + var result = await (( + from c in db.Customers.Where(x => true).WithLock(LockMode.Upgrade) + from o in c.Orders.Where(x => true).WithLock(LockMode.Upgrade) + select new {o, c} + ).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0].o), Is.EqualTo(LockMode.Upgrade)); + Assert.That(session.GetCurrentLockMode(result[0].c), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public async Task CanSetLockOnLinqPagingQueryAsync() + { + Assume.That(TestDialect.SupportsSelectForUpdateWithPaging, Is.True, "Dialect does not support locking in subqueries"); + using (session.BeginTransaction()) + { + var result = await ((from e in db.Customers + select e).Skip(5).Take(5).WithLock(LockMode.Upgrade).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(5)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + + await (AssertSeparateTransactionIsLockedOutAsync(result[0].CustomerId)); + } + } + + [Test] + public async Task CanLockBeforeSkipOnLinqOrderedPageQueryAsync() + { + Assume.That(TestDialect.SupportsSelectForUpdateWithPaging, Is.True, "Dialect does not support locking in subqueries"); + using (session.BeginTransaction()) + { + var result = await ((from e in db.Customers + orderby e.CompanyName + select e) + .WithLock(LockMode.Upgrade).Skip(5).Take(5).ToListAsync()); + + Assert.That(result, Has.Count.EqualTo(5)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + + await (AssertSeparateTransactionIsLockedOutAsync(result[0].CustomerId)); + } + } + + private Task AssertSeparateTransactionIsLockedOutAsync(string customerId, CancellationToken cancellationToken = default(CancellationToken)) + { + try + { + using (new TransactionScope(TransactionScopeOption.Suppress, TransactionScopeAsyncFlowOption.Enabled)) + using (var s2 = OpenSession()) + using (s2.BeginTransaction()) + { + // TODO: We should try to verify that the exception actually IS a locking failure and not something unrelated. + Assert.ThrowsAsync( + async () => + { + var result2 = await (( + from e in s2.Query() + where e.CustomerId == customerId + select e + ).WithLock(LockMode.UpgradeNoWait) + .WithOptions(o => o.SetTimeout(5)) + .ToListAsync(cancellationToken)); + Assert.That(result2, Is.Not.Null); + }, + "Expected an exception to indicate locking failure due to already locked."); + } + return Task.CompletedTask; + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + [Test] + [Description("Verify that different lock modes are respected even if the query is otherwise exactly the same.")] + public async Task CanChangeLockModeForQueryAsync() + { + // Limit to a few dialects where we know the "nowait" keyword is used to make life easier. + Assume.That(Dialect is MsSql2000Dialect || Dialect is Oracle8iDialect || Dialect is PostgreSQL81Dialect); + + using (session.BeginTransaction()) + { + var result = await (BuildQueryableAllCustomers(db.Customers, LockMode.Upgrade).ToListAsync()); + Assert.That(result, Has.Count.EqualTo(91)); + + using (var logSpy = new SqlLogSpy()) + { + // Only difference in query is the lockmode - make sure it gets picked up. + var result2 = await (BuildQueryableAllCustomers(session.Query(), LockMode.UpgradeNoWait) + .ToListAsync()); + Assert.That(result2, Has.Count.EqualTo(91)); + + Assert.That(logSpy.GetWholeLog().ToLower(), Does.Contain("nowait")); + } + } + } + + private static IQueryable BuildQueryableAllCustomers( + IQueryable dbCustomers, + LockMode lockMode) + { + return (from e in dbCustomers select e).WithLock(lockMode).WithOptions(o => o.SetTimeout(5)); + } + } +} diff --git a/src/NHibernate.Test/Linq/QueryLock.cs b/src/NHibernate.Test/Linq/QueryLock.cs new file mode 100644 index 00000000000..ab6396ebc43 --- /dev/null +++ b/src/NHibernate.Test/Linq/QueryLock.cs @@ -0,0 +1,246 @@ +using System; +using System.Linq; +using System.Transactions; +using NHibernate.Dialect; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; +using NHibernate.Engine; +using NHibernate.Exceptions; +using NHibernate.Linq; +using NUnit.Framework; + + +namespace NHibernate.Test.Linq +{ + [TestFixture] + public class QueryLock : LinqTestCase + { + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return TestDialect.SupportsSelectForUpdate; + } + + protected override bool AppliesTo(ISessionFactoryImplementor factory) + { + return !(factory.ConnectionProvider.Driver is OdbcDriver); + } + + [Test] + public void CanSetLockLinqQueriesOuter() + { + using (session.BeginTransaction()) + { + var result = (from e in db.Customers + select e).WithLock(LockMode.Upgrade).ToList(); + + Assert.That(result, Has.Count.EqualTo(91)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); + } + } + + [Test] + public void CanSetLockLinqQueries() + { + using (session.BeginTransaction()) + { + var result = (from e in db.Customers.WithLock(LockMode.Upgrade) + select e).ToList(); + + Assert.That(result, Has.Count.EqualTo(91)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); + } + } + + [Test] + public void CanSetLockOnJoinHql() + { + using (session.BeginTransaction()) + { + session + .CreateQuery("select o from Customer c join c.Orders o") + .SetLockMode("o", LockMode.Upgrade) + .List(); + } + } + + [Test] + public void CanSetLockOnJoin() + { + using (session.BeginTransaction()) + { + var result = (from c in db.Customers + from o in c.Orders.WithLock(LockMode.Upgrade) + select o).ToList(); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnJoinOuter() + { + using (session.BeginTransaction()) + { + var result = (from c in db.Customers + from o in c.Orders + select o).WithLock(LockMode.Upgrade).ToList(); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnJoinOuterNotSupported() + { + using (session.BeginTransaction()) + { + var query = ( + from c in db.Customers + from o in c.Orders + select new {o, c} + ).WithLock(LockMode.Upgrade); + + Assert.Throws(() => query.ToList()); + } + } + + [Test] + public void CanSetLockOnJoinOuter2Hql() + { + using (session.BeginTransaction()) + { + session + .CreateQuery("select o, c from Customer c join c.Orders o") + .SetLockMode("o", LockMode.Upgrade) + .SetLockMode("c", LockMode.Upgrade) + .List(); + } + } + + [Test] + public void CanSetLockOnBothJoinAndMain() + { + using (session.BeginTransaction()) + { + var result = ( + from c in db.Customers.WithLock(LockMode.Upgrade) + from o in c.Orders.WithLock(LockMode.Upgrade) + select new {o, c} + ).ToList(); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0].o), Is.EqualTo(LockMode.Upgrade)); + Assert.That(session.GetCurrentLockMode(result[0].c), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnBothJoinAndMainComplex() + { + using (session.BeginTransaction()) + { + var result = ( + from c in db.Customers.Where(x => true).WithLock(LockMode.Upgrade) + from o in c.Orders.Where(x => true).WithLock(LockMode.Upgrade) + select new {o, c} + ).ToList(); + + Assert.That(result, Has.Count.EqualTo(830)); + Assert.That(session.GetCurrentLockMode(result[0].o), Is.EqualTo(LockMode.Upgrade)); + Assert.That(session.GetCurrentLockMode(result[0].c), Is.EqualTo(LockMode.Upgrade)); + } + } + + [Test] + public void CanSetLockOnLinqPagingQuery() + { + Assume.That(TestDialect.SupportsSelectForUpdateWithPaging, Is.True, "Dialect does not support locking in subqueries"); + using (session.BeginTransaction()) + { + var result = (from e in db.Customers + select e).Skip(5).Take(5).WithLock(LockMode.Upgrade).ToList(); + + Assert.That(result, Has.Count.EqualTo(5)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); + } + } + + [Test] + public void CanLockBeforeSkipOnLinqOrderedPageQuery() + { + Assume.That(TestDialect.SupportsSelectForUpdateWithPaging, Is.True, "Dialect does not support locking in subqueries"); + using (session.BeginTransaction()) + { + var result = (from e in db.Customers + orderby e.CompanyName + select e) + .WithLock(LockMode.Upgrade).Skip(5).Take(5).ToList(); + + Assert.That(result, Has.Count.EqualTo(5)); + Assert.That(session.GetCurrentLockMode(result[0]), Is.EqualTo(LockMode.Upgrade)); + + AssertSeparateTransactionIsLockedOut(result[0].CustomerId); + } + } + + private void AssertSeparateTransactionIsLockedOut(string customerId) + { + using (new TransactionScope(TransactionScopeOption.Suppress)) + using (var s2 = OpenSession()) + using (s2.BeginTransaction()) + { + // TODO: We should try to verify that the exception actually IS a locking failure and not something unrelated. + Assert.Throws( + () => + { + var result2 = ( + from e in s2.Query() + where e.CustomerId == customerId + select e + ).WithLock(LockMode.UpgradeNoWait) + .WithOptions(o => o.SetTimeout(5)) + .ToList(); + Assert.That(result2, Is.Not.Null); + }, + "Expected an exception to indicate locking failure due to already locked."); + } + } + + [Test] + [Description("Verify that different lock modes are respected even if the query is otherwise exactly the same.")] + public void CanChangeLockModeForQuery() + { + // Limit to a few dialects where we know the "nowait" keyword is used to make life easier. + Assume.That(Dialect is MsSql2000Dialect || Dialect is Oracle8iDialect || Dialect is PostgreSQL81Dialect); + + using (session.BeginTransaction()) + { + var result = BuildQueryableAllCustomers(db.Customers, LockMode.Upgrade).ToList(); + Assert.That(result, Has.Count.EqualTo(91)); + + using (var logSpy = new SqlLogSpy()) + { + // Only difference in query is the lockmode - make sure it gets picked up. + var result2 = BuildQueryableAllCustomers(session.Query(), LockMode.UpgradeNoWait) + .ToList(); + Assert.That(result2, Has.Count.EqualTo(91)); + + Assert.That(logSpy.GetWholeLog().ToLower(), Does.Contain("nowait")); + } + } + } + + private static IQueryable BuildQueryableAllCustomers( + IQueryable dbCustomers, + LockMode lockMode) + { + return (from e in dbCustomers select e).WithLock(lockMode).WithOptions(o => o.SetTimeout(5)); + } + } +} diff --git a/src/NHibernate.Test/TestDialect.cs b/src/NHibernate.Test/TestDialect.cs index 246c3078f9d..1e23fae97c3 100644 --- a/src/NHibernate.Test/TestDialect.cs +++ b/src/NHibernate.Test/TestDialect.cs @@ -46,7 +46,20 @@ public bool HasIdentityNativeGenerator public virtual bool SupportsNullCharactersInUtfStrings => true; - public virtual bool SupportsSelectForUpdateOnOuterJoin => true; + /// + /// Some databases do not support SELECT FOR UPDATE + /// + public virtual bool SupportsSelectForUpdate => true; + + /// + /// Some databases do not support SELECT FOR UPDATE with paging + /// + public virtual bool SupportsSelectForUpdateWithPaging => SupportsSelectForUpdate; + + /// + /// Some databases do not support SELECT FOR UPDATE in conjunction with outer joins + /// + public virtual bool SupportsSelectForUpdateOnOuterJoin => SupportsSelectForUpdate; public virtual bool SupportsHavingWithoutGroupBy => true; diff --git a/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs b/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs index 09b259a98f5..9d8d065f60b 100644 --- a/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs +++ b/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs @@ -12,5 +12,10 @@ public FirebirdTestDialect(Dialect.Dialect dialect) : base(dialect) /// Non-integer arguments are rounded before the division takes place. So, “7.5 mod 2.5” gives 2 (8 mod 3), not 0. /// public override bool SupportsModuloOnDecimal => false; + + /// + /// Does not support update locks + /// + public override bool SupportsSelectForUpdate => false; } } diff --git a/src/NHibernate.Test/TestDialects/MsSql2008TestDialect.cs b/src/NHibernate.Test/TestDialects/MsSql2008TestDialect.cs new file mode 100644 index 00000000000..bf4c375e761 --- /dev/null +++ b/src/NHibernate.Test/TestDialects/MsSql2008TestDialect.cs @@ -0,0 +1,15 @@ +namespace NHibernate.Test.TestDialects +{ + public class MsSql2008TestDialect : TestDialect + { + public MsSql2008TestDialect(Dialect.Dialect dialect) + : base(dialect) + { + } + + /// + /// Does not support SELECT FOR UPDATE with paging + /// + public override bool SupportsSelectForUpdateWithPaging => false; + } +} diff --git a/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs b/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs index 89156042f21..062ca3396c2 100644 --- a/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs +++ b/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs @@ -30,5 +30,10 @@ public MsSqlCe40TestDialect(Dialect.Dialect dialect) : base(dialect) /// Modulo is not supported on real, float, money, and numeric data types. [ Data type = numeric ] /// public override bool SupportsModuloOnDecimal => false; + + /// + /// Does not support update locks + /// + public override bool SupportsSelectForUpdate => false; } } diff --git a/src/NHibernate.Test/TestDialects/Oracle10gTestDialect.cs b/src/NHibernate.Test/TestDialects/Oracle10gTestDialect.cs new file mode 100644 index 00000000000..ef117344223 --- /dev/null +++ b/src/NHibernate.Test/TestDialects/Oracle10gTestDialect.cs @@ -0,0 +1,14 @@ +namespace NHibernate.Test.TestDialects +{ + public class Oracle10gTestDialect : TestDialect + { + public Oracle10gTestDialect(Dialect.Dialect dialect) : base(dialect) + { + } + + /// + /// Does not support SELECT FOR UPDATE with paging + /// + public override bool SupportsSelectForUpdateWithPaging => false; + } +} diff --git a/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs b/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs index da04e00f2f0..934425ed12f 100644 --- a/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs +++ b/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs @@ -46,5 +46,10 @@ public override bool SupportsHavingWithoutGroupBy } public override bool SupportsModuloOnDecimal => false; + + /// + /// Does not support update locks + /// + public override bool SupportsSelectForUpdate => false; } } diff --git a/src/NHibernate.Test/TestDialects/SapSQLAnywhere17TestDialect.cs b/src/NHibernate.Test/TestDialects/SapSQLAnywhere17TestDialect.cs index 3d08930b062..35a35092ce3 100644 --- a/src/NHibernate.Test/TestDialects/SapSQLAnywhere17TestDialect.cs +++ b/src/NHibernate.Test/TestDialects/SapSQLAnywhere17TestDialect.cs @@ -38,5 +38,10 @@ public SapSQLAnywhere17TestDialect(Dialect.Dialect dialect) /// numeric. See https://stackoverflow.com/q/52558715/1178314. /// public override bool HasBrokenTypeInferenceOnSelectedParameters => true; + + /// + /// Does not support SELECT FOR UPDATE + /// + public override bool SupportsSelectForUpdate => false; } } diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index 18473d80128..0d7c8182997 100644 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -2514,6 +2514,21 @@ public static IQueryable CacheRegion(this IQueryable query, string regi public static IQueryable Timeout(this IQueryable query, int timeout) => query.WithOptions(o => o.SetTimeout(timeout)); + public static IQueryable WithLock(this IQueryable query, LockMode lockMode) + { + var method = ReflectHelper.GetMethod(() => WithLock(query, lockMode)); + + var callExpression = Expression.Call(method, query.Expression, Expression.Constant(lockMode)); + + return new NhQueryable(query.Provider, callExpression); + } + + public static IEnumerable WithLock(this IEnumerable query, LockMode lockMode) + { + throw new InvalidOperationException( + "The NHibernate.Linq.LinqExtensionMethods.WithLock(IEnumerable, LockMode) method can only be used in a Linq expression."); + } + /// /// Allows to specify the parameter NHibernate type to use for a literal in a queryable expression. /// diff --git a/src/NHibernate/Linq/LockExpressionNode.cs b/src/NHibernate/Linq/LockExpressionNode.cs new file mode 100644 index 00000000000..622e260dcac --- /dev/null +++ b/src/NHibernate/Linq/LockExpressionNode.cs @@ -0,0 +1,41 @@ +using System; +using System.Linq.Expressions; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing.Structure.IntermediateModel; + +namespace NHibernate.Linq +{ + internal class LockExpressionNode : ResultOperatorExpressionNodeBase + { + private static readonly ParameterExpression Parameter = Expression.Parameter(typeof(object)); + + private readonly ConstantExpression _lockMode; + private readonly ResolvedExpressionCache _cache; + + public LockExpressionNode(MethodCallExpressionParseInfo parseInfo, ConstantExpression lockMode) + : base(parseInfo, null, null) + { + _lockMode = lockMode; + _cache = new ResolvedExpressionCache(this); + } + + public override Expression Resolve(ParameterExpression inputParameter, Expression expressionToBeResolved, ClauseGenerationContext clauseGenerationContext) + { + return Source.Resolve(inputParameter, expressionToBeResolved, clauseGenerationContext); + } + + protected override ResultOperatorBase CreateResultOperator(ClauseGenerationContext clauseGenerationContext) + { + //Resolve identity expression (_=>_). Normally this would be resolved into QuerySourceReferenceExpression. + + var expression = _cache.GetOrCreate( + r => r.GetResolvedExpression(Parameter, Parameter, clauseGenerationContext)); + + if (!(expression is QuerySourceReferenceExpression qsrExpression)) + throw new NotSupportedException($"WithLock is not supported on {expression}"); + + return new LockResultOperator(qsrExpression, _lockMode); + } + } +} diff --git a/src/NHibernate/Linq/LockResultOperator.cs b/src/NHibernate/Linq/LockResultOperator.cs new file mode 100644 index 00000000000..2fc841258d4 --- /dev/null +++ b/src/NHibernate/Linq/LockResultOperator.cs @@ -0,0 +1,44 @@ +using System; +using System.Linq.Expressions; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.StreamedData; +using Remotion.Linq.Parsing.Structure.IntermediateModel; + +namespace NHibernate.Linq +{ + internal class LockResultOperator : ResultOperatorBase + { + private QuerySourceReferenceExpression _qsrExpression; + + public IQuerySource QuerySource => _qsrExpression.ReferencedQuerySource; + + public ConstantExpression LockMode { get; } + + public LockResultOperator(QuerySourceReferenceExpression qsrExpression, ConstantExpression lockMode) + { + _qsrExpression = qsrExpression; + LockMode = lockMode; + } + + public override IStreamedData ExecuteInMemory(IStreamedData input) + { + throw new NotImplementedException(); + } + + public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo) + { + return inputInfo; + } + + public override ResultOperatorBase Clone(CloneContext cloneContext) + { + throw new NotImplementedException(); + } + + public override void TransformExpressions(Func transformation) + { + _qsrExpression = (QuerySourceReferenceExpression) transformation(_qsrExpression); + } + } +} diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs index 01a72df3b81..af920331057 100644 --- a/src/NHibernate/Linq/NhRelinqQueryParser.cs +++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs @@ -1,4 +1,5 @@ using System.Collections; +using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -7,11 +8,9 @@ using NHibernate.Util; using Remotion.Linq; using Remotion.Linq.EagerFetching.Parsing; -using Remotion.Linq.Parsing.ExpressionVisitors; using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; using Remotion.Linq.Parsing.Structure; using Remotion.Linq.Parsing.Structure.ExpressionTreeProcessors; -using Remotion.Linq.Parsing.Structure.IntermediateModel; using Remotion.Linq.Parsing.Structure.NodeTypeProviders; namespace NHibernate.Linq @@ -83,6 +82,13 @@ public NHibernateNodeTypeProvider() methodInfoRegistry.Register( new[] { ReflectHelper.GetMethodDefinition(() => EagerFetchingExtensionMethods.ThenFetchMany(null, null)) }, typeof(ThenFetchManyExpressionNode)); + methodInfoRegistry.Register( + new[] + { + ReflectHelper.GetMethodDefinition(() => default(IQueryable).WithLock(LockMode.Read)), + ReflectHelper.GetMethodDefinition(() => default(IEnumerable).WithLock(LockMode.Read)) + }, + typeof(LockExpressionNode)); var nodeTypeProvider = ExpressionTreeParser.CreateDefaultNodeTypeProvider(); nodeTypeProvider.InnerProviders.Add(methodInfoRegistry); diff --git a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs index 80cf888d41e..be66bfa727f 100644 --- a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs +++ b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs @@ -15,6 +15,7 @@ public class QueryReferenceExpressionFlattener : RelinqExpressionVisitor internal static readonly System.Type[] FlattenableResultOperators = { + typeof(LockResultOperator), typeof(FetchOneRequest), typeof(FetchManyRequest) }; diff --git a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs index 9f1142fbe2a..b05dddca40c 100644 --- a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs +++ b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs @@ -68,6 +68,7 @@ private class ResultOperatorExpressionRewriter : RelinqExpressionVisitor typeof(OfTypeResultOperator), typeof(CastResultOperator), typeof(AsQueryableResultOperator), + typeof(LockResultOperator), }; private readonly List resultOperators = new List(); diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 0fa769938a5..3567ad7c970 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -141,6 +141,7 @@ static QueryModelVisitor() ResultOperatorMap.Add(); ResultOperatorMap.Add(); ResultOperatorMap.Add(); + ResultOperatorMap.Add(); } private QueryModelVisitor(VisitorParameters visitorParameters, bool root, QueryModel queryModel, diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs new file mode 100644 index 00000000000..5ddbe360b8e --- /dev/null +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs @@ -0,0 +1,11 @@ +namespace NHibernate.Linq.Visitors.ResultOperatorProcessors +{ + internal class ProcessLock : IResultOperatorProcessor + { + public void Process(LockResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) + { + var alias = queryModelVisitor.VisitorParameters.QuerySourceNamer.GetName(resultOperator.QuerySource); + tree.AddAdditionalCriteria((q, p) => q.SetLockMode(alias, (LockMode) resultOperator.LockMode.Value)); + } + } +} diff --git a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs index d8d1f2ca58d..d47b7cb47d1 100644 --- a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs +++ b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs @@ -12,8 +12,9 @@ public class SubQueryFromClauseFlattener : NhQueryModelVisitorBase { private static readonly System.Type[] FlattenableResultOperators = { - typeof (FetchOneRequest), - typeof (FetchManyRequest) + typeof(LockResultOperator), + typeof(FetchOneRequest), + typeof(FetchManyRequest) }; public static void ReWrite(QueryModel queryModel) @@ -23,16 +24,14 @@ public static void ReWrite(QueryModel queryModel) public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { - var subQueryExpression = fromClause.FromExpression as SubQueryExpression; - if (subQueryExpression != null) + if (fromClause.FromExpression is SubQueryExpression subQueryExpression) FlattenSubQuery(subQueryExpression, fromClause, queryModel, index + 1); base.VisitAdditionalFromClause(fromClause, queryModel, index); } public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) { - var subQueryExpression = fromClause.FromExpression as SubQueryExpression; - if (subQueryExpression != null) + if (fromClause.FromExpression is SubQueryExpression subQueryExpression) FlattenSubQuery(subQueryExpression, fromClause, queryModel, 0); base.VisitMainFromClause(fromClause, queryModel); }