diff --git a/src/NHibernate.Test/Async/Linq/ConstantTest.cs b/src/NHibernate.Test/Async/Linq/ConstantTest.cs index 04b1f02fbf3..20c7daab902 100644 --- a/src/NHibernate.Test/Async/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Async/Linq/ConstantTest.cs @@ -252,9 +252,10 @@ public async Task DmlPlansAreCachedAsync() { await (db.Customers.Where(c => c.CustomerId == "UNKNOWN").UpdateAsync(x => new Customer {CompanyName = "Constant1"})); await (db.Customers.Where(c => c.CustomerId == "ALFKI").UpdateAsync(x => new Customer {CompanyName = x.CompanyName})); + await (db.Customers.Where(c => c.CustomerId == "UNKNOWN").UpdateAsync(x => new Customer {ContactName = "Constant1"})); Assert.That( cache, - Has.Count.EqualTo(2), + Has.Count.EqualTo(3), "Query plans should be cached."); using (var spy = new LogSpy(queryPlanCacheType)) @@ -264,6 +265,7 @@ public async Task DmlPlansAreCachedAsync() { await (db.Customers.Where(c => c.CustomerId == "ANATR").UpdateAsync(x => new Customer {CompanyName = x.CompanyName})); await (db.Customers.Where(c => c.CustomerId == "UNKNOWN").UpdateAsync(x => new Customer {CompanyName = "Constant2"})); + await (db.Customers.Where(c => c.CustomerId == "UNKNOWN").UpdateAsync(x => new Customer {ContactName = "Constant2"})); var sqlEvents = sqlSpy.Appender.GetEvents(); Assert.That( @@ -272,11 +274,17 @@ public async Task DmlPlansAreCachedAsync() "Unexpected constant parameter value"); Assert.That( sqlEvents[1].RenderedMessage, - Does.Contain("UNKNOWN").And.Contain("Constant2").And.Not.Contain("Constant1"), + Does.Contain("UNKNOWN").And.Contain("Constant2").And.Contain("CompanyName").IgnoreCase + .And.Not.Contain("Constant1"), + "Unexpected constant parameter value"); + Assert.That( + sqlEvents[2].RenderedMessage, + Does.Contain("UNKNOWN").And.Contain("Constant2").And.Contain("ContactName").IgnoreCase + .And.Not.Contain("Constant1"), "Unexpected constant parameter value"); } - Assert.That(cache, Has.Count.EqualTo(2), "Additional queries should not cause a plan to be cached."); + Assert.That(cache, Has.Count.EqualTo(3), "Additional queries should not cause a plan to be cached."); Assert.That( spy.GetWholeLog(), Does @@ -284,7 +292,7 @@ public async Task DmlPlansAreCachedAsync() .And.Not.Contain("unable to locate HQL query plan in cache")); await (db.Customers.Where(c => c.CustomerId == "ANATR").UpdateAsync(x => new Customer {ContactName = x.ContactName})); - Assert.That(cache, Has.Count.EqualTo(3), "Query should be cached"); + Assert.That(cache, Has.Count.EqualTo(4), "Query should be cached"); } } } diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs index 8afb0aa77d3..e8f77edac8c 100644 --- a/src/NHibernate.Test/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Linq/ConstantTest.cs @@ -273,9 +273,10 @@ public void DmlPlansAreCached() { db.Customers.Where(c => c.CustomerId == "UNKNOWN").Update(x => new Customer {CompanyName = "Constant1"}); db.Customers.Where(c => c.CustomerId == "ALFKI").Update(x => new Customer {CompanyName = x.CompanyName}); + db.Customers.Where(c => c.CustomerId == "UNKNOWN").Update(x => new Customer {ContactName = "Constant1"}); Assert.That( cache, - Has.Count.EqualTo(2), + Has.Count.EqualTo(3), "Query plans should be cached."); using (var spy = new LogSpy(queryPlanCacheType)) @@ -285,6 +286,7 @@ public void DmlPlansAreCached() { db.Customers.Where(c => c.CustomerId == "ANATR").Update(x => new Customer {CompanyName = x.CompanyName}); db.Customers.Where(c => c.CustomerId == "UNKNOWN").Update(x => new Customer {CompanyName = "Constant2"}); + db.Customers.Where(c => c.CustomerId == "UNKNOWN").Update(x => new Customer {ContactName = "Constant2"}); var sqlEvents = sqlSpy.Appender.GetEvents(); Assert.That( @@ -293,11 +295,17 @@ public void DmlPlansAreCached() "Unexpected constant parameter value"); Assert.That( sqlEvents[1].RenderedMessage, - Does.Contain("UNKNOWN").And.Contain("Constant2").And.Not.Contain("Constant1"), + Does.Contain("UNKNOWN").And.Contain("Constant2").And.Contain("CompanyName").IgnoreCase + .And.Not.Contain("Constant1"), + "Unexpected constant parameter value"); + Assert.That( + sqlEvents[2].RenderedMessage, + Does.Contain("UNKNOWN").And.Contain("Constant2").And.Contain("ContactName").IgnoreCase + .And.Not.Contain("Constant1"), "Unexpected constant parameter value"); } - Assert.That(cache, Has.Count.EqualTo(2), "Additional queries should not cause a plan to be cached."); + Assert.That(cache, Has.Count.EqualTo(3), "Additional queries should not cause a plan to be cached."); Assert.That( spy.GetWholeLog(), Does @@ -305,7 +313,7 @@ public void DmlPlansAreCached() .And.Not.Contain("unable to locate HQL query plan in cache")); db.Customers.Where(c => c.CustomerId == "ANATR").Update(x => new Customer {ContactName = x.ContactName}); - Assert.That(cache, Has.Count.EqualTo(3), "Query should be cached"); + Assert.That(cache, Has.Count.EqualTo(4), "Query should be cached"); } } } diff --git a/src/NHibernate/Linq/DmlExpressionRewriter.cs b/src/NHibernate/Linq/DmlExpressionRewriter.cs index baf8e24f2c6..782318a64a7 100644 --- a/src/NHibernate/Linq/DmlExpressionRewriter.cs +++ b/src/NHibernate/Linq/DmlExpressionRewriter.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using System.Reflection; using NHibernate.Linq.Visitors; using NHibernate.Util; @@ -10,10 +9,6 @@ namespace NHibernate.Linq { public class DmlExpressionRewriter { - static readonly ConstructorInfo DictionaryConstructorInfo = typeof(Dictionary).GetConstructor(new[] {typeof(int)}); - - static readonly MethodInfo DictionaryAddMethodInfo = ReflectHelper.GetMethod>(d => d.Add(null, null)); - readonly IReadOnlyCollection _parameters; readonly Dictionary _assignments = new Dictionary(); @@ -80,39 +75,25 @@ void AddSettersFromAssignment(MemberAssignment assignment, string path) } /// - /// Converts the assignments into a lambda expression, which creates a Dictionary<string,object%gt;. + /// Converts the assignments into block of assignments /// /// /// A lambda expression representing the assignments. - static LambdaExpression ConvertAssignmentsToDictionaryExpression(IReadOnlyDictionary assignments) + static LambdaExpression ConvertAssignmentsToBlockExpression(IReadOnlyDictionary assignments) { var param = Expression.Parameter(typeof(TSource)); - var inits = new List(); + var variableAndAssignmentDic = new Dictionary(assignments.Count); foreach (var set in assignments) { var setter = set.Value; if (setter is LambdaExpression setterLambda) setter = setterLambda.Body.Replace(setterLambda.Parameters.First(), param); - inits.Add( - Expression.ElementInit( - DictionaryAddMethodInfo, - Expression.Constant(set.Key), - Expression.Convert(setter, typeof(object)))); + + var var = Expression.Variable(typeof(object), set.Key); + variableAndAssignmentDic[var] = Expression.Assign(var, Expression.Convert(setter, typeof(object))); } - //The ListInit is intentionally "infected" with the lambda parameter (param), in the form of an IIF. - //The only relevance is to make sure that the ListInit is not evaluated by the PartialEvaluatingExpressionTreeVisitor, - //which could turn it into a Constant - var listInit = Expression.ListInit( - Expression.New( - DictionaryConstructorInfo, - Expression.Condition( - Expression.Equal(param, Expression.Constant(null, typeof(TSource))), - Expression.Constant(assignments.Count), - Expression.Constant(assignments.Count))), - inits); - - return Expression.Lambda(listInit, param); + return Expression.Lambda(Expression.Block(variableAndAssignmentDic.Keys, variableAndAssignmentDic.Values), param); } public static Expression PrepareExpression(Expression sourceExpression, Expression> expression) @@ -151,7 +132,7 @@ public static Expression PrepareExpressionFromAnonymous(Expression sour public static Expression PrepareExpression(Expression sourceExpression, IReadOnlyDictionary assignments) { - var lambda = ConvertAssignmentsToDictionaryExpression(assignments); + var lambda = ConvertAssignmentsToBlockExpression(assignments); return Expression.Call( ReflectionCache.QueryableMethods.SelectDefinition.MakeGenericMethod(typeof(TSource), lambda.Body.Type), diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index d508a7ecddb..50ec325c47f 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -88,16 +88,13 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter ParameterDescriptors = requiredHqlParameters.AsReadOnly(); - if (QueryMode == QueryMode.Select && CanCachePlan) - { - CanCachePlan = - // If some constants do not have matching HQL parameters, their values from first query will - // be embedded in the plan and reused for subsequent queries: do not cache the plan. - !ParameterValuesByName + CanCachePlan = CanCachePlan && + // If some constants do not have matching HQL parameters, their values from first query will + // be embedded in the plan and reused for subsequent queries: do not cache the plan. + !ParameterValuesByName .Keys .Except(requiredHqlParameters.Select(p => p.Name)) .Any(); - } // The ast node may be altered by caller, duplicate it for preserving the original one. return DuplicateTree(ExpressionToHqlTranslationResults.Statement.AstNode); diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 648b752678f..f043d30c51f 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -438,20 +438,20 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que private void VisitInsertClause(Expression expression) { - var listInit = expression as ListInitExpression + var assignments = expression as BlockExpression ?? throw new QueryException("Malformed insert expression"); var insertedType = VisitorParameters.TargetEntityType; var idents = new List(); var selectColumns = new List(); //Extract the insert clause from the projected ListInit - foreach (var assignment in listInit.Initializers) + foreach (BinaryExpression assignment in assignments.Expressions) { - var member = (ConstantExpression)assignment.Arguments[0]; - var value = assignment.Arguments[1]; + var propName = ((ParameterExpression) assignment.Left).Name; + var value = assignment.Right; //The target property - idents.Add(_hqlTree.TreeBuilder.Ident((string)member.Value)); + idents.Add(_hqlTree.TreeBuilder.Ident(propName)); var valueHql = HqlGeneratorExpressionVisitor.Visit(value, VisitorParameters).AsExpression(); selectColumns.Add(valueHql); @@ -467,16 +467,15 @@ private void VisitInsertClause(Expression expression) private void VisitUpdateClause(Expression expression) { - var listInit = expression as ListInitExpression + var assignments = expression as BlockExpression ?? throw new QueryException("Malformed update expression"); - foreach (var initializer in listInit.Initializers) + foreach (BinaryExpression assigment in assignments.Expressions) { - var member = (ConstantExpression)initializer.Arguments[0]; - var setter = initializer.Arguments[1]; + var propName = ((ParameterExpression) assigment.Left).Name; + var setter = assigment.Right; var setterHql = HqlGeneratorExpressionVisitor.Visit(setter, VisitorParameters).AsExpression(); - _hqlTree.AddSet(_hqlTree.TreeBuilder.Equality(_hqlTree.TreeBuilder.Ident((string)member.Value), - setterHql)); + _hqlTree.AddSet(_hqlTree.TreeBuilder.Equality(_hqlTree.TreeBuilder.Ident(propName), setterHql)); } }