Skip to content

Commit 8ac6fa3

Browse files
bahusoidhazzik
andauthored
Proper support for IN clause for composite values in Criteria (#2158)
Co-authored-by: Alexander Zaytsev <hazzik@gmail.com>
1 parent 7a8d36d commit 8ac6fa3

File tree

5 files changed

+203
-60
lines changed

5 files changed

+203
-60
lines changed

src/NHibernate.Test/Async/CompositeId/ClassWithCompositeIdFixture.cs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@ protected override string[] Mappings
3838
get { return new string[] {"CompositeId.ClassWithCompositeId.hbm.xml"}; }
3939
}
4040

41-
protected override bool AppliesTo(Dialect.Dialect dialect)
42-
{
43-
return !(dialect is Dialect.FirebirdDialect); // Firebird has no CommandTimeout, and locks up during the tear-down of this fixture
44-
}
45-
4641
protected override void OnSetUp()
4742
{
4843
id = new Id("stringKey", 3, firstDateTime);
@@ -52,9 +47,11 @@ protected override void OnSetUp()
5247
protected override void OnTearDown()
5348
{
5449
using (ISession s = Sfi.OpenSession())
50+
using (var t = s.BeginTransaction())
5551
{
5652
s.Delete("from ClassWithCompositeId");
5753
s.Flush();
54+
t.Commit();
5855
}
5956
}
6057

@@ -396,5 +393,45 @@ public async Task QueryOverOrderByAndWhereWithIdProjectionDoesntThrowAsync()
396393
Assert.That(results.Count, Is.EqualTo(1));
397394
}
398395
}
396+
397+
[Test]
398+
public async Task QueryOverInClauseAsync()
399+
{
400+
// insert the new objects
401+
var id1 = id;
402+
var id2 = secondId;
403+
var id3 = new Id(id1.KeyString, id1.GetKeyShort(), id2.KeyDateTime);
404+
405+
using (ISession s = OpenSession())
406+
using (ITransaction t = s.BeginTransaction())
407+
{
408+
await (s.SaveAsync(new ClassWithCompositeId(id1) {OneProperty = 5}));
409+
await (s.SaveAsync(new ClassWithCompositeId(id2) {OneProperty = 10}));
410+
await (s.SaveAsync(new ClassWithCompositeId(id3)));
411+
412+
await (t.CommitAsync());
413+
}
414+
415+
using (var s = OpenSession())
416+
{
417+
var results1 = await (s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).IsIn(new[] {id1, id2}).ListAsync());
418+
var results2 = await (s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).IsIn(new[] {id1}).ListAsync());
419+
var results3 = await (s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).Not.IsIn(new[] {id1, id2}).ListAsync());
420+
var results4 = await (s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).Not.IsIn(new[] {id1}).ListAsync());
421+
422+
Assert.Multiple(
423+
() =>
424+
{
425+
Assert.That(results1.Count, Is.EqualTo(2), "in multiple ids");
426+
Assert.That(results1.Select(r => r.Id), Is.EquivalentTo(new[] {id1, id2}), "in multiple ids");
427+
Assert.That(results2.Count, Is.EqualTo(1), "in single id");
428+
Assert.That(results2.Select(r => r.Id), Is.EquivalentTo(new[] {id1}), "in single id");
429+
Assert.That(results3.Count, Is.EqualTo(1), "not in multiple ids");
430+
Assert.That(results3.Select(r => r.Id), Is.EquivalentTo(new[] {id3}), "not in multiple ids");
431+
Assert.That(results4.Count, Is.EqualTo(2), "not in single id");
432+
Assert.That(results4.Select(r => r.Id), Is.EquivalentTo(new[] {id2, id3}), "not in single id");
433+
});
434+
}
435+
}
399436
}
400437
}

src/NHibernate.Test/CompositeId/ClassWithCompositeIdFixture.cs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ protected override string[] Mappings
2727
get { return new string[] {"CompositeId.ClassWithCompositeId.hbm.xml"}; }
2828
}
2929

