Skip to content

Commit 466c3e1

Browse files
authored
fix(compiler): Support references to columns in joined tables in UPDATE statements (#1289)
* fix(compiler): Support references to columns in joined tables in UPDATE statements
1 parent 5eb649d commit 466c3e1

File tree

17 files changed

+297
-15
lines changed

17 files changed

+297
-15
lines changed

internal/compiler/find_params.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,13 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
117117
if !ok {
118118
continue
119119
}
120-
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: n.Relation})
120+
for _, relation := range n.Relations.Items {
121+
rv, ok := relation.(*ast.RangeVar)
122+
if !ok {
123+
continue
124+
}
125+
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv})
126+
}
121127
p.seen[ref.Location] = struct{}{}
122128
}
123129

internal/compiler/output_columns.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
349349
})
350350
case *ast.UpdateStmt:
351351
list = &ast.List{
352-
Items: append(n.FromClause.Items, n.Relation),
352+
Items: append(n.FromClause.Items, n.Relations.Items...),
353353
}
354354
default:
355355
return nil, fmt.Errorf("sourceTables: unsupported node type: %T", n)

internal/endtoend/testdata/update_join/mysql/db/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/update_join/mysql/db/models.go

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/update_join/mysql/db/query.sql.go

Lines changed: 68 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
CREATE TABLE primary_table (
2+
id bigint(20) unsigned NOT NULL AUTO_INCREMENT,
3+
user_id bigint(20) unsigned NOT NULL,
4+
PRIMARY KEY (id)
5+
);
6+
7+
CREATE TABLE join_table (
8+
id bigint(20) unsigned NOT NULL AUTO_INCREMENT,
9+
primary_table_id bigint(20) unsigned NOT NULL,
10+
other_table_id bigint(20) unsigned NOT NULL,
11+
is_active tinyint(1) NOT NULL DEFAULT '0',
12+
PRIMARY KEY (id)
13+
);
14+
15+
-- name: UpdateJoin :exec
16+
UPDATE join_table as jt
17+
JOIN primary_table as pt
18+
ON jt.primary_table_id = pt.id
19+
SET jt.is_active = ?
20+
WHERE jt.id = ?
21+
AND pt.user_id = ?;
22+
23+
-- name: UpdateLeftJoin :exec
24+
UPDATE join_table as jt
25+
LEFT JOIN primary_table as pt
26+
ON jt.primary_table_id = pt.id
27+
SET jt.is_active = ?
28+
WHERE jt.id = ?
29+
AND pt.user_id = ?;
30+
31+
-- name: UpdateRightJoin :exec
32+
UPDATE join_table as jt
33+
RIGHT JOIN primary_table as pt
34+
ON jt.primary_table_id = pt.id
35+
SET jt.is_active = ?
36+
WHERE jt.id = ?
37+
AND pt.user_id = ?;
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"path": "db",
6+
"engine": "mysql",
7+
"schema": "query.sql",
8+
"queries": "query.sql"
9+
}
10+
]
11+
}

internal/endtoend/testdata/update_join/postgresql/db/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/update_join/postgresql/db/models.go

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/update_join/postgresql/db/query.sql.go

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
CREATE TABLE primary_table (
2+
id INT PRIMARY KEY,
3+
user_id INT NOT NULL
4+
);
5+
6+
CREATE TABLE join_table (
7+
id INT PRIMARY KEY,
8+
primary_table_id INT NOT NULL,
9+
other_table_id INT NOT NULL,
10+
is_active BOOLEAN NOT NULL
11+
);
12+
13+
-- name: UpdateJoin :exec
14+
UPDATE join_table
15+
SET is_active = $1
16+
FROM primary_table
17+
WHERE join_table.id = $2
18+
AND primary_table.user_id = $3
19+
AND join_table.primary_table_id = primary_table.id;
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"path": "db",
6+
"engine": "postgresql",
7+
"schema": "query.sql",
8+
"queries": "query.sql"
9+
}
10+
]
11+
}

