diff --git a/src/NHibernate.Test/Async/Linq/ConstantTest.cs b/src/NHibernate.Test/Async/Linq/ConstantTest.cs index b5f0f05037d..17465eea544 100644 --- a/src/NHibernate.Test/Async/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Async/Linq/ConstantTest.cs @@ -10,7 +10,11 @@ using System.Collections.Generic; using System.Linq; +using System.Reflection; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine.Query; +using NHibernate.Linq.Visitors; +using NHibernate.Util; using NUnit.Framework; using NHibernate.Linq; @@ -130,10 +134,10 @@ public async Task ConstantNonCachedInMemberInitExpressionAsync() public async Task ConstantInNewArrayExpressionAsync() { var c1 = await ((from c in db.Categories - select new [] { c.Name, "category1" }).ToListAsync()); + select new[] { c.Name, "category1" }).ToListAsync()); var c2 = await ((from c in db.Categories - select new [] { c.Name, "category2" }).ToListAsync()); + select new[] { c.Name, "category2" }).ToListAsync()); Assert.That(c1, Has.Count.GreaterThan(0), "c1 Count"); Assert.That(c2, Has.Count.GreaterThan(0), "c2 Count"); @@ -175,13 +179,19 @@ public int GetItemValue(Product p) { return _value; } + + // Workaround for having a different key per different instances. + public override string ToString() + { + return base.ToString() + _value; + } } // Adapted from NH-2500 first test case by Andrey Titov (file NHTest3.zip) [Test] - [Ignore("Not fixed yet")] public async Task ObjectConstantsAsync() { + // Fixed with a workaround, see InfoBuilder above. var builder = new InfoBuilder(1); var v1 = await ((from p in db.Products select builder.GetItemValue(p)).FirstAsync()); @@ -200,7 +210,6 @@ private int TestFunc(Product item, int closureValue) // Adapted from NH-3673 [Test] - [Ignore("Not fixed yet")] public async Task ConstantsInFuncCallAsync() { var closureVariable = 1; @@ -213,5 +222,64 @@ public async Task ConstantsInFuncCallAsync() Assert.That(v1, Is.EqualTo(1), "v1"); Assert.That(v2, Is.EqualTo(2), "v2"); } + + [Test] + public async Task PlansAreCachedAsync() + { + var queryPlanCacheType = typeof(QueryPlanCache); + + var cache = (SoftLimitMRUCache) queryPlanCacheType + .GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic) + .GetValue(Sfi.QueryPlanCache); + cache.Clear(); + + await ((from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, Constant = 1 }).FirstAsync()); + Assert.That( + cache, + Has.Count.EqualTo(2), + "First query plan should be cached with a non-refined key and a refined one."); + + using (var spy = new LogSpy(queryPlanCacheType)) + { + // Should hit non-refined key but miss refined key. + await ((from c in db.Customers + where c.CustomerId == "ANATR" + select new { c.CustomerId, c.ContactName, Constant = 2 }).FirstAsync()); + Assert.That(cache, Has.Count.EqualTo(3), "Second query plan should be cached only with its refined key."); + Assert.That( + spy.GetWholeLog(), + Does + .Contain("located HQL query plan in cache") + .And.Contain("Key was refined and is no more matching") + .And.Contain("unable to locate HQL query plan in cache")); + + spy.Appender.Clear(); + // Should hit non-refined key entry directly. + await ((from c in db.Customers + where c.CustomerId == "ANATR" + select new { c.CustomerId, c.ContactName, Constant = 1 }).FirstAsync()); + Assert.That(cache, Has.Count.EqualTo(3), "Third query plan should not be additionnaly cached."); + Assert.That( + spy.GetWholeLog(), + Does + .Contain("located HQL query plan in cache") + .And.Not.Contain("Key was refined and is no more matching")); + + spy.Appender.Clear(); + // Should hit non-refined key then hit refined key. + await ((from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, Constant = 2 }).FirstAsync()); + Assert.That(cache, Has.Count.EqualTo(3), "Fourth query plan should not be additionnaly cached."); + Assert.That( + spy.GetWholeLog(), + Does + .Contain("located HQL query plan in cache") + .And.Contain("Key was refined and is no more matching") + .And.Not.Contain("unable to locate HQL query plan in cache")); + } + } } } diff --git a/src/NHibernate.Test/Async/NHSpecificTest/NH2658/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/NH2658/Fixture.cs new file mode 100644 index 00000000000..dec6e7b3d15 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/NH2658/Fixture.cs @@ -0,0 +1,101 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Driver; +using NHibernate.Engine; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.NH2658 +{ + using System.Threading.Tasks; + + [TestFixture] + public class FixtureAsync : TestCase + { + public class DynamicPropertyGenerator : BaseHqlGeneratorForMethod + { + public DynamicPropertyGenerator() + { + //just registering for string here, but in a real implementation we'd be doing a runtime generator + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => ObjectExtensions.GetProperty(null, null)) + }; + } + + public override HqlTreeNode BuildHql( + MethodInfo method, + Expression targetObject, + ReadOnlyCollection arguments, + HqlTreeBuilder treeBuilder, + IHqlExpressionVisitor visitor) + { + var propertyName = (string) ((ConstantExpression) arguments[1]).Value; + + return treeBuilder.Dot( + visitor.Visit(arguments[0]).AsExpression(), + treeBuilder.Ident(propertyName)).AsExpression(); + } + } + + protected override string MappingsAssembly => "NHibernate.Test"; + + protected override IList Mappings => new[] { "NHSpecificTest.NH2658.Mappings.hbm.xml" }; + + protected override DebugSessionFactory BuildSessionFactory() + { + var sfi = base.BuildSessionFactory(); + + //add our linq extension + ((ISessionFactoryImplementor)sfi).Settings.LinqToHqlGeneratorsRegistry.Merge(new DynamicPropertyGenerator()); + return sfi; + } + + [Test] + public async Task Does_Not_Cache_NonParametersAsync() + { + using (var session = OpenSession()) + { + //PASSES + using (var spy = new SqlLogSpy()) + { + //Query by name + await ((from p in session.Query() where p.GetProperty("Name") == "Value" select p).ToListAsync()); + + var paramName = ((ISqlParameterFormatter) Sfi.ConnectionProvider.Driver).GetParameterName(0); + Assert.That(spy.GetWholeLog(), Does.Contain("Name=" + paramName)); + } + + //FAILS + //Because this query is considered the same as the top query the hql will be reused from the top statement + //Even though GetProperty has a parameter that never get passed to sql or hql + using (var spy = new SqlLogSpy()) + { + //Query by description + await ((from p in session.Query() where p.GetProperty("Description") == "Value" select p).ToListAsync()); + + var paramName = ((ISqlParameterFormatter) Sfi.ConnectionProvider.Driver).GetParameterName(0); + Assert.That(spy.GetWholeLog(), Does.Contain("Description=" + paramName)); + } + } + } + } +} diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs index a30118e7283..0f2a0769908 100644 --- a/src/NHibernate.Test/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Linq/ConstantTest.cs @@ -1,6 +1,10 @@ using System.Collections.Generic; using System.Linq; +using System.Reflection; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine.Query; +using NHibernate.Linq.Visitors; +using NHibernate.Util; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -118,10 +122,10 @@ public void ConstantNonCachedInMemberInitExpression() public void ConstantInNewArrayExpression() { var c1 = (from c in db.Categories - select new [] { c.Name, "category1" }).ToList(); + select new[] { c.Name, "category1" }).ToList(); var c2 = (from c in db.Categories - select new [] { c.Name, "category2" }).ToList(); + select new[] { c.Name, "category2" }).ToList(); Assert.That(c1, Has.Count.GreaterThan(0), "c1 Count"); Assert.That(c2, Has.Count.GreaterThan(0), "c2 Count"); @@ -163,13 +167,19 @@ public int GetItemValue(Product p) { return _value; } + + // Workaround for having a different key per different instances. + public override string ToString() + { + return base.ToString() + _value; + } } // Adapted from NH-2500 first test case by Andrey Titov (file NHTest3.zip) [Test] - [Ignore("Not fixed yet")] public void ObjectConstants() { + // Fixed with a workaround, see InfoBuilder above. var builder = new InfoBuilder(1); var v1 = (from p in db.Products select builder.GetItemValue(p)).First(); @@ -188,7 +198,6 @@ private int TestFunc(Product item, int closureValue) // Adapted from NH-3673 [Test] - [Ignore("Not fixed yet")] public void ConstantsInFuncCall() { var closureVariable = 1; @@ -201,5 +210,83 @@ public void ConstantsInFuncCall() Assert.That(v1, Is.EqualTo(1), "v1"); Assert.That(v2, Is.EqualTo(2), "v2"); } + + [Test] + public void ConstantInWhereDoesNotCauseManyKeys() + { + var q1 = (from c in db.Customers + where c.CustomerId == "ALFKI" + select c); + var q2 = (from c in db.Customers + where c.CustomerId == "ANATR" + select c); + var parameters1 = ExpressionParameterVisitor.Visit(q1.Expression, Sfi); + var k1 = ExpressionKeyVisitor.Visit(q1.Expression, parameters1); + var parameters2 = ExpressionParameterVisitor.Visit(q2.Expression, Sfi); + var k2 = ExpressionKeyVisitor.Visit(q2.Expression, parameters2); + + Assert.That(parameters1, Has.Count.GreaterThan(0), "parameters1"); + Assert.That(parameters2, Has.Count.GreaterThan(0), "parameters2"); + Assert.That(k2, Is.EqualTo(k1)); + } + + [Test] + public void PlansAreCached() + { + var queryPlanCacheType = typeof(QueryPlanCache); + + var cache = (SoftLimitMRUCache) queryPlanCacheType + .GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic) + .GetValue(Sfi.QueryPlanCache); + cache.Clear(); + + (from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, Constant = 1 }).First(); + Assert.That( + cache, + Has.Count.EqualTo(2), + "First query plan should be cached with a non-refined key and a refined one."); + + using (var spy = new LogSpy(queryPlanCacheType)) + { + // Should hit non-refined key but miss refined key. + (from c in db.Customers + where c.CustomerId == "ANATR" + select new { c.CustomerId, c.ContactName, Constant = 2 }).First(); + Assert.That(cache, Has.Count.EqualTo(3), "Second query plan should be cached only with its refined key."); + Assert.That( + spy.GetWholeLog(), + Does + .Contain("located HQL query plan in cache") + .And.Contain("Key was refined and is no more matching") + .And.Contain("unable to locate HQL query plan in cache")); + + spy.Appender.Clear(); + // Should hit non-refined key entry directly. + (from c in db.Customers + where c.CustomerId == "ANATR" + select new { c.CustomerId, c.ContactName, Constant = 1 }).First(); + Assert.That(cache, Has.Count.EqualTo(3), "Third query plan should not be additionnaly cached."); + Assert.That( + spy.GetWholeLog(), + Does + .Contain("located HQL query plan in cache") + .And.Not.Contain("Key was refined and is no more matching")); + + spy.Appender.Clear(); + // Should hit non-refined key then hit refined key. + (from c in db.Customers + where c.CustomerId == "ALFKI" + select new { c.CustomerId, c.ContactName, Constant = 2 }).First(); + Assert.That(cache, Has.Count.EqualTo(3), "Fourth query plan should not be additionnaly cached."); + Assert.That( + spy.GetWholeLog(), + Does + .Contain("located HQL query plan in cache") + .And.Contain("Key was refined and is no more matching") + .And.Not.Contain("unable to locate HQL query plan in cache")); + } + } } } diff --git a/src/NHibernate.Test/NHSpecificTest/NH2658/Entity.cs b/src/NHibernate.Test/NHSpecificTest/NH2658/Entity.cs new file mode 100644 index 00000000000..3a80f5fff8a --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH2658/Entity.cs @@ -0,0 +1,11 @@ +namespace NHibernate.Test.NHSpecificTest.NH2658 +{ + public class Product + { + public virtual string ProductId { get; set; } + + public virtual string Name { get; set; } + + public virtual string Description { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH2658/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/NH2658/Fixture.cs new file mode 100644 index 00000000000..b01c09e2d3b --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH2658/Fixture.cs @@ -0,0 +1,97 @@ +using System; +using System.Collections; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Driver; +using NHibernate.Engine; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.NH2658 +{ + public static class ObjectExtensions + { + public static T GetProperty(this object o, string propertyName) + { + //no implementation for this test + throw new NotImplementedException(); + } + } + + [TestFixture] + public class Fixture : TestCase + { + public class DynamicPropertyGenerator : BaseHqlGeneratorForMethod + { + public DynamicPropertyGenerator() + { + //just registering for string here, but in a real implementation we'd be doing a runtime generator + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => ObjectExtensions.GetProperty(null, null)) + }; + } + + public override HqlTreeNode BuildHql( + MethodInfo method, + Expression targetObject, + ReadOnlyCollection arguments, + HqlTreeBuilder treeBuilder, + IHqlExpressionVisitor visitor) + { + var propertyName = (string) ((ConstantExpression) arguments[1]).Value; + + return treeBuilder.Dot( + visitor.Visit(arguments[0]).AsExpression(), + treeBuilder.Ident(propertyName)).AsExpression(); + } + } + + protected override string MappingsAssembly => "NHibernate.Test"; + + protected override IList Mappings => new[] { "NHSpecificTest.NH2658.Mappings.hbm.xml" }; + + protected override DebugSessionFactory BuildSessionFactory() + { + var sfi = base.BuildSessionFactory(); + + //add our linq extension + ((ISessionFactoryImplementor)sfi).Settings.LinqToHqlGeneratorsRegistry.Merge(new DynamicPropertyGenerator()); + return sfi; + } + + [Test] + public void Does_Not_Cache_NonParameters() + { + using (var session = OpenSession()) + { + //PASSES + using (var spy = new SqlLogSpy()) + { + //Query by name + (from p in session.Query() where p.GetProperty("Name") == "Value" select p).ToList(); + + var paramName = ((ISqlParameterFormatter) Sfi.ConnectionProvider.Driver).GetParameterName(0); + Assert.That(spy.GetWholeLog(), Does.Contain("Name=" + paramName)); + } + + //FAILS + //Because this query is considered the same as the top query the hql will be reused from the top statement + //Even though GetProperty has a parameter that never get passed to sql or hql + using (var spy = new SqlLogSpy()) + { + //Query by description + (from p in session.Query() where p.GetProperty("Description") == "Value" select p).ToList(); + + var paramName = ((ISqlParameterFormatter) Sfi.ConnectionProvider.Driver).GetParameterName(0); + Assert.That(spy.GetWholeLog(), Does.Contain("Description=" + paramName)); + } + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH2658/Mappings.hbm.xml b/src/NHibernate.Test/NHSpecificTest/NH2658/Mappings.hbm.xml new file mode 100644 index 00000000000..8e4c8736769 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH2658/Mappings.hbm.xml @@ -0,0 +1,11 @@ + + + + + + + + + + diff --git a/src/NHibernate/Engine/Query/QueryPlanCache.cs b/src/NHibernate/Engine/Query/QueryPlanCache.cs index 09556fa5a64..3d228326ac8 100644 --- a/src/NHibernate/Engine/Query/QueryPlanCache.cs +++ b/src/NHibernate/Engine/Query/QueryPlanCache.cs @@ -58,8 +58,19 @@ public IQueryExpressionPlan GetHQLQueryPlan(IQueryExpression queryExpression, bo { log.Debug("unable to locate HQL query plan in cache; generating ({0})", queryExpression.Key); } + + var refinableQuery = queryExpression as IRefinableKeyQueryExpression; + var wasAlreadyRefined = refinableQuery?.RefinedKey; + plan = new QueryExpressionPlan(queryExpression, shallow, enabledFilters, factory); planCache.Put(key, plan); + + if (wasAlreadyRefined == false && refinableQuery.RefinedKey) + { + // Additionally cache with the refined key. The cached entry with the non refined key is still + // needed for allowing cache hit with non-refined key, which then refines the key. + planCache.Put(new HQLQueryPlanKey(queryExpression, shallow, enabledFilters), plan); + } } else { @@ -67,6 +78,18 @@ public IQueryExpressionPlan GetHQLQueryPlan(IQueryExpression queryExpression, bo { log.Debug("located HQL query plan in cache ({0})", queryExpression.Key); } + + if (plan.QueryExpression is IRefinableKeyQueryExpression cachedRefinableQuery && cachedRefinableQuery.RefinedKey && + queryExpression is IRefinableKeyQueryExpression refinableQuery && !refinableQuery.RefinedKey) + { + refinableQuery.RefineKey(cachedRefinableQuery.ParametersRefiningKey); + if (refinableQuery.Key != cachedRefinableQuery.Key) + { + log.Debug("Key was refined and is no more matching, querying cache again."); + return GetHQLQueryPlan(queryExpression, shallow, enabledFilters); + } + } + plan = CopyIfRequired(plan, queryExpression); } @@ -109,12 +132,35 @@ public IQueryExpressionPlan GetFilterQueryPlan(IQueryExpression queryExpression, if (plan == null) { log.Debug("unable to locate collection-filter query plan in cache; generating ({0} : {1})", collectionRole, queryExpression.Key); + + var refinableQuery = queryExpression as IRefinableKeyQueryExpression; + var wasAlreadyRefined = refinableQuery?.RefinedKey; + plan = new FilterQueryPlan(queryExpression, collectionRole, shallow, enabledFilters, factory); planCache.Put(key, plan); + + if (wasAlreadyRefined == false && refinableQuery.RefinedKey) + { + // Additionally cache with the refined key. The cached entry with the non refined key is still + // needed for allowing cache hit with non-refined key, which then refines the key. + planCache.Put(new FilterQueryPlanKey(queryExpression.Key, collectionRole, shallow, enabledFilters), plan); + } } else { log.Debug("located collection-filter query plan in cache ({0} : {1})", collectionRole, queryExpression.Key); + + if (plan.QueryExpression is IRefinableKeyQueryExpression cachedRefinableQuery && cachedRefinableQuery.RefinedKey && + queryExpression is IRefinableKeyQueryExpression refinableQuery && !refinableQuery.RefinedKey) + { + refinableQuery.RefineKey(cachedRefinableQuery.ParametersRefiningKey); + if (refinableQuery.Key != cachedRefinableQuery.Key) + { + log.Debug("Key was refined and is no more matching, querying cache again."); + return GetFilterQueryPlan(queryExpression, collectionRole, shallow, enabledFilters); + } + } + plan = CopyIfRequired(plan, queryExpression); } diff --git a/src/NHibernate/IRefinableKeyQueryExpression.cs b/src/NHibernate/IRefinableKeyQueryExpression.cs new file mode 100644 index 00000000000..430e44722a9 --- /dev/null +++ b/src/NHibernate/IRefinableKeyQueryExpression.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace NHibernate +{ + public interface IRefinableKeyQueryExpression : IQueryExpression + { + bool RefinedKey { get; } + void RefineKey(ISet parametersRefiningKey); + ISet ParametersRefiningKey { get; } + } +} diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index fbca270693f..fe2fc32cc29 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -11,10 +11,12 @@ namespace NHibernate.Linq { - public class NhLinqExpression : IQueryExpression + public class NhLinqExpression : IRefinableKeyQueryExpression { public string Key { get; protected set; } + public bool RefinedKey { get; private set; } + public System.Type Type { get; private set; } /// @@ -23,6 +25,7 @@ public class NhLinqExpression : IQueryExpression protected virtual System.Type TargetType => Type; public IList ParameterDescriptors { get; private set; } + public ISet ParametersRefiningKey { get; private set; } public NhLinqExpressionReturnType ReturnType { get; } @@ -78,6 +81,13 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter ParameterDescriptors = requiredHqlParameters.AsReadOnly(); + var parametersRefiningKey = new HashSet( + _constantToParameterMap + .Values.Select(p => p.Name) + .Except(requiredHqlParameters.Select(p => p.Name))); + + RefineKey(parametersRefiningKey); + return ExpressionToHqlTranslationResults.Statement.AstNode; } @@ -88,5 +98,19 @@ internal void CopyExpressionTranslation(NhLinqExpression other) // Type could have been overridden by translation. Type = other.Type; } + + public void RefineKey(ISet parametersRefiningKey) + { + if (RefinedKey) + // Already done. + return; + ParametersRefiningKey = parametersRefiningKey; + var refinedKey = ExpressionKeyVisitor.RefineKey(Key, parametersRefiningKey, _constantToParameterMap); + if (refinedKey == null) + // No changes to key + return; + Key = refinedKey; + RefinedKey = true; + } } } diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index 1789357cfb0..0decbdf7b0d 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -1,7 +1,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -37,6 +36,28 @@ public static string Visit(Expression expression, IDictionary + /// Refines the provided key according to parameter values which needs to be included in the key. + /// + /// The key. + /// The set of parameter names which need to have their value included in the key. + /// The map of constants to candidate parameters. + /// A refined key, or if no refining occured. + public static string RefineKey(string key, ISet parametersRefiningKey, IDictionary parameters) + { + if (parametersRefiningKey.Count == 0) + return null; + + var writer = new StringBuilder(key); + foreach (var param in parameters.Where(p => parametersRefiningKey.Contains(p.Value.Name))) + { + writer.Append(';').Append(param.Value.Name).Append('='); + WriteConstantValue(param.Key, writer); + } + + return writer.Length == key.Length ? null : writer.ToString(); + } + public override string ToString() { return _string.ToString(); @@ -83,7 +104,7 @@ protected override Expression VisitConstant(ConstantExpression expression) if (_constantToParameterMap == null) throw new InvalidOperationException("Cannot visit a constant without a constant to parameter map."); - if (_constantToParameterMap.TryGetValue(expression, out param) && insideSelectClause == false) + if (_constantToParameterMap.TryGetValue(expression, out param)) { // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. if (param.Value == null) @@ -105,27 +126,31 @@ protected override Expression VisitConstant(ConstantExpression expression) } else { - if (expression.Value == null) + WriteConstantValue(expression, _string); + } + + return base.VisitConstant(expression); + } + + private static void WriteConstantValue(ConstantExpression expression, StringBuilder writer) + { + if (expression.Value == null) + { + writer.Append("NULL"); + } + else + { + if (expression.Value is IEnumerable value && !(value is string) && !(value is IQueryable)) { - _string.Append("NULL"); + writer.Append("{"); + writer.Append(String.Join(",", value.Cast())); + writer.Append("}"); } else { - var value = expression.Value as IEnumerable; - if (value != null && !(value is string) && !(value is IQueryable)) - { - _string.Append("{"); - _string.Append(String.Join(",", value.Cast())); - _string.Append("}"); - } - else - { - _string.Append(expression.Value); - } + writer.Append(expression.Value); } } - - return base.VisitConstant(expression); } private T AppendCommas(T expression) where T : Expression @@ -158,26 +183,8 @@ protected override Expression VisitMember(MemberExpression expression) return expression; } - private bool insideSelectClause; protected override Expression VisitMethodCall(MethodCallExpression expression) { - var old = insideSelectClause; - - switch (expression.Method.Name) - { - case "First": - case "FirstOrDefault": - case "Single": - case "SingleOrDefault": - case "Select": - case "GroupBy": - insideSelectClause = true; - break; - default: - insideSelectClause = false; - break; - } - Visit(expression.Object); _string.Append('.'); VisitMethod(expression.Method); @@ -185,7 +192,6 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) ExpressionVisitor.Visit(expression.Arguments, AppendCommas); _string.Append(')'); - insideSelectClause = old; return expression; }