30-
protected override bool AppliesTo(Dialect.Dialect dialect)
31-
{
32-
return !(dialect is Dialect.FirebirdDialect); // Firebird has no CommandTimeout, and locks up during the tear-down of this fixture
33-
}
34-
3530
protected override void OnSetUp()
3631
{
3732
id = new Id("stringKey", 3, firstDateTime);
@@ -41,9 +36,11 @@ protected override void OnSetUp()
4136
protected override void OnTearDown()
4237
{
4338
using (ISession s = Sfi.OpenSession())
39+
using (var t = s.BeginTransaction())
4440
{
4541
s.Delete("from ClassWithCompositeId");
4642
s.Flush();
43+
t.Commit();
4744
}
4845
}
4946

@@ -385,5 +382,45 @@ public void QueryOverOrderByAndWhereWithIdProjectionDoesntThrow()
385382
Assert.That(results.Count, Is.EqualTo(1));
386383
}
387384
}
385+
386+
[Test]
387+
public void QueryOverInClause()
388+
{
389+
// insert the new objects
390+
var id1 = id;
391+
var id2 = secondId;
392+
var id3 = new Id(id1.KeyString, id1.GetKeyShort(), id2.KeyDateTime);
393+
394+
using (ISession s = OpenSession())
395+
using (ITransaction t = s.BeginTransaction())
396+
{
397+
s.Save(new ClassWithCompositeId(id1) {OneProperty = 5});
398+
s.Save(new ClassWithCompositeId(id2) {OneProperty = 10});
399+
s.Save(new ClassWithCompositeId(id3));
400+
401+
t.Commit();
402+
}
403+
404+
using (var s = OpenSession())
405+
{
406+
var results1 = s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).IsIn(new[] {id1, id2}).List();
407+
var results2 = s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).IsIn(new[] {id1}).List();
408+
var results3 = s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).Not.IsIn(new[] {id1, id2}).List();
409+
var results4 = s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).Not.IsIn(new[] {id1}).List();
410+
411+
Assert.Multiple(
412+
() =>
413+
{
414+
Assert.That(results1.Count, Is.EqualTo(2), "in multiple ids");
415+
Assert.That(results1.Select(r => r.Id), Is.EquivalentTo(new[] {id1, id2}), "in multiple ids");
416+
Assert.That(results2.Count, Is.EqualTo(1), "in single id");
417+
Assert.That(results2.Select(r => r.Id), Is.EquivalentTo(new[] {id1}), "in single id");
418+
Assert.That(results3.Count, Is.EqualTo(1), "not in multiple ids");
419+
Assert.That(results3.Select(r => r.Id), Is.EquivalentTo(new[] {id3}), "not in multiple ids");
420+
Assert.That(results4.Count, Is.EqualTo(2), "not in single id");
421+
Assert.That(results4.Select(r => r.Id), Is.EquivalentTo(new[] {id2, id3}), "not in single id");
422+
});
423+
}
424+
}
388425
}
389426
}

