diff --git a/src/NHibernate.Test/UtilityTest/SetSnapShotFixture.cs b/src/NHibernate.Test/UtilityTest/SetSnapShotFixture.cs index e00af722111..47b03b50ec9 100644 --- a/src/NHibernate.Test/UtilityTest/SetSnapShotFixture.cs +++ b/src/NHibernate.Test/UtilityTest/SetSnapShotFixture.cs @@ -1,7 +1,10 @@ -using System.Collections.Generic; +using System; +using System.Collections; +using System.Collections.Generic; using System.IO; using System.Runtime.Serialization.Formatters.Binary; using NHibernate.Collection.Generic.SetHelpers; +using NSubstitute.ExceptionExtensions; using NUnit.Framework; namespace NHibernate.Test.UtilityTest @@ -70,6 +73,29 @@ public void TestCopyTo() Assert.That(list, Is.EquivalentTo(array)); } + [Test] + public void TestCopyToObjectArray() + { + var list = new List { "test1", null, "test2" }; + ICollection sn = new SetSnapShot(list); + + var array = new object[3]; + sn.CopyTo(array, 0); + Assert.That(list, Is.EquivalentTo(array)); + } + + [Test] + public void WhenCopyToIsCalledWithIncompatibleArrayTypeThenThrowArgumentOrInvalidCastException() + { + var list = new List { "test1", null, "test2" }; + ICollection sn = new SetSnapShot(list); + + var array = new int[3]; + Assert.That( + () => sn.CopyTo(array, 0), + Throws.ArgumentException.Or.TypeOf()); + } + [Test] public void TestSerialization() { diff --git a/src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs b/src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs index 0c052a67ba8..edafec312ea 100644 --- a/src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs +++ b/src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs @@ -114,12 +114,16 @@ IEnumerator IEnumerable.GetEnumerator() void ICollection.CopyTo(Array array, int index) { - if (!(array is T[] typedArray)) + if (array is T[] typedArray) { - throw new ArgumentException($"Array must be of type {typeof(T[])}.", nameof(array)); + CopyTo(typedArray, index); + return; } - CopyTo(typedArray, index); + if (_hasNull) + array.SetValue(default(T), index); + ICollection keysCollection = _values.Keys; + keysCollection.CopyTo(array, index + (_hasNull ? 1 : 0)); } public int Count => _values.Count + (_hasNull ? 1 : 0); @@ -153,12 +157,26 @@ protected SetSnapShot(SerializationInfo info, StreamingContext context) : base(i void ICollection.CopyTo(Array array, int index) { - if (!(array is T[] typedArray)) + if (array is T[] typedArray) { - throw new ArgumentException($"Array must be of type {typeof(T[])}.", nameof(array)); + CopyTo(typedArray, index); + return; } - CopyTo(typedArray, index); + if (array == null) + throw new ArgumentNullException(nameof(array)); + + if (index < 0) + throw new ArgumentOutOfRangeException(nameof(index), index, "Array index cannot be negative"); + + if (index > array.Length || Count > array.Length - index) + throw new ArgumentException("Provided array is too small", nameof(array)); + + foreach (var value in this) + { + array.SetValue(value, index); + index++; + } } bool ICollection.IsSynchronized => false;