Skip to content

Proper query plan caching for DML LINQ queries #2299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/NHibernate.Test/Async/Linq/ConstantTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,10 @@ public async Task DmlPlansAreCachedAsync()
{
await (db.Customers.Where(c => c.CustomerId == "UNKNOWN").UpdateAsync(x => new Customer {CompanyName = "Constant1"}));
await (db.Customers.Where(c => c.CustomerId == "ALFKI").UpdateAsync(x => new Customer {CompanyName = x.CompanyName}));
await (db.Customers.Where(c => c.CustomerId == "UNKNOWN").UpdateAsync(x => new Customer {ContactName = "Constant1"}));
Assert.That(
cache,
Has.Count.EqualTo(2),
Has.Count.EqualTo(3),
"Query plans should be cached.");

using (var spy = new LogSpy(queryPlanCacheType))
Expand All @@ -264,6 +265,7 @@ public async Task DmlPlansAreCachedAsync()
{
await (db.Customers.Where(c => c.CustomerId == "ANATR").UpdateAsync(x => new Customer {CompanyName = x.CompanyName}));
await (db.Customers.Where(c => c.CustomerId == "UNKNOWN").UpdateAsync(x => new Customer {CompanyName = "Constant2"}));
await (db.Customers.Where(c => c.CustomerId == "UNKNOWN").UpdateAsync(x => new Customer {ContactName = "Constant2"}));

var sqlEvents = sqlSpy.Appender.GetEvents();
Assert.That(
Expand All @@ -272,19 +274,25 @@ public async Task DmlPlansAreCachedAsync()
"Unexpected constant parameter value");
Assert.That(
sqlEvents[1].RenderedMessage,
Does.Contain("UNKNOWN").And.Contain("Constant2").And.Not.Contain("Constant1"),
Does.Contain("UNKNOWN").And.Contain("Constant2").And.Contain("CompanyName").IgnoreCase
.And.Not.Contain("Constant1"),
"Unexpected constant parameter value");
Assert.That(
sqlEvents[2].RenderedMessage,
Does.Contain("UNKNOWN").And.Contain("Constant2").And.Contain("ContactName").IgnoreCase
.And.Not.Contain("Constant1"),
"Unexpected constant parameter value");
}

Assert.That(cache, Has.Count.EqualTo(2), "Additional queries should not cause a plan to be cached.");
Assert.That(cache, Has.Count.EqualTo(3), "Additional queries should not cause a plan to be cached.");
Assert.That(
spy.GetWholeLog(),
Does
.Contain("located HQL query plan in cache")
.And.Not.Contain("unable to locate HQL query plan in cache"));

await (db.Customers.Where(c => c.CustomerId == "ANATR").UpdateAsync(x => new Customer {ContactName = x.ContactName}));
Assert.That(cache, Has.Count.EqualTo(3), "Query should be cached");
Assert.That(cache, Has.Count.EqualTo(4), "Query should be cached");
}
}
}
Expand Down
16 changes: 12 additions & 4 deletions src/NHibernate.Test/Linq/ConstantTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,10 @@ public void DmlPlansAreCached()
{
db.Customers.Where(c => c.CustomerId == "UNKNOWN").Update(x => new Customer {CompanyName = "Constant1"});
db.Customers.Where(c => c.CustomerId == "ALFKI").Update(x => new Customer {CompanyName = x.CompanyName});
db.Customers.Where(c => c.CustomerId == "UNKNOWN").Update(x => new Customer {ContactName = "Constant1"});
Assert.That(
cache,
Has.Count.EqualTo(2),
Has.Count.EqualTo(3),
"Query plans should be cached.");

using (var spy = new LogSpy(queryPlanCacheType))
Expand All @@ -285,6 +286,7 @@ public void DmlPlansAreCached()
{
db.Customers.Where(c => c.CustomerId == "ANATR").Update(x => new Customer {CompanyName = x.CompanyName});
db.Customers.Where(c => c.CustomerId == "UNKNOWN").Update(x => new Customer {CompanyName = "Constant2"});
db.Customers.Where(c => c.CustomerId == "UNKNOWN").Update(x => new Customer {ContactName = "Constant2"});

var sqlEvents = sqlSpy.Appender.GetEvents();
Assert.That(
Expand All @@ -293,19 +295,25 @@ public void DmlPlansAreCached()
"Unexpected constant parameter value");
Assert.That(
sqlEvents[1].RenderedMessage,
Does.Contain("UNKNOWN").And.Contain("Constant2").And.Not.Contain("Constant1"),
Does.Contain("UNKNOWN").And.Contain("Constant2").And.Contain("CompanyName").IgnoreCase
.And.Not.Contain("Constant1"),
"Unexpected constant parameter value");
Assert.That(
sqlEvents[2].RenderedMessage,
Does.Contain("UNKNOWN").And.Contain("Constant2").And.Contain("ContactName").IgnoreCase
.And.Not.Contain("Constant1"),
"Unexpected constant parameter value");
}

Assert.That(cache, Has.Count.EqualTo(2), "Additional queries should not cause a plan to be cached.");
Assert.That(cache, Has.Count.EqualTo(3), "Additional queries should not cause a plan to be cached.");
Assert.That(
spy.GetWholeLog(),
Does
.Contain("located HQL query plan in cache")
.And.Not.Contain("unable to locate HQL query plan in cache"));

db.Customers.Where(c => c.CustomerId == "ANATR").Update(x => new Customer {ContactName = x.ContactName});
Assert.That(cache, Has.Count.EqualTo(3), "Query should be cached");
Assert.That(cache, Has.Count.EqualTo(4), "Query should be cached");
}
}
}
Expand Down
35 changes: 8 additions & 27 deletions src/NHibernate/Linq/DmlExpressionRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using NHibernate.Linq.Visitors;
using NHibernate.Util;