src/NHibernate/Criterion/InExpression.cs

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using System;
2-
using System.Collections;
32
using System.Collections.Generic;
43
using System.Linq;
54
using NHibernate.Engine;
@@ -13,9 +12,6 @@ namespace NHibernate.Criterion
1312
/// An <see cref="ICriterion"/> that constrains the property
1413
/// to a specified list of values.
1514
/// </summary>
16-
/// <remarks>
17-
/// InExpression - should only be used with a Single Value column - no multicolumn properties...
18-
/// </remarks>
1915
[Serializable]
2016
public class InExpression : AbstractCriterion
2117
{
@@ -62,41 +58,52 @@ public override SqlString ToSqlString(ICriteria criteria, ICriteriaQuery criteri
6258
return new SqlString("1=0");
6359
}
6460

65-
//TODO: add default capacity
66-
SqlStringBuilder result = new SqlStringBuilder();
67-
SqlString[] columnNames =
68-
CriterionUtil.GetColumnNames(_propertyName, _projection, criteriaQuery, criteria);
69-
70-
// Generate SqlString of the form:
71-
// columnName1 in (values) and columnName2 in (values) and ...
72-
Parameter[] parameters = GetParameterTypedValues(criteria, criteriaQuery).SelectMany(t => criteriaQuery.NewQueryParameter(t)).ToArray();
61+
SqlString[] columns = CriterionUtil.GetColumnNames(_propertyName, _projection, criteriaQuery, criteria);
7362

74-
for (int columnIndex = 0; columnIndex < columnNames.Length; columnIndex++)
63+
var list = new List<Parameter>(columns.Length * Values.Length);
64+
foreach (var typedValue in GetParameterTypedValues(criteria, criteriaQuery))
7565
{
76-
SqlString columnName = columnNames[columnIndex];
77-
78-
if (columnIndex > 0)
79-
{
80-
result.Add(" and ");
81-
}
66+
//Must be executed after CriterionUtil.GetColumnNames (as it might add _projection parameters to criteria)
67+
list.AddRange(criteriaQuery.NewQueryParameter(typedValue));
68+
}
8269

83-
result
84-
.Add(columnName)
85-
.Add(" in (");
70+
var bogusParam = Parameter.Placeholder;
8671

87-
for (int i = 0; i < _values.Length; i++)
88-
{
89-
if (i > 0)
90-
{
91-
result.Add(StringHelper.CommaSpace);
92-
}
93-
result.Add(parameters[i]);
94-
}
72+
var sqlString = GetSqlString(criteriaQuery, columns, bogusParam);
73+
sqlString.SubstituteBogusParameters(list, bogusParam);
74+
return sqlString;
75+
}
9576

96-
result.Add(")");
77+
private SqlString GetSqlString(ICriteriaQuery criteriaQuery, SqlString[] columns, Parameter bogusParam)
78+
{
79+
if (columns.Length <= 1 || criteriaQuery.Factory.Dialect.SupportsRowValueConstructorSyntaxInInList)
80+
{
81+
var wrapInParens = columns.Length > 1;
82+
const string comaSeparator = ", ";
83+
var singleValueParam = SqlStringHelper.Repeat(new SqlString(bogusParam), columns.Length, comaSeparator, wrapInParens);
84+
85+
var parameters = SqlStringHelper.Repeat(singleValueParam, Values.Length, comaSeparator, wrapInParens: false);
86+
87+
//single column: col1 in (?, ?)
88+
//multi column: (col1, col2) in ((?, ?), (?, ?))
89+
return new SqlString(
90+
wrapInParens ? StringHelper.OpenParen : string.Empty,
91+
SqlStringHelper.Join(comaSeparator, columns),
92+
wrapInParens ? StringHelper.ClosedParen : string.Empty,
93+
" in (",
94+
parameters,
95+
")");
9796
}
9897

99-
return result.ToSqlString();
98+
//((col1 = ? and col2 = ?) or (col1 = ? and col2 = ?))
99+
var cols = new SqlString(
100+
" ( ",
101+
SqlStringHelper.Join(new SqlString(" = ", bogusParam, " and "), columns),
102+
"= ",
103+
bogusParam,
104+
" ) ");
105+
cols = SqlStringHelper.Repeat(cols, Values.Length, " or ", wrapInParens: Values.Length > 1);
106+
return cols;
100107
}
101108

102109
private void AssertPropertyIsNotCollection(ICriteriaQuery criteriaQuery, ICriteria criteria)
@@ -122,29 +129,24 @@ private List<TypedValue> GetParameterTypedValues(ICriteria criteria, ICriteriaQu
122129
{
123130
IType type = GetElementType(criteria, criteriaQuery);
124131

125-
if (type.IsComponentType)
132+
if (!type.IsComponentType)
126133
{
127-
List<TypedValue> list = new List<TypedValue>();
128-
IAbstractComponentType actype = (IAbstractComponentType) type;
129-
IType[] types = actype.Subtypes;
134+
return _values.ToList(v => new TypedValue(type, v, false));
135+
}
130136

131-
for (int i = 0; i < types.Length; i++)
137+
List<TypedValue> list = new List<TypedValue>();
138+
IAbstractComponentType actype = (IAbstractComponentType) type;
139+
var types = actype.Subtypes;
140+
foreach (var value in _values)
141+
{
142+
var propertyValues = value != null ? actype.GetPropertyValues(value) : null;
143+
for (int ti = 0; ti < types.Length; ti++)
132144
{
133-
for (int j = 0; j < _values.Length; j++)
134-
{
135-
object subval = _values[j] == null
136-
? null
137-
: actype.GetPropertyValues(_values[j])[i];
138-
list.Add(new TypedValue(types[i], subval, false));
139-
}
145+
list.Add(new TypedValue(types[ti], propertyValues?[ti], false));
140146
}
141-
142-
return list;
143-
}
144-
else
145-
{
146-
return _values.ToList(v => new TypedValue(type, v, false));
147147
}
148+
149+
return list;
148150
}
149151

150152
/// <summary>

src/NHibernate/SqlCommand/SqlString.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,22 @@ public SqlString GetSubselectString()
10241024
return new SubselectClauseExtractor(this).GetSqlString();
10251025
}
10261026

1027+
internal void SubstituteBogusParameters(IReadOnlyList<Parameter> actualParams, Parameter bogusParam)
1028+
{
1029+
int index = 0;
1030+
var keys = _parameters.Keys;
1031+
// The loop below is technically not altering the keys collection on which we iterate, but
1032+
// the underlying implementation still throws on foreach iterations over keys even if we
1033+
// have only changed the associated value.
1034+
// ReSharper disable once ForCanBeConvertedToForeach
1035+
for (var i = 0; i < keys.Count; i++)
1036+
{
1037+
var key = keys[i];
1038+
if (ReferenceEquals(_parameters[key], bogusParam))
1039+
_parameters[key] = actualParams[index++];
1040+
}
1041+
}
1042+
10271043
[Serializable]
10281044
private struct Part : IEquatable<Part>
10291045
{

src/NHibernate/SqlCommand/SqlStringHelper.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,25 @@ public static SqlString Join(SqlString separator, IEnumerable objects)
3030
return buf.ToSqlString();
3131
}
3232

33+
internal static SqlString Join(string separator, IList<SqlString> strings)
34+
{
35+
if (strings.Count == 0)
36+
return SqlString.Empty;
37+
38+
if (strings.Count == 1)
39+
return strings[0];
40+
41+
var buf = new SqlStringBuilder();
42+
43+
buf.Add(strings[0]);
44+
for (var index = 1; index < strings.Count; index++)
45+
{
46+
buf.Add(separator).Add(strings[index]);
47+
}
48+
49+
return buf.ToSqlString();
50+
}
51+
3352
public static SqlString[] Add(SqlString[] x, string sep, SqlString[] y)
3453
{
3554
SqlString[] result = new SqlString[x.Length];
@@ -85,5 +104,37 @@ internal static SqlString ParametersList(List<Parameter> parameters)
85104

86105
return builder.ToSqlString();
87106
}
107+
108+
internal static SqlString Repeat(SqlString placeholder, int count, string separator, bool wrapInParens)
109+
{
110+
if (count == 0)
111+
return SqlString.Empty;
112+
113+
if (count == 1)
114+
return wrapInParens
115+
? new SqlString("(", placeholder, ")")
116+
: placeholder;
117+
118+
var builder = new SqlStringBuilder((placeholder.Count + 1) * count + 1);
119+
120+
if (wrapInParens)
121+
{
122+
builder.Add("(");
123+
}
124+
125+
builder.Add(placeholder);
126+
127+
for (int i = 1; i < count; i++)
128+
{
129+
builder.Add(separator).Add(placeholder);
130+
}
131+
132+
if (wrapInParens)
133+
{
134+
builder.Add(")");
135+
}
136+
137+
return builder.ToSqlString();
138+
}
88139
}
89140
}

0 commit comments

Comments
 (0)