diff --git a/src/NHibernate.Test/Async/CompositeId/ClassWithCompositeIdFixture.cs b/src/NHibernate.Test/Async/CompositeId/ClassWithCompositeIdFixture.cs index e080dbd91eb..dc4c498e440 100644 --- a/src/NHibernate.Test/Async/CompositeId/ClassWithCompositeIdFixture.cs +++ b/src/NHibernate.Test/Async/CompositeId/ClassWithCompositeIdFixture.cs @@ -38,11 +38,6 @@ protected override string[] Mappings get { return new string[] {"CompositeId.ClassWithCompositeId.hbm.xml"}; } } - protected override bool AppliesTo(Dialect.Dialect dialect) - { - return !(dialect is Dialect.FirebirdDialect); // Firebird has no CommandTimeout, and locks up during the tear-down of this fixture - } - protected override void OnSetUp() { id = new Id("stringKey", 3, firstDateTime); @@ -52,9 +47,11 @@ protected override void OnSetUp() protected override void OnTearDown() { using (ISession s = Sfi.OpenSession()) + using (var t = s.BeginTransaction()) { s.Delete("from ClassWithCompositeId"); s.Flush(); + t.Commit(); } } @@ -396,5 +393,45 @@ public async Task QueryOverOrderByAndWhereWithIdProjectionDoesntThrowAsync() Assert.That(results.Count, Is.EqualTo(1)); } } + + [Test] + public async Task QueryOverInClauseAsync() + { + // insert the new objects + var id1 = id; + var id2 = secondId; + var id3 = new Id(id1.KeyString, id1.GetKeyShort(), id2.KeyDateTime); + + using (ISession s = OpenSession()) + using (ITransaction t = s.BeginTransaction()) + { + await (s.SaveAsync(new ClassWithCompositeId(id1) {OneProperty = 5})); + await (s.SaveAsync(new ClassWithCompositeId(id2) {OneProperty = 10})); + await (s.SaveAsync(new ClassWithCompositeId(id3))); + + await (t.CommitAsync()); + } + + using (var s = OpenSession()) + { + var results1 = await (s.QueryOver().WhereRestrictionOn(p => p.Id).IsIn(new[] {id1, id2}).ListAsync()); + var results2 = await (s.QueryOver().WhereRestrictionOn(p => p.Id).IsIn(new[] {id1}).ListAsync()); + var results3 = await (s.QueryOver().WhereRestrictionOn(p => p.Id).Not.IsIn(new[] {id1, id2}).ListAsync()); + var results4 = await (s.QueryOver().WhereRestrictionOn(p => p.Id).Not.IsIn(new[] {id1}).ListAsync()); + + Assert.Multiple( + () => + { + Assert.That(results1.Count, Is.EqualTo(2), "in multiple ids"); + Assert.That(results1.Select(r => r.Id), Is.EquivalentTo(new[] {id1, id2}), "in multiple ids"); + Assert.That(results2.Count, Is.EqualTo(1), "in single id"); + Assert.That(results2.Select(r => r.Id), Is.EquivalentTo(new[] {id1}), "in single id"); + Assert.That(results3.Count, Is.EqualTo(1), "not in multiple ids"); + Assert.That(results3.Select(r => r.Id), Is.EquivalentTo(new[] {id3}), "not in multiple ids"); + Assert.That(results4.Count, Is.EqualTo(2), "not in single id"); + Assert.That(results4.Select(r => r.Id), Is.EquivalentTo(new[] {id2, id3}), "not in single id"); + }); + } + } } } diff --git a/src/NHibernate.Test/CompositeId/ClassWithCompositeIdFixture.cs b/src/NHibernate.Test/CompositeId/ClassWithCompositeIdFixture.cs index 0d2b479fe57..edcec77d3ed 100644 --- a/src/NHibernate.Test/CompositeId/ClassWithCompositeIdFixture.cs +++ b/src/NHibernate.Test/CompositeId/ClassWithCompositeIdFixture.cs @@ -27,11 +27,6 @@ protected override string[] Mappings get { return new string[] {"CompositeId.ClassWithCompositeId.hbm.xml"}; } } - protected override bool AppliesTo(Dialect.Dialect dialect) - { - return !(dialect is Dialect.FirebirdDialect); // Firebird has no CommandTimeout, and locks up during the tear-down of this fixture - } - protected override void OnSetUp() { id = new Id("stringKey", 3, firstDateTime); @@ -41,9 +36,11 @@ protected override void OnSetUp() protected override void OnTearDown() { using (ISession s = Sfi.OpenSession()) + using (var t = s.BeginTransaction()) { s.Delete("from ClassWithCompositeId"); s.Flush(); + t.Commit(); } } @@ -385,5 +382,45 @@ public void QueryOverOrderByAndWhereWithIdProjectionDoesntThrow() Assert.That(results.Count, Is.EqualTo(1)); } } + + [Test] + public void QueryOverInClause() + { + // insert the new objects + var id1 = id; + var id2 = secondId; + var id3 = new Id(id1.KeyString, id1.GetKeyShort(), id2.KeyDateTime); + + using (ISession s = OpenSession()) + using (ITransaction t = s.BeginTransaction()) + { + s.Save(new ClassWithCompositeId(id1) {OneProperty = 5}); + s.Save(new ClassWithCompositeId(id2) {OneProperty = 10}); + s.Save(new ClassWithCompositeId(id3)); + + t.Commit(); + } + + using (var s = OpenSession()) + { + var results1 = s.QueryOver().WhereRestrictionOn(p => p.Id).IsIn(new[] {id1, id2}).List(); + var results2 = s.QueryOver().WhereRestrictionOn(p => p.Id).IsIn(new[] {id1}).List(); + var results3 = s.QueryOver().WhereRestrictionOn(p => p.Id).Not.IsIn(new[] {id1, id2}).List(); + var results4 = s.QueryOver().WhereRestrictionOn(p => p.Id).Not.IsIn(new[] {id1}).List(); + + Assert.Multiple( + () => + { + Assert.That(results1.Count, Is.EqualTo(2), "in multiple ids"); + Assert.That(results1.Select(r => r.Id), Is.EquivalentTo(new[] {id1, id2}), "in multiple ids"); + Assert.That(results2.Count, Is.EqualTo(1), "in single id"); + Assert.That(results2.Select(r => r.Id), Is.EquivalentTo(new[] {id1}), "in single id"); + Assert.That(results3.Count, Is.EqualTo(1), "not in multiple ids"); + Assert.That(results3.Select(r => r.Id), Is.EquivalentTo(new[] {id3}), "not in multiple ids"); + Assert.That(results4.Count, Is.EqualTo(2), "not in single id"); + Assert.That(results4.Select(r => r.Id), Is.EquivalentTo(new[] {id2, id3}), "not in single id"); + }); + } + } } } diff --git a/src/NHibernate/Criterion/InExpression.cs b/src/NHibernate/Criterion/InExpression.cs index d2edd83eba2..54409ba9f31 100644 --- a/src/NHibernate/Criterion/InExpression.cs +++ b/src/NHibernate/Criterion/InExpression.cs @@ -1,5 +1,4 @@ using System; -using System.Collections; using System.Collections.Generic; using System.Linq; using NHibernate.Engine; @@ -13,9 +12,6 @@ namespace NHibernate.Criterion /// An that constrains the property /// to a specified list of values. /// - /// - /// InExpression - should only be used with a Single Value column - no multicolumn properties... - /// [Serializable] public class InExpression : AbstractCriterion { @@ -62,41 +58,52 @@ public override SqlString ToSqlString(ICriteria criteria, ICriteriaQuery criteri return new SqlString("1=0"); } - //TODO: add default capacity - SqlStringBuilder result = new SqlStringBuilder(); - SqlString[] columnNames = - CriterionUtil.GetColumnNames(_propertyName, _projection, criteriaQuery, criteria); - - // Generate SqlString of the form: - // columnName1 in (values) and columnName2 in (values) and ... - Parameter[] parameters = GetParameterTypedValues(criteria, criteriaQuery).SelectMany(t => criteriaQuery.NewQueryParameter(t)).ToArray(); + SqlString[] columns = CriterionUtil.GetColumnNames(_propertyName, _projection, criteriaQuery, criteria); - for (int columnIndex = 0; columnIndex < columnNames.Length; columnIndex++) + var list = new List(columns.Length * Values.Length); + foreach (var typedValue in GetParameterTypedValues(criteria, criteriaQuery)) { - SqlString columnName = columnNames[columnIndex]; - - if (columnIndex > 0) - { - result.Add(" and "); - } + //Must be executed after CriterionUtil.GetColumnNames (as it might add _projection parameters to criteria) + list.AddRange(criteriaQuery.NewQueryParameter(typedValue)); + } - result - .Add(columnName) - .Add(" in ("); + var bogusParam = Parameter.Placeholder; - for (int i = 0; i < _values.Length; i++) - { - if (i > 0) - { - result.Add(StringHelper.CommaSpace); - } - result.Add(parameters[i]); - } + var sqlString = GetSqlString(criteriaQuery, columns, bogusParam); + sqlString.SubstituteBogusParameters(list, bogusParam); + return sqlString; + } - result.Add(")"); + private SqlString GetSqlString(ICriteriaQuery criteriaQuery, SqlString[] columns, Parameter bogusParam) + { + if (columns.Length <= 1 || criteriaQuery.Factory.Dialect.SupportsRowValueConstructorSyntaxInInList) + { + var wrapInParens = columns.Length > 1; + const string comaSeparator = ", "; + var singleValueParam = SqlStringHelper.Repeat(new SqlString(bogusParam), columns.Length, comaSeparator, wrapInParens); + + var parameters = SqlStringHelper.Repeat(singleValueParam, Values.Length, comaSeparator, wrapInParens: false); + + //single column: col1 in (?, ?) + //multi column: (col1, col2) in ((?, ?), (?, ?)) + return new SqlString( + wrapInParens ? StringHelper.OpenParen : string.Empty, + SqlStringHelper.Join(comaSeparator, columns), + wrapInParens ? StringHelper.ClosedParen : string.Empty, + " in (", + parameters, + ")"); } - return result.ToSqlString(); + //((col1 = ? and col2 = ?) or (col1 = ? and col2 = ?)) + var cols = new SqlString( + " ( ", + SqlStringHelper.Join(new SqlString(" = ", bogusParam, " and "), columns), + "= ", + bogusParam, + " ) "); + cols = SqlStringHelper.Repeat(cols, Values.Length, " or ", wrapInParens: Values.Length > 1); + return cols; } private void AssertPropertyIsNotCollection(ICriteriaQuery criteriaQuery, ICriteria criteria) @@ -122,29 +129,24 @@ private List GetParameterTypedValues(ICriteria criteria, ICriteriaQu { IType type = GetElementType(criteria, criteriaQuery); - if (type.IsComponentType) + if (!type.IsComponentType) { - List list = new List(); - IAbstractComponentType actype = (IAbstractComponentType) type; - IType[] types = actype.Subtypes; + return _values.ToList(v => new TypedValue(type, v, false)); + } - for (int i = 0; i < types.Length; i++) + List list = new List(); + IAbstractComponentType actype = (IAbstractComponentType) type; + var types = actype.Subtypes; + foreach (var value in _values) + { + var propertyValues = value != null ? actype.GetPropertyValues(value) : null; + for (int ti = 0; ti < types.Length; ti++) { - for (int j = 0; j < _values.Length; j++) - { - object subval = _values[j] == null - ? null - : actype.GetPropertyValues(_values[j])[i]; - list.Add(new TypedValue(types[i], subval, false)); - } + list.Add(new TypedValue(types[ti], propertyValues?[ti], false)); } - - return list; - } - else - { - return _values.ToList(v => new TypedValue(type, v, false)); } + + return list; } /// diff --git a/src/NHibernate/SqlCommand/SqlString.cs b/src/NHibernate/SqlCommand/SqlString.cs index 52682fd3ecd..33caf5e0096 100644 --- a/src/NHibernate/SqlCommand/SqlString.cs +++ b/src/NHibernate/SqlCommand/SqlString.cs @@ -1024,6 +1024,22 @@ public SqlString GetSubselectString() return new SubselectClauseExtractor(this).GetSqlString(); } + internal void SubstituteBogusParameters(IReadOnlyList actualParams, Parameter bogusParam) + { + int index = 0; + var keys = _parameters.Keys; + // The loop below is technically not altering the keys collection on which we iterate, but + // the underlying implementation still throws on foreach iterations over keys even if we + // have only changed the associated value. + // ReSharper disable once ForCanBeConvertedToForeach + for (var i = 0; i < keys.Count; i++) + { + var key = keys[i]; + if (ReferenceEquals(_parameters[key], bogusParam)) + _parameters[key] = actualParams[index++]; + } + } + [Serializable] private struct Part : IEquatable { diff --git a/src/NHibernate/SqlCommand/SqlStringHelper.cs b/src/NHibernate/SqlCommand/SqlStringHelper.cs index 6fcba3382ee..ccc4e2da9ac 100644 --- a/src/NHibernate/SqlCommand/SqlStringHelper.cs +++ b/src/NHibernate/SqlCommand/SqlStringHelper.cs @@ -30,6 +30,25 @@ public static SqlString Join(SqlString separator, IEnumerable objects) return buf.ToSqlString(); } + internal static SqlString Join(string separator, IList strings) + { + if (strings.Count == 0) + return SqlString.Empty; + + if (strings.Count == 1) + return strings[0]; + + var buf = new SqlStringBuilder(); + + buf.Add(strings[0]); + for (var index = 1; index < strings.Count; index++) + { + buf.Add(separator).Add(strings[index]); + } + + return buf.ToSqlString(); + } + public static SqlString[] Add(SqlString[] x, string sep, SqlString[] y) { SqlString[] result = new SqlString[x.Length]; @@ -85,5 +104,37 @@ internal static SqlString ParametersList(List parameters) return builder.ToSqlString(); } + + internal static SqlString Repeat(SqlString placeholder, int count, string separator, bool wrapInParens) + { + if (count == 0) + return SqlString.Empty; + + if (count == 1) + return wrapInParens + ? new SqlString("(", placeholder, ")") + : placeholder; + + var builder = new SqlStringBuilder((placeholder.Count + 1) * count + 1); + + if (wrapInParens) + { + builder.Add("("); + } + + builder.Add(placeholder); + + for (int i = 1; i < count; i++) + { + builder.Add(separator).Add(placeholder); + } + + if (wrapInParens) + { + builder.Add(")"); + } + + return builder.ToSqlString(); + } } }