namespace NHibernate.Linq
{
public class DmlExpressionRewriter
{
static readonly ConstructorInfo DictionaryConstructorInfo = typeof(Dictionary<string, object>).GetConstructor(new[] {typeof(int)});

static readonly MethodInfo DictionaryAddMethodInfo = ReflectHelper.GetMethod<Dictionary<string, object>>(d => d.Add(null, null));

readonly IReadOnlyCollection<ParameterExpression> _parameters;
readonly Dictionary<string, Expression> _assignments = new Dictionary<string, Expression>();

Expand Down Expand Up @@ -80,39 +75,25 @@ void AddSettersFromAssignment(MemberAssignment assignment, string path)
}

/// <summary>
/// Converts the assignments into a lambda expression, which creates a Dictionary&lt;string,object%gt;.
/// Converts the assignments into block of assignments
/// </summary>
/// <param name="assignments"></param>
/// <returns>A lambda expression representing the assignments.</returns>
static LambdaExpression ConvertAssignmentsToDictionaryExpression<TSource>(IReadOnlyDictionary<string, Expression> assignments)
static LambdaExpression ConvertAssignmentsToBlockExpression<TSource>(IReadOnlyDictionary<string, Expression> assignments)
{
var param = Expression.Parameter(typeof(TSource));
var inits = new List<ElementInit>();
var variableAndAssignmentDic = new Dictionary<ParameterExpression, Expression>(assignments.Count);
foreach (var set in assignments)
{
var setter = set.Value;
if (setter is LambdaExpression setterLambda)
setter = setterLambda.Body.Replace(setterLambda.Parameters.First(), param);
inits.Add(
Expression.ElementInit(
DictionaryAddMethodInfo,
Expression.Constant(set.Key),
Expression.Convert(setter, typeof(object))));

var var = Expression.Variable(typeof(object), set.Key);
variableAndAssignmentDic[var] = Expression.Assign(var, Expression.Convert(setter, typeof(object)));
}

//The ListInit is intentionally "infected" with the lambda parameter (param), in the form of an IIF.
//The only relevance is to make sure that the ListInit is not evaluated by the PartialEvaluatingExpressionTreeVisitor,
//which could turn it into a Constant
var listInit = Expression.ListInit(
Expression.New(
DictionaryConstructorInfo,
Expression.Condition(
Expression.Equal(param, Expression.Constant(null, typeof(TSource))),
Expression.Constant(assignments.Count),
Expression.Constant(assignments.Count))),
inits);

return Expression.Lambda(listInit, param);
return Expression.Lambda(Expression.Block(variableAndAssignmentDic.Keys, variableAndAssignmentDic.Values), param);
}

public static Expression PrepareExpression<TSource, TTarget>(Expression sourceExpression, Expression<Func<TSource, TTarget>> expression)
Expand Down Expand Up @@ -151,7 +132,7 @@ public static Expression PrepareExpressionFromAnonymous<TSource>(Expression sour

public static Expression PrepareExpression<TSource>(Expression sourceExpression, IReadOnlyDictionary<string, Expression> assignments)
{
var lambda = ConvertAssignmentsToDictionaryExpression<TSource>(assignments);
var lambda = ConvertAssignmentsToBlockExpression<TSource>(assignments);

return Expression.Call(
ReflectionCache.QueryableMethods.SelectDefinition.MakeGenericMethod(typeof(TSource), lambda.Body.Type),
Expand Down
11 changes: 4 additions & 7 deletions src/NHibernate/Linq/NhLinqExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,13 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter

ParameterDescriptors = requiredHqlParameters.AsReadOnly();

if (QueryMode == QueryMode.Select && CanCachePlan)
{
CanCachePlan =
// If some constants do not have matching HQL parameters, their values from first query will
// be embedded in the plan and reused for subsequent queries: do not cache the plan.
!ParameterValuesByName
CanCachePlan = CanCachePlan &&
// If some constants do not have matching HQL parameters, their values from first query will
// be embedded in the plan and reused for subsequent queries: do not cache the plan.
!ParameterValuesByName
.Keys
.Except(requiredHqlParameters.Select(p => p.Name))
.Any();
}

// The ast node may be altered by caller, duplicate it for preserving the original one.
return DuplicateTree(ExpressionToHqlTranslationResults.Statement.AstNode);
Expand Down
21 changes: 10 additions & 11 deletions src/NHibernate/Linq/Visitors/QueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -438,20 +438,20 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que

private void VisitInsertClause(Expression expression)
{
var listInit = expression as ListInitExpression
var assignments = expression as BlockExpression
?? throw new QueryException("Malformed insert expression");
var insertedType = VisitorParameters.TargetEntityType;
var idents = new List<HqlIdent>();
var selectColumns = new List<HqlExpression>();

//Extract the insert clause from the projected ListInit
foreach (var assignment in listInit.Initializers)
foreach (BinaryExpression assignment in assignments.Expressions)
{
var member = (ConstantExpression)assignment.Arguments[0];
var value = assignment.Arguments[1];
var propName = ((ParameterExpression) assignment.Left).Name;
var value = assignment.Right;

//The target property
idents.Add(_hqlTree.TreeBuilder.Ident((string)member.Value));
idents.Add(_hqlTree.TreeBuilder.Ident(propName));

var valueHql = HqlGeneratorExpressionVisitor.Visit(value, VisitorParameters).AsExpression();
selectColumns.Add(valueHql);
Expand All @@ -467,16 +467,15 @@ private void VisitInsertClause(Expression expression)

private void VisitUpdateClause(Expression expression)
{
var listInit = expression as ListInitExpression
var assignments = expression as BlockExpression
?? throw new QueryException("Malformed update expression");
foreach (var initializer in listInit.Initializers)
foreach (BinaryExpression assigment in assignments.Expressions)
{
var member = (ConstantExpression)initializer.Arguments[0];
var setter = initializer.Arguments[1];
var propName = ((ParameterExpression) assigment.Left).Name;
var setter = assigment.Right;
var setterHql = HqlGeneratorExpressionVisitor.Visit(setter, VisitorParameters).AsExpression();

_hqlTree.AddSet(_hqlTree.TreeBuilder.Equality(_hqlTree.TreeBuilder.Ident((string)member.Value),
setterHql));
_hqlTree.AddSet(_hqlTree.TreeBuilder.Equality(_hqlTree.TreeBuilder.Ident(propName), setterHql));
}
}

Expand Down