diff --git a/lib/arel/visitors/sqlserver.rb b/lib/arel/visitors/sqlserver.rb index 5890b7fc3..a9e8609ba 100644 --- a/lib/arel/visitors/sqlserver.rb +++ b/lib/arel/visitors/sqlserver.rb @@ -29,9 +29,56 @@ def visit_Arel_Nodes_Concat(o, collector) visit o.right, collector end + # Same as SQLite and PostgreSQL. def visit_Arel_Nodes_UpdateStatement(o, collector) - if has_join_and_composite_primary_key?(o) - update_statement_using_join(o, collector) + collector.retryable = false + o = prepare_update_statement(o) + + collector << "UPDATE " + + # UPDATE with JOIN is in the form of: + # + # UPDATE t1 + # SET .. + # FROM t2 + # WHERE t1.join_id = t2.join_id + # + # Or if more than one join is present: + # + # UPDATE t1 + # SET .. + # FROM t2 + # JOIN t3 ON t2.join_id = t3.join_id + # WHERE t1.join_id = t2.join_id + if has_join_sources?(o) + visit o.relation.left, collector + collect_nodes_for o.values, collector, " SET " + collector << " FROM " + first_join, *remaining_joins = o.relation.right + visit first_join.left, collector + + if remaining_joins && !remaining_joins.empty? + collector << " " + remaining_joins.each do |join| + visit join, collector + end + end + + collect_nodes_for [first_join.right.expr] + o.wheres, collector, " WHERE ", " AND " + else + collector = visit o.relation, collector + collect_nodes_for o.values, collector, " SET " + collect_nodes_for o.wheres, collector, " WHERE ", " AND " + end + + collect_nodes_for o.orders, collector, " ORDER BY " + maybe_visit o.limit, collector + end + + # Same as PostgreSQL except we need to add limit if using subquery. + def prepare_update_statement(o) + if has_join_sources?(o) && !has_limit_or_offset_or_orders?(o) && !has_group_by_and_having?(o) + o else o.limit = Nodes::Limit.new(9_223_372_036_854_775_807) if o.orders.any? && o.limit.nil? @@ -39,6 +86,7 @@ def visit_Arel_Nodes_UpdateStatement(o, collector) end end + def visit_Arel_Nodes_DeleteStatement(o, collector) if has_join_and_composite_primary_key?(o) delete_statement_using_join(o, collector) @@ -61,17 +109,6 @@ def delete_statement_using_join(o, collector) collect_nodes_for o.wheres, collector, " WHERE ", " AND " end - def update_statement_using_join(o, collector) - collector.retryable = false - - collector << "UPDATE " - visit o.relation.left, collector - collect_nodes_for o.values, collector, " SET " - collector << " FROM " - visit o.relation, collector - collect_nodes_for o.wheres, collector, " WHERE ", " AND " - end - def visit_Arel_Nodes_Lock(o, collector) o.expr = Arel.sql("WITH(UPDLOCK)") if o.expr.to_s =~ /FOR UPDATE/ collector << " " diff --git a/test/cases/coerced_tests.rb b/test/cases/coerced_tests.rb index 58d760341..4da87a0cf 100644 --- a/test/cases/coerced_tests.rb +++ b/test/cases/coerced_tests.rb @@ -1303,18 +1303,25 @@ def test_update_coerced require "models/author" class UpdateAllTest < ActiveRecord::TestCase - # Rails test required updating a identity column. + # Regular expression slightly different. coerce_tests! :test_update_all_doesnt_ignore_order def test_update_all_doesnt_ignore_order_coerced - david, mary = authors(:david), authors(:mary) - _(david.id).must_equal 1 - _(mary.id).must_equal 2 - _(david.name).wont_equal mary.name - assert_queries_match(/UPDATE.*\(SELECT \[authors\].\[id\] FROM \[authors\].*ORDER BY \[authors\].\[id\]/i) do - Author.where("[id] > 1").order(:id).update_all(name: "Test") + assert_equal authors(:david).id + 1, authors(:mary).id # make sure there is going to be a duplicate PK error + test_update_with_order_succeeds = lambda do |order| + Author.order(order).update_all("id = id + 1") + rescue ActiveRecord::ActiveRecordError + false + end + + if test_update_with_order_succeeds.call("id DESC") + # test that this wasn't a fluke and using an incorrect order results in an exception + assert_not test_update_with_order_succeeds.call("id ASC") + else + # test that we're failing because the current Arel's engine doesn't support UPDATE ORDER BY queries is using subselects instead + assert_queries_match(/\AUPDATE .+ \(SELECT .* ORDER BY id DESC.*\)/i) do + test_update_with_order_succeeds.call("id DESC") + end end - _(david.reload.name).must_equal "David" - _(mary.reload.name).must_equal "Test" end # SELECT columns must be in the GROUP clause. @@ -1971,7 +1978,7 @@ def with_marshable_time_defaults # Revert changes @connection.change_column_default(:sst_datatypes, :datetime, current_default) if current_default.present? end - + # We need to give the full paths for this to work. undef_method :schema_dump_5_1_path def schema_dump_5_1_path