Skip to content

NH-3675 - Support Bulk Inserts #367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/NHibernate.Test/DriverTest/BulkInsertTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using System.Linq;
using NHibernate.Cfg;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Linq;
using NHibernate.Test.Linq;
using NUnit.Framework;

namespace NHibernate.Test.DriverTest
{
[TestFixture]
public class BulkInsertTests : LinqTestCase
{
protected override void Configure(Configuration configuration)
{
configuration.SetProperty(Environment.Hbm2ddlAuto, SchemaAutoAction.Create.ToString());

base.Configure(configuration);
}

[Test]
public void CanBulkInsertEntitiesWithComponents()
{
//NH-3675
using (var statelessSession = session.SessionFactory.OpenStatelessSession())
using (statelessSession.BeginTransaction())
{
var customers = new Customer[] { new Customer { Address = new Address("street", "city", "region", "postalCode", "country", "phoneNumber", "fax"), CompanyName = "Company", ContactName = "Contact", ContactTitle = "Title", CustomerId = "12345" } };

statelessSession.CreateQuery("delete from Customer").ExecuteUpdate();

statelessSession.BulkInsert(customers);

var count = statelessSession.Query<Customer>().Count();

Assert.AreEqual(customers.Count(), count);
}
}

[Test]
public void CanBulkInsertEntitiesWithComponentsAndAssociations()
{
//NH-3675
using (var statelessSession = session.SessionFactory.OpenStatelessSession())
using (statelessSession.BeginTransaction())
{
var superior = new Employee { Address = new Address("street", "city", "region", "zip", "country", "phone", "fax"), BirthDate = System.DateTime.Now, EmployeeId = 1, Extension = "1", FirstName = "Superior", LastName = "Last" };
var employee = new Employee { Address = new Address("street", "city", "region", "zip", "country", "phone", "fax"), BirthDate = System.DateTime.Now, EmployeeId = 2, Extension = "2", FirstName = "Employee", LastName = "Last", Superior = superior };
var employees = new Employee[] { superior, employee };

statelessSession.CreateQuery("delete from Employee").ExecuteUpdate();

statelessSession.BulkInsert(employees);

var count = statelessSession.Query<Employee>().Count();

Assert.AreEqual(employees.Count(), count);
}
}
}
}
4 changes: 4 additions & 0 deletions src/NHibernate/Cfg/Environment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ public static string Version
}
}

public const String BulkProviderClass = "adonet.bulk_provider_class";
public const String BulkProviderTimeout = "adonet.bulk_provider_timeout";
public const String BulkProviderBatchSize = "adonet.bulk_provider_batch_size";

public const string ConnectionProvider = "connection.provider";
public const string ConnectionDriver = "connection.driver_class";
public const string ConnectionString = "connection.connection_string";
Expand Down
63 changes: 63 additions & 0 deletions src/NHibernate/Driver/BulkProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System;
using System.Collections.Generic;
using NHibernate.Engine;
using NHibernate.Id;
using NHibernate.Persister.Entity;
using Environment = NHibernate.Cfg.Environment;

namespace NHibernate.Driver
{
public abstract class BulkProvider : IDisposable
{
protected BulkProvider()
{
}

~BulkProvider()
{
this.Dispose(false);
}

public Int32 BatchSize { get; set; }

public Int32 Timeout { get; set; }

public abstract void Insert<T>(ISessionImplementor session, IEnumerable<T> entities) where T : class;

public virtual void Initialize(IDictionary<String, String> properties)
{
var timeout = string.Empty;
var batchSize = string.Empty;

if (properties.TryGetValue(Environment.BulkProviderTimeout, out timeout))
{
this.Timeout = Convert.ToInt32(timeout);
}

if (properties.TryGetValue(Environment.BulkProviderBatchSize, out batchSize))
{
this.BatchSize = Convert.ToInt32(batchSize);
}
}

protected virtual void FillIdentifier(ISessionImplementor session, IEntityPersister persister, Object entity)
{
if (!(persister.IdentifierGenerator is Assigned) && !(persister.IdentifierGenerator is ForeignGenerator))
{
var id = persister.IdentifierGenerator.Generate(session, entity);

persister.SetIdentifier(entity, id, session.EntityMode);
}
}

protected virtual void Dispose(Boolean disposing)
{
}

public void Dispose()
{
this.Dispose(true);
GC.SuppressFinalize(this);
}
}
}
24 changes: 24 additions & 0 deletions src/NHibernate/Driver/DefaultBulkProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using NHibernate.Engine;

namespace NHibernate.Driver
{
sealed class DefaultBulkProvider : BulkProvider
{
public override void Insert<T>(ISessionImplementor session, IEnumerable<T> entities)
{
var statelessSession = session as IStatelessSession;

if (statelessSession == null)
{
throw new InvalidOperationException("Insert can only be called with stateless sessions.");
}

foreach (var entity in entities)
{
statelessSession.Insert(entity);
}
}
}
}
5 changes: 5 additions & 0 deletions src/NHibernate/Driver/DriverBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ public abstract class DriverBase : IDriver, ISqlParameterFormatter
private int commandTimeout;
private bool prepareSql;

