From 135ac824a4ac934c4102568f91e8e9529fb2e081 Mon Sep 17 00:00:00 2001 From: Aidan Haran Date: Tue, 20 May 2025 11:08:07 +0100 Subject: [PATCH 1/2] Enable identity insert on view's base table for fixtures --- .../sqlserver/database_statements.rb | 102 +++++++++--------- test/cases/view_test_sqlserver.rb | 10 +- test/fixtures/sst_customers_view.yml | 6 ++ 3 files changed, 65 insertions(+), 53 deletions(-) create mode 100644 test/fixtures/sst_customers_view.yml diff --git a/lib/active_record/connection_adapters/sqlserver/database_statements.rb b/lib/active_record/connection_adapters/sqlserver/database_statements.rb index bac9e5441..da60515d7 100644 --- a/lib/active_record/connection_adapters/sqlserver/database_statements.rb +++ b/lib/active_record/connection_adapters/sqlserver/database_statements.rb @@ -17,15 +17,12 @@ def perform_query(raw_connection, sql, binds, type_casted_binds, prepare:, notif id_insert_table_name = query_requires_identity_insert?(sql) result, affected_rows = if id_insert_table_name - # If the table name is a view, we need to get the base table name for enabling identity insert. - id_insert_table_name = view_table_name(id_insert_table_name) if view_exists?(id_insert_table_name) - - with_identity_insert_enabled(id_insert_table_name, raw_connection) do - internal_exec_sql_query(sql, raw_connection) - end - else - internal_exec_sql_query(sql, raw_connection) - end + with_identity_insert_enabled(id_insert_table_name, raw_connection) do + internal_exec_sql_query(sql, raw_connection) + end + else + internal_exec_sql_query(sql, raw_connection) + end verified! notification_payload[:affected_rows] = affected_rows @@ -239,10 +236,10 @@ def merge_insert_values_list(insert:, insert_all:) def execute_procedure(proc_name, *variables) vars = if variables.any? && variables.first.is_a?(Hash) - variables.first.map { |k, v| "@#{k} = #{quote(v)}" } - else - variables.map { |v| quote(v) } - end.join(", ") + variables.first.map { |k, v| "@#{k} = #{quote(v)}" } + else + variables.map { |v| quote(v) } + end.join(", ") sql = "EXEC #{proc_name} #{vars}".strip log(sql, "Execute Procedure") do |notification_payload| @@ -264,11 +261,14 @@ def execute_procedure(proc_name, *variables) end def with_identity_insert_enabled(table_name, conn) - table_name = quote_table_name(table_name) - set_identity_insert(table_name, conn, true) + # If the table name is a view, we need to get the base table name for enabling identity insert. + table_name = view_table_name(table_name) if view_exists?(table_name) + quoted_table_name = quote_table_name(table_name) + + set_identity_insert(quoted_table_name, conn, true) yield ensure - set_identity_insert(table_name, conn, false) + set_identity_insert(quoted_table_name, conn, false) end def use_database(database = nil) @@ -345,35 +345,35 @@ def sql_for_insert(sql, pk, binds, returning) end sql = if pk && use_output_inserted? && !database_prefix_remote_server? - table_name ||= get_table_name(sql) - exclude_output_inserted = exclude_output_inserted_table_name?(table_name, sql) - - if exclude_output_inserted - pk_and_types = Array(pk).map do |subkey| - { - quoted: SQLServer::Utils.extract_identifiers(subkey).quoted, - id_sql_type: exclude_output_inserted_id_sql_type(subkey, exclude_output_inserted) - } - end - - <<~SQL.squish + table_name ||= get_table_name(sql) + exclude_output_inserted = exclude_output_inserted_table_name?(table_name, sql) + + if exclude_output_inserted + pk_and_types = Array(pk).map do |subkey| + { + quoted: SQLServer::Utils.extract_identifiers(subkey).quoted, + id_sql_type: exclude_output_inserted_id_sql_type(subkey, exclude_output_inserted) + } + end + + <<~SQL.squish DECLARE @ssaIdInsertTable table (#{pk_and_types.map { |pk_and_type| "#{pk_and_type[:quoted]} #{pk_and_type[:id_sql_type]}" }.join(", ")}); #{sql.dup.insert sql.index(/ (DEFAULT )?VALUES/i), " OUTPUT #{pk_and_types.map { |pk_and_type| "INSERTED.#{pk_and_type[:quoted]}" }.join(", ")} INTO @ssaIdInsertTable"} SELECT #{pk_and_types.map { |pk_and_type| "CAST(#{pk_and_type[:quoted]} AS #{pk_and_type[:id_sql_type]}) #{pk_and_type[:quoted]}" }.join(", ")} FROM @ssaIdInsertTable SQL - else - returning_columns = returning || Array(pk) - - if returning_columns.any? - returning_columns_statements = returning_columns.map { |c| " INSERTED.#{SQLServer::Utils.extract_identifiers(c).quoted}" } - sql.dup.insert sql.index(/ (DEFAULT )?VALUES/i), " OUTPUT" + returning_columns_statements.join(",") - else - sql - end - end - else - "#{sql}; SELECT CAST(SCOPE_IDENTITY() AS bigint) AS Ident" - end + else + returning_columns = returning || Array(pk) + + if returning_columns.any? + returning_columns_statements = returning_columns.map { |c| " INSERTED.#{SQLServer::Utils.extract_identifiers(c).quoted}" } + sql.dup.insert sql.index(/ (DEFAULT )?VALUES/i), " OUTPUT" + returning_columns_statements.join(",") + else + sql + end + end + else + "#{sql}; SELECT CAST(SCOPE_IDENTITY() AS bigint) AS Ident" + end [sql, binds] end @@ -542,16 +542,16 @@ def build_sql_for_returning(insert:, insert_all:) return "" unless insert_all.returning returning_values_sql = if insert_all.returning.is_a?(String) - insert_all.returning - else - Array(insert_all.returning).map do |attribute| - if insert.model.attribute_alias?(attribute) - "INSERTED.#{quote_column_name(insert.model.attribute_alias(attribute))} AS #{quote_column_name(attribute)}" - else - "INSERTED.#{quote_column_name(attribute)}" - end - end.join(",") - end + insert_all.returning + else + Array(insert_all.returning).map do |attribute| + if insert.model.attribute_alias?(attribute) + "INSERTED.#{quote_column_name(insert.model.attribute_alias(attribute))} AS #{quote_column_name(attribute)}" + else + "INSERTED.#{quote_column_name(attribute)}" + end + end.join(",") + end " OUTPUT #{returning_values_sql}" end diff --git a/test/cases/view_test_sqlserver.rb b/test/cases/view_test_sqlserver.rb index 88d195750..ed5d2303c 100644 --- a/test/cases/view_test_sqlserver.rb +++ b/test/cases/view_test_sqlserver.rb @@ -53,10 +53,16 @@ class ViewTestSQLServer < ActiveRecord::TestCase end describe "identity insert" do - it "identity insert works with views" do - assert_difference("SSTestCustomersView.count", 1) do + it "creates table record through a view" do + assert_difference("SSTestCustomersView.count", 2) do SSTestCustomersView.create!(id: 5, name: "Bob") + SSTestCustomersView.create!(id: 6, name: "Tim") end end + + it "creates table records through a view using fixtures" do + ActiveRecord::FixtureSet.create_fixtures(File.join(ARTest::SQLServer.test_root_sqlserver, "fixtures"), ["sst_customers_view"]) + assert_equal SSTestCustomersView.all.count, 2 + end end end diff --git a/test/fixtures/sst_customers_view.yml b/test/fixtures/sst_customers_view.yml new file mode 100644 index 000000000..668ba3763 --- /dev/null +++ b/test/fixtures/sst_customers_view.yml @@ -0,0 +1,6 @@ +david: + name: "David" + balance: 2,004 +aidan: + name: "Aidan" + balance: 10,191 From b455791c47a58b433db05844f5130628c6481248 Mon Sep 17 00:00:00 2001 From: Aidan Haran Date: Tue, 20 May 2025 11:10:10 +0100 Subject: [PATCH 2/2] Update database_statements.rb --- .../sqlserver/database_statements.rb | 90 +++++++++---------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/lib/active_record/connection_adapters/sqlserver/database_statements.rb b/lib/active_record/connection_adapters/sqlserver/database_statements.rb index da60515d7..4aee1e1a2 100644 --- a/lib/active_record/connection_adapters/sqlserver/database_statements.rb +++ b/lib/active_record/connection_adapters/sqlserver/database_statements.rb @@ -17,12 +17,12 @@ def perform_query(raw_connection, sql, binds, type_casted_binds, prepare:, notif id_insert_table_name = query_requires_identity_insert?(sql) result, affected_rows = if id_insert_table_name - with_identity_insert_enabled(id_insert_table_name, raw_connection) do - internal_exec_sql_query(sql, raw_connection) - end - else - internal_exec_sql_query(sql, raw_connection) - end + with_identity_insert_enabled(id_insert_table_name, raw_connection) do + internal_exec_sql_query(sql, raw_connection) + end + else + internal_exec_sql_query(sql, raw_connection) + end verified! notification_payload[:affected_rows] = affected_rows @@ -236,10 +236,10 @@ def merge_insert_values_list(insert:, insert_all:) def execute_procedure(proc_name, *variables) vars = if variables.any? && variables.first.is_a?(Hash) - variables.first.map { |k, v| "@#{k} = #{quote(v)}" } - else - variables.map { |v| quote(v) } - end.join(", ") + variables.first.map { |k, v| "@#{k} = #{quote(v)}" } + else + variables.map { |v| quote(v) } + end.join(", ") sql = "EXEC #{proc_name} #{vars}".strip log(sql, "Execute Procedure") do |notification_payload| @@ -345,35 +345,35 @@ def sql_for_insert(sql, pk, binds, returning) end sql = if pk && use_output_inserted? && !database_prefix_remote_server? - table_name ||= get_table_name(sql) - exclude_output_inserted = exclude_output_inserted_table_name?(table_name, sql) - - if exclude_output_inserted - pk_and_types = Array(pk).map do |subkey| - { - quoted: SQLServer::Utils.extract_identifiers(subkey).quoted, - id_sql_type: exclude_output_inserted_id_sql_type(subkey, exclude_output_inserted) - } - end - - <<~SQL.squish + table_name ||= get_table_name(sql) + exclude_output_inserted = exclude_output_inserted_table_name?(table_name, sql) + + if exclude_output_inserted + pk_and_types = Array(pk).map do |subkey| + { + quoted: SQLServer::Utils.extract_identifiers(subkey).quoted, + id_sql_type: exclude_output_inserted_id_sql_type(subkey, exclude_output_inserted) + } + end + + <<~SQL.squish DECLARE @ssaIdInsertTable table (#{pk_and_types.map { |pk_and_type| "#{pk_and_type[:quoted]} #{pk_and_type[:id_sql_type]}" }.join(", ")}); #{sql.dup.insert sql.index(/ (DEFAULT )?VALUES/i), " OUTPUT #{pk_and_types.map { |pk_and_type| "INSERTED.#{pk_and_type[:quoted]}" }.join(", ")} INTO @ssaIdInsertTable"} SELECT #{pk_and_types.map { |pk_and_type| "CAST(#{pk_and_type[:quoted]} AS #{pk_and_type[:id_sql_type]}) #{pk_and_type[:quoted]}" }.join(", ")} FROM @ssaIdInsertTable SQL - else - returning_columns = returning || Array(pk) - - if returning_columns.any? - returning_columns_statements = returning_columns.map { |c| " INSERTED.#{SQLServer::Utils.extract_identifiers(c).quoted}" } - sql.dup.insert sql.index(/ (DEFAULT )?VALUES/i), " OUTPUT" + returning_columns_statements.join(",") - else - sql - end - end - else - "#{sql}; SELECT CAST(SCOPE_IDENTITY() AS bigint) AS Ident" - end + else + returning_columns = returning || Array(pk) + + if returning_columns.any? + returning_columns_statements = returning_columns.map { |c| " INSERTED.#{SQLServer::Utils.extract_identifiers(c).quoted}" } + sql.dup.insert sql.index(/ (DEFAULT )?VALUES/i), " OUTPUT" + returning_columns_statements.join(",") + else + sql + end + end + else + "#{sql}; SELECT CAST(SCOPE_IDENTITY() AS bigint) AS Ident" + end [sql, binds] end @@ -542,16 +542,16 @@ def build_sql_for_returning(insert:, insert_all:) return "" unless insert_all.returning returning_values_sql = if insert_all.returning.is_a?(String) - insert_all.returning - else - Array(insert_all.returning).map do |attribute| - if insert.model.attribute_alias?(attribute) - "INSERTED.#{quote_column_name(insert.model.attribute_alias(attribute))} AS #{quote_column_name(attribute)}" - else - "INSERTED.#{quote_column_name(attribute)}" - end - end.join(",") - end + insert_all.returning + else + Array(insert_all.returning).map do |attribute| + if insert.model.attribute_alias?(attribute) + "INSERTED.#{quote_column_name(insert.model.attribute_alias(attribute))} AS #{quote_column_name(attribute)}" + else + "INSERTED.#{quote_column_name(attribute)}" + end + end.join(",") + end " OUTPUT #{returning_values_sql}" end