diff --git a/lib/arel/visitors/sqlserver.rb b/lib/arel/visitors/sqlserver.rb index 3dfe35bbe..b4cef188d 100644 --- a/lib/arel/visitors/sqlserver.rb +++ b/lib/arel/visitors/sqlserver.rb @@ -30,10 +30,46 @@ def visit_Arel_Nodes_Concat(o, collector) end def visit_Arel_Nodes_UpdateStatement(o, collector) - if o.orders.any? && o.limit.nil? - o.limit = Nodes::Limit.new(9_223_372_036_854_775_807) + if has_join_and_composite_primary_key?(o) + update_statement_using_join(o, collector) + else + o.limit = Nodes::Limit.new(9_223_372_036_854_775_807) if o.orders.any? && o.limit.nil? + + super end - super + end + + def visit_Arel_Nodes_DeleteStatement(o, collector) + if has_join_and_composite_primary_key?(o) + delete_statement_using_join(o, collector) + else + super + end + end + + def has_join_and_composite_primary_key?(o) + has_join_sources?(o) && o.relation.left.instance_variable_get(:@klass).composite_primary_key? + end + + def delete_statement_using_join(o, collector) + collector.retryable = false + + collector << "DELETE " + visit o.relation.left, collector + collector << " FROM " + visit o.relation, collector + collect_nodes_for o.wheres, collector, " WHERE ", " AND " + end + + def update_statement_using_join(o, collector) + collector.retryable = false + + collector << "UPDATE " + visit o.relation.left, collector + collect_nodes_for o.values, collector, " SET " + collector << " FROM " + visit o.relation, collector + collect_nodes_for o.wheres, collector, " WHERE ", " AND " end def visit_Arel_Nodes_Lock(o, collector)