internal/engine/dolphin/convert.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt {
540540
panic("expected one range var")
541541
}
542542

543-
var rangeVar *ast.RangeVar
543+
relations := &ast.List{}
544544
switch rel := rels.Items[0].(type) {
545545

546546
// Special case for joins in updates
@@ -549,10 +549,16 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt {
549549
if !ok {
550550
panic("expected range var")
551551
}
552-
rangeVar = left
552+
relations.Items = append(relations.Items, left)
553+
554+
right, ok := rel.Rarg.(*ast.RangeVar)
555+
if !ok {
556+
panic("expected range var")
557+
}
558+
relations.Items = append(relations.Items, right)
553559

554560
case *ast.RangeVar:
555-
rangeVar = rel
561+
relations.Items = append(relations.Items, rel)
556562

557563
default:
558564
panic("expected range var")
@@ -564,7 +570,7 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt {
564570
list.Items = append(list.Items, c.convertAssignment(a))
565571
}
566572
return &ast.UpdateStmt{
567-
Relation: rangeVar,
573+
Relations: relations,
568574
TargetList: list,
569575
WhereClause: c.convert(n.Where),
570576
FromClause: &ast.List{},

internal/engine/postgresql/convert.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
//go:build !windows
12
// +build !windows
23

34
package postgresql
@@ -2765,8 +2766,11 @@ func convertUpdateStmt(n *pg.UpdateStmt) *ast.UpdateStmt {
27652766
if n == nil {
27662767
return nil
27672768
}
2769+
27682770
return &ast.UpdateStmt{
2769-
Relation: convertRangeVar(n.Relation),
2771+
Relations: &ast.List{
2772+
Items: []ast.Node{convertRangeVar(n.Relation)},
2773+
},
27702774
TargetList: convertSlice(n.TargetList),
27712775
WhereClause: convertNode(n.WhereClause),
27722776
FromClause: convertSlice(n.FromClause),
@@ -2780,10 +2784,10 @@ func convertVacuumStmt(n *pg.VacuumStmt) *ast.VacuumStmt {
27802784
return nil
27812785
}
27822786
return &ast.VacuumStmt{
2783-
// FIXME: The VacuumStmt node has changed quite a bit
2784-
// Options: n.Options
2785-
// Relation: convertRangeVar(n.Relation),
2786-
// VaCols: convertSlice(n.VaCols),
2787+
// FIXME: The VacuumStmt node has changed quite a bit
2788+
// Options: n.Options
2789+
// Relation: convertRangeVar(n.Relation),
2790+
// VaCols: convertSlice(n.VaCols),
27872791
}
27882792
}
27892793

internal/sql/ast/update_stmt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package ast
22

33
type UpdateStmt struct {
4-
Relation *RangeVar
4+
Relations *List
55
TargetList *List
66
WhereClause Node
77
FromClause *List

internal/sql/astutils/rewrite.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.
11411141
// pass
11421142

11431143
case *ast.UpdateStmt:
1144-
a.apply(n, "Relation", nil, n.Relation)
1144+
a.apply(n, "Relations", nil, n.Relations)
11451145
a.apply(n, "TargetList", nil, n.TargetList)
11461146
a.apply(n, "WhereClause", nil, n.WhereClause)
11471147
a.apply(n, "FromClause", nil, n.FromClause)

internal/sql/astutils/walk.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,8 +2008,8 @@ func Walk(f Visitor, node ast.Node) {
20082008
// pass
20092009

20102010
case *ast.UpdateStmt:
2011-
if n.Relation != nil {
2012-
Walk(f, n.Relation)
2011+
if n.Relations != nil {
2012+
Walk(f, n.Relations)
20132013
}
20142014
if n.TargetList != nil {
20152015
Walk(f, n.TargetList)

0 commit comments

Comments
 (0)