public virtual BulkProvider GetBulkProvider()
{
return new DefaultBulkProvider();
}

public virtual void Configure(IDictionary<string, string> settings)
{
// Command timeout
Expand Down
5 changes: 5 additions & 0 deletions src/NHibernate/Driver/IDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ namespace NHibernate.Driver
/// </remarks>
public interface IDriver
{
/// <summary>
/// Returns a bulk provider for the current driver.
/// </summary>
BulkProvider GetBulkProvider();

/// <summary>
/// Configure the driver using <paramref name="settings"/>.
/// </summary>
Expand Down
55 changes: 55 additions & 0 deletions src/NHibernate/Driver/OracleDataClientBulkProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Reflection;
using NHibernate.Engine;

namespace NHibernate.Driver
{
public class OracleDataClientBulkProvider : TableBasedBulkProvider
{
public const String BulkProviderOptions = "adonet.bulk_provider_options";

private static readonly System.Type bulkCopyOptionsType = System.Type.GetType("Oracle.DataAccess.Client.OracleBulkCopyOptions, Oracle.DataAccess");
private static readonly System.Type bulkCopyType = System.Type.GetType("Oracle.DataAccess.Client.OracleBulkCopy, Oracle.DataAccess");
private static readonly PropertyInfo batchSizeProperty = bulkCopyType.GetProperty("BatchSize");
private static readonly PropertyInfo bulkCopyTimeoutProperty = bulkCopyType.GetProperty("BulkCopyTimeout");
private static readonly PropertyInfo destinationTableNameProperty = bulkCopyType.GetProperty("DestinationTableName");
private static readonly MethodInfo writeToServerMethod = bulkCopyType.GetMethod("WriteToServer", new System.Type[] { typeof(DataTable) });

public Int32 Options { get; set; }

public Int32 NotifyAfter { get; set; }

public override void Initialize(IDictionary<String, String> properties)
{
base.Initialize(properties);

var bulkProviderOptions = String.Empty;

if (properties.TryGetValue(BulkProviderOptions, out bulkProviderOptions))
{
this.Options = Convert.ToInt32(bulkProviderOptions);
}
}

public override void Insert<T>(ISessionImplementor session, IEnumerable<T> entities)
{
if (entities.Any() == true)
{
foreach (var table in this.GetTables(session, entities))
{
using (var copy = Activator.CreateInstance(bulkCopyType, session.Connection, Enum.ToObject(bulkCopyOptionsType, this.Options)) as IDisposable)
{
batchSizeProperty.SetValue(copy, this.BatchSize, null);
bulkCopyTimeoutProperty.SetValue(copy, this.Timeout, null);
destinationTableNameProperty.SetValue(copy, table.TableName, null);

writeToServerMethod.Invoke(copy, new Object[] { table });
}
}
}
}
}
}
5 changes: 5 additions & 0 deletions src/NHibernate/Driver/OracleDataClientDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public OracleDataClientDriver()
oracleDbTypeXmlType = Enum.Parse(oracleDbTypeEnum, "XmlType");
}

public override BulkProvider GetBulkProvider()
{
return new OracleDataClientBulkProvider();
}

/// <summary></summary>
public override bool UseNamedPrefixInSql
{
Expand Down
46 changes: 46 additions & 0 deletions src/NHibernate/Driver/SqlBulkProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using System;
using System.Collections.Generic;
using System.Data.SqlClient;
using System.Linq;
using NHibernate.Engine;
using NHibernate.Transaction;

namespace NHibernate.Driver
{
public class SqlBulkProvider : TableBasedBulkProvider
{
public const String BulkProviderOptions = "adonet.bulk_provider_options";

public SqlBulkCopyOptions Options { get; set; }


public override void Initialize(IDictionary<String, String> properties)
{
base.Initialize(properties);

var bulkProviderOptions = String.Empty;

if (properties.TryGetValue(BulkProviderOptions, out bulkProviderOptions))
{
this.Options = (SqlBulkCopyOptions)Enum.Parse(typeof(SqlBulkCopyOptions), bulkProviderOptions, true);
}
}

public override void Insert<T>(ISessionImplementor session, IEnumerable<T> entities)
{
if (entities.Any() == true)
{
var con = session.Connection as SqlConnection;
var tx = (session.ConnectionManager.Transaction as AdoTransaction).GetNativeTransaction() as SqlTransaction;

foreach (var table in this.GetTables(session, entities))
{
using (var copy = new SqlBulkCopy(con, this.Options, tx) { BatchSize = this.BatchSize, BulkCopyTimeout = this.Timeout, DestinationTableName = table.TableName })
{
copy.WriteToServer(table);
}
}
}
}
}
}
5 changes: 5 additions & 0 deletions src/NHibernate/Driver/SqlClientDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ public class SqlClientDriver : DriverBase, IEmbeddedBatcherFactoryProvider
public const byte MaxDateTime2 = 8;
public const byte MaxDateTimeOffset = 10;

public override BulkProvider GetBulkProvider()
{
return new SqlBulkProvider();
}

/// <summary>
/// Creates an uninitialized <see cref="DbConnection" /> object for
/// the SqlClientDriver.
Expand Down
Loading