Skip to content

Commit 50fa0d8

Browse files
authored
Optimize PersistentGenericSet snapshot (#2394)
1 parent 14fef3f commit 50fa0d8

File tree

6 files changed

+218
-57
lines changed

6 files changed

+218
-57
lines changed

src/NHibernate.Test/NHibernate.Test.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
<Compile Include="..\NHibernate\Util\AsyncReaderWriterLock.cs">
5050
<Link>UtilityTest\AsyncReaderWriterLock.cs</Link>
5151
</Compile>
52+
<Compile Include="..\NHibernate\Collection\Generic\SetHelpers\SetSnapShot.cs">
53+
<Link>UtilityTest\SetSnapShot.cs</Link>
54+
</Compile>
5255
</ItemGroup>
5356
<ItemGroup>
5457
<PackageReference Include="log4net" Version="2.0.8" />
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
using System.Collections.Generic;
2+
using System.IO;
3+
using System.Runtime.Serialization.Formatters.Binary;
4+
using NHibernate.Collection.Generic.SetHelpers;
5+
using NUnit.Framework;
6+
7+
namespace NHibernate.Test.UtilityTest
8+
{
9+
[TestFixture]
10+
public class SetSnapShotFixture
11+
{
12+
[Test]
13+
public void TestNullValue()
14+
{
15+
var sn = new SetSnapShot<object>(1);
16+
Assert.That(sn, Has.Count.EqualTo(0));
17+
Assert.That(sn, Is.EquivalentTo(new object[0]));
18+
Assert.That(sn.Contains(null), Is.False);
19+
Assert.That(sn.TryGetValue(null, out _), Is.False);
20+
21+
sn.Add(null);
22+
Assert.That(sn, Has.Count.EqualTo(1));
23+
Assert.That(sn, Is.EquivalentTo(new object[] {null}));
24+
25+
Assert.That(sn.TryGetValue(null, out var value), Is.True);
26+
Assert.That(sn.Contains(null), Is.True);
27+
Assert.That(value, Is.Null);
28+
29+
Assert.That(sn.Remove(null), Is.True);
30+
Assert.That(sn, Has.Count.EqualTo(0));
31+
Assert.That(sn, Is.EquivalentTo(new object[0]));
32+
33+
sn.Add(null);
34+
Assert.That(sn, Has.Count.EqualTo(1));
35+
36+
sn.Clear();
37+
Assert.That(sn, Has.Count.EqualTo(0));
38+
Assert.That(sn, Is.EquivalentTo(new object[0]));
39+
}
40+
41+
[Test]
42+
public void TestInitialization()
43+
{
44+
var list = new List<string> {"test1", null, "test2"};
45+
var sn = new SetSnapShot<string>(list);
46+
Assert.That(sn, Has.Count.EqualTo(list.Count));
47+
Assert.That(sn, Is.EquivalentTo(list));
48+
Assert.That(sn.TryGetValue("test1", out _), Is.True);
49+
Assert.That(sn.TryGetValue(null, out _), Is.True);
50+
}
51+
52+
[Test]
53+
public void TestCopyTo()
54+
{
55+
var list = new List<string> {"test1", null, "test2"};
56+
var sn = new SetSnapShot<string>(list);
57+
58+
var array = new string[3];
59+
sn.CopyTo(array, 0);
60+
Assert.That(list, Is.EquivalentTo(array));
61+
}
62+
63+
[Test]
64+
public void TestSerialization()
65+
{
66+
var list = new List<string> {"test1", null, "test2"};
67+
var sn = new SetSnapShot<string>(list);
68+
69+
sn = Deserialize<SetSnapShot<string>>(Serialize(sn));
70+
Assert.That(sn, Has.Count.EqualTo(list.Count));
71+
Assert.That(sn, Is.EquivalentTo(list));
72+
Assert.That(sn.TryGetValue("test1", out var item1), Is.True);
73+
Assert.That(item1, Is.EqualTo("test1"));
74+
Assert.That(sn.TryGetValue(null, out var nullValue), Is.True);
75+
Assert.That(nullValue, Is.Null);
76+
}
77+
78+
private static byte[] Serialize<T>(T obj)
79+
{
80+
var serializer = new BinaryFormatter();
81+
using (var stream = new MemoryStream())
82+
{
83+
serializer.Serialize(stream, obj);
84+
return stream.ToArray();
85+
}
86+
}
87+
88+
private static T Deserialize<T>(byte[] value)
89+
{
90+
var serializer = new BinaryFormatter();
91+
using (var stream = new MemoryStream(value))
92+
{
93+
return (T) serializer.Deserialize(stream);
94+
}
95+
}
96+
}
97+
}

src/NHibernate/Async/Collection/Generic/PersistentGenericSet.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public override async Task<bool> EqualsSnapshotAsync(ICollectionPersister persis
5454
{
5555
cancellationToken.ThrowIfCancellationRequested();
5656
var elementType = persister.ElementType;
57-
var snapshot = (ISetSnapshot<T>)GetSnapshot();
57+
var snapshot = (SetSnapShot<T>)GetSnapshot();
5858
if (((ICollection)snapshot).Count != WrappedSet.Count)
5959
{
6060
return false;
@@ -122,7 +122,7 @@ public override async Task<IEnumerable> GetDeletesAsync(ICollectionPersister per
122122
{
123123
cancellationToken.ThrowIfCancellationRequested();
124124
IType elementType = persister.ElementType;
125-
var sn = (ISetSnapshot<T>)GetSnapshot();
125+
var sn = (SetSnapShot<T>)GetSnapshot();
126126
var deletes = new List<T>(((ICollection<T>)sn).Count);
127127

128128
deletes.AddRange(sn.Where(obj => !WrappedSet.Contains(obj)));
@@ -140,7 +140,7 @@ public override async Task<IEnumerable> GetDeletesAsync(ICollectionPersister per
140140
public override async Task<bool> NeedsInsertingAsync(object entry, int i, IType elemType, CancellationToken cancellationToken)
141141
{
142142
cancellationToken.ThrowIfCancellationRequested();
143-
var sn = (ISetSnapshot<T>)GetSnapshot();
143+
var sn = (SetSnapShot<T>)GetSnapshot();
144144
T oldKey;
145145

146146
// note that it might be better to iterate the snapshot but this is safe,

src/NHibernate/Collection/Generic/PersistentGenericSet.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public override ICollection GetOrphans(object snapshot, string entityName)
103103
public override bool EqualsSnapshot(ICollectionPersister persister)
104104
{
105105
var elementType = persister.ElementType;
106-
var snapshot = (ISetSnapshot<T>)GetSnapshot();
106+
var snapshot = (SetSnapShot<T>)GetSnapshot();
107107
if (((ICollection)snapshot).Count != WrappedSet.Count)
108108
{
109109
return false;
@@ -217,7 +217,7 @@ public override object Disassemble(ICollectionPersister persister)
217217
public override IEnumerable GetDeletes(ICollectionPersister persister, bool indexIsFormula)
218218
{
219219
IType elementType = persister.ElementType;
220-
var sn = (ISetSnapshot<T>)GetSnapshot();
220+
var sn = (SetSnapShot<T>)GetSnapshot();
221221
var deletes = new List<T>(((ICollection<T>)sn).Count);
222222

223223
deletes.AddRange(sn.Where(obj => !WrappedSet.Contains(obj)));
@@ -234,7 +234,7 @@ public override IEnumerable GetDeletes(ICollectionPersister persister, bool inde
234234

235235
public override bool NeedsInserting(object entry, int i, IType elemType)
236236
{
237-
var sn = (ISetSnapshot<T>)GetSnapshot();
237+
var sn = (SetSnapShot<T>)GetSnapshot();
238238
T oldKey;
239239

240240
// note that it might be better to iterate the snapshot but this is safe,

src/NHibernate/Collection/Generic/SetHelpers/ISetSnapshot.cs

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 112 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,180 @@
11
using System;
22
using System.Collections;
33
using System.Collections.Generic;
4+
#if NETCOREAPP2_0
5+
using System.Runtime.Serialization;
6+
using System.Threading;
7+
#endif
48

59
namespace NHibernate.Collection.Generic.SetHelpers
610
{
11+
#if NETFX || NETSTANDARD2_0
12+
// TODO 6.0: Consider removing this class in case we upgrade to .NET 4.7.2 and NET Standard 2.1,
13+
// which have HashSet<T>.TryGetValue
714
[Serializable]
8-
internal class SetSnapShot<T> : ISetSnapshot<T>
15+
internal class SetSnapShot<T> : ICollection<T>, IReadOnlyCollection<T>, ICollection
916
{
10-
private readonly List<T> _elements;
11-
public SetSnapShot()
12-
{
13-
_elements = new List<T>();
14-
}
17+
private readonly Dictionary<T, T> _values;
18+
private bool _hasNull;
1519

1620
public SetSnapShot(int capacity)
1721
{
18-
_elements = new List<T>(capacity);
22+
_values = new Dictionary<T, T>(capacity);
1923
}
2024

2125
public SetSnapShot(IEnumerable<T> collection)
2226
{
23-
_elements = new List<T>(collection);
27+
_values = new Dictionary<T, T>();
28+
foreach (var item in collection)
29+
{
30+
if (item == null)
31+
{
32+
_hasNull = true;
33+
}
34+
else
35+
{
36+
_values.Add(item, item);
37+
}
38+
}
2439
}
2540

26-
public IEnumerator<T> GetEnumerator()
41+
public bool TryGetValue(T equalValue, out T actualValue)
2742
{
28-
return _elements.GetEnumerator();
43+
if (equalValue != null)
44+
{
45+
return _values.TryGetValue(equalValue, out actualValue);
46+
}
47+
48+
actualValue = default(T);
49+
return _hasNull;
2950
}
3051

31-
IEnumerator IEnumerable.GetEnumerator()
52+
public IEnumerator<T> GetEnumerator()
3253
{
33-
return GetEnumerator();
54+
if (_hasNull)
55+
{
56+
yield return default(T);
57+
}
58+
59+
foreach (var item in _values.Keys)
60+
{
61+
yield return item;
62+
}
3463
}
3564

3665
public void Add(T item)
3766
{
38-
_elements.Add(item);
67+
if (item == null)
68+
{
69+
_hasNull = true;
70+
return;
71+
}
72+
73+
_values.Add(item, item);
3974
}
4075

4176
public void Clear()
4277
{
43-
throw new InvalidOperationException();
78+
_values.Clear();
79+
_hasNull = false;
4480
}
4581

4682
public bool Contains(T item)
4783
{
48-
return _elements.Contains(item);
84+
return item == null ? _hasNull : _values.ContainsKey(item);
4985
}
5086

5187
public void CopyTo(T[] array, int arrayIndex)
5288
{
53-
_elements.CopyTo(array, arrayIndex);
89+
if (_hasNull)
90+
array[arrayIndex] = default(T);
91+
_values.Keys.CopyTo(array, arrayIndex + (_hasNull ? 1 : 0));
5492
}
5593

5694
public bool Remove(T item)
5795
{
58-
throw new InvalidOperationException();
59-
}
96+
if (item != null)
97+
{
98+
return _values.Remove(item);
99+
}
60100

61-
public void CopyTo(Array array, int index)
62-
{
63-
((ICollection)_elements).CopyTo(array, index);
101+
if (!_hasNull)
102+
{
103+
return false;
104+
}
105+
106+
_hasNull = false;
107+
return true;
64108
}
65109

66-
int ICollection.Count
110+
IEnumerator IEnumerable.GetEnumerator()
67111
{
68-
get { return _elements.Count; }
112+
return GetEnumerator();
69113
}
70114

71-
public object SyncRoot
115+
void ICollection.CopyTo(Array array, int index)
72116
{
73-
get { return ((ICollection)_elements).SyncRoot; }
117+
if (!(array is T[] typedArray))
118+
{
119+
throw new ArgumentException($"Array must be of type {typeof(T[])}.", nameof(array));
120+
}
121+
122+
CopyTo(typedArray, index);
74123
}
75124

76-
public bool IsSynchronized
125+
public int Count => _values.Count + (_hasNull ? 1 : 0);
126+
127+
public bool IsReadOnly => ((ICollection<KeyValuePair<T, T>>) _values).IsReadOnly;
128+
129+
public object SyncRoot => ((ICollection) _values).SyncRoot;
130+
131+
public bool IsSynchronized => ((ICollection) _values).IsSynchronized;
132+
}
133+
#endif
134+
135+
#if NETCOREAPP2_0
136+
[Serializable]
137+
internal class SetSnapShot<T> : HashSet<T>, ICollection
138+
{
139+
[NonSerialized]
140+
private object _syncRoot;
141+
142+
public SetSnapShot(int capacity) : base(capacity)
77143
{
78-
get { return ((ICollection)_elements).IsSynchronized; }
79144
}
80145

81-
int ICollection<T>.Count
146+
public SetSnapShot(IEnumerable<T> collection) : base(collection)
82147
{
83-
get { return _elements.Count; }
84148
}
85149

86-
int IReadOnlyCollection<T>.Count
150+
protected SetSnapShot(SerializationInfo info, StreamingContext context) : base(info, context)
87151
{
88-
get { return _elements.Count; }
89152
}
90153

91-
public bool IsReadOnly
154+
void ICollection.CopyTo(Array array, int index)
92155
{
93-
get { return ((ICollection<T>)_elements).IsReadOnly; }
156+
if (!(array is T[] typedArray))
157+
{
158+
throw new ArgumentException($"Array must be of type {typeof(T[])}.", nameof(array));
159+
}
160+
161+
CopyTo(typedArray, index);
94162
}
95163

96-
public bool TryGetValue(T element, out T value)
164+
bool ICollection.IsSynchronized => false;
165+
166+
object ICollection.SyncRoot
97167
{
98-
var idx = _elements.IndexOf(element);
99-
if (idx >= 0)
168+
get
100169
{
101-
value = _elements[idx];
102-
return true;
103-
}
170+
if (_syncRoot == null)
171+
{
172+
Interlocked.CompareExchange<object>(ref _syncRoot, new object(), null);
173+
}
104174

105-
value = default(T);
106-
return false;
175+
return _syncRoot;
176+
}
107177
}
108178
}
179+
#endif
109180
}

0 commit comments

Comments
 (0)