From 5b6f31380ab5fa6044bf55275dcd5c75cd03dea1 Mon Sep 17 00:00:00 2001 From: Aidan Haran Date: Fri, 21 Feb 2025 20:47:28 +0000 Subject: [PATCH] Update all subquery fixes --- lib/arel/visitors/sqlserver.rb | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/lib/arel/visitors/sqlserver.rb b/lib/arel/visitors/sqlserver.rb index 69fad62fa..e9e9e9357 100644 --- a/lib/arel/visitors/sqlserver.rb +++ b/lib/arel/visitors/sqlserver.rb @@ -57,16 +57,23 @@ def visit_Arel_Nodes_UpdateStatement(o, collector) collect_nodes_for o.values, collector, " SET " collector << " FROM " first_join, *remaining_joins = o.relation.right - visit first_join.left, collector + from_items = remaining_joins.extract! do |join| + join.right.expr.right.relation == o.relation.left + end + + from_where = [first_join.left] + from_items.map(&:left) + collect_nodes_for from_where, collector, " ", ", " if remaining_joins && !remaining_joins.empty? collector << " " remaining_joins.each do |join| visit join, collector + collector << " " end end - collect_nodes_for [first_join.right.expr] + o.wheres, collector, " WHERE ", " AND " + from_where = [first_join.right.expr] + from_items.map { |i| i.right.expr } + collect_nodes_for from_where + o.wheres, collector, " WHERE ", " AND " else collector = visit o.relation, collector collect_nodes_for o.values, collector, " SET " @@ -77,9 +84,13 @@ def visit_Arel_Nodes_UpdateStatement(o, collector) maybe_visit o.limit, collector end - # Same as PostgreSQL except we need to add limit if using subquery. + # Same as PostgreSQL and SQLite except we need to add limit if using subquery. def prepare_update_statement(o) - if has_join_sources?(o) && !has_limit_or_offset_or_orders?(o) && !has_group_by_and_having?(o) && o.relation.right.first.is_a?(Arel::Nodes::InnerJoin) + if has_join_sources?(o) && !has_limit_or_offset_or_orders?(o) && !has_group_by_and_having?(o) && + # The dialect isn't flexible enough to allow anything other than a inner join + # for the first join: + # UPDATE table SET .. FROM joined_table WHERE ... + (o.relation.right.all? { |join| join.is_a?(Arel::Nodes::InnerJoin) || join.right.expr.right.relation != o.relation.left }) o else o.limit = Nodes::Limit.new(9_223_372_036_854_775_807) if o.orders.any? && o.limit.nil?