From c1326a816217178792d32dde3aa506ba0ce5c7a7 Mon Sep 17 00:00:00 2001 From: Timothy Studd Date: Tue, 16 Nov 2021 05:07:28 -0800 Subject: [PATCH 1/3] fix(compiler): Support references to columns in joined tables in UPDATE statements --- internal/compiler/find_params.go | 8 ++- internal/compiler/output_columns.go | 2 +- .../testdata/update_join/mysql/db/db.go | 29 ++++++++ .../testdata/update_join/mysql/db/models.go | 17 +++++ .../update_join/mysql/db/query.sql.go | 68 +++++++++++++++++++ .../testdata/update_join/mysql/query.sql | 37 ++++++++++ .../testdata/update_join/mysql/sqlc.json | 11 +++ .../testdata/update_join/postgresql/db/db.go | 29 ++++++++ .../update_join/postgresql/db/models.go | 17 +++++ .../update_join/postgresql/db/query.sql.go | 28 ++++++++ .../testdata/update_join/postgresql/query.sql | 19 ++++++ .../testdata/update_join/postgresql/sqlc.json | 11 +++ internal/engine/dolphin/convert.go | 14 ++-- internal/engine/postgresql/convert.go | 14 ++-- internal/sql/ast/update_stmt.go | 2 +- internal/sql/astutils/rewrite.go | 2 +- internal/sql/astutils/walk.go | 4 +- 17 files changed, 297 insertions(+), 15 deletions(-) create mode 100644 internal/endtoend/testdata/update_join/mysql/db/db.go create mode 100644 internal/endtoend/testdata/update_join/mysql/db/models.go create mode 100644 internal/endtoend/testdata/update_join/mysql/db/query.sql.go create mode 100644 internal/endtoend/testdata/update_join/mysql/query.sql create mode 100644 internal/endtoend/testdata/update_join/mysql/sqlc.json create mode 100644 internal/endtoend/testdata/update_join/postgresql/db/db.go create mode 100644 internal/endtoend/testdata/update_join/postgresql/db/models.go create mode 100644 internal/endtoend/testdata/update_join/postgresql/db/query.sql.go create mode 100644 internal/endtoend/testdata/update_join/postgresql/query.sql create mode 100644 internal/endtoend/testdata/update_join/postgresql/sqlc.json diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index e5ee811d2e..8c7a325c3c 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -117,7 +117,13 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { if !ok { continue } - *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: n.Relation}) + for _, relation := range n.Relations.Items { + rv, ok := relation.(*ast.RangeVar) + if !ok { + continue + } + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv}) + } p.seen[ref.Location] = struct{}{} } diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 793e433673..e81348a596 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -349,7 +349,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) { }) case *ast.UpdateStmt: list = &ast.List{ - Items: append(n.FromClause.Items, n.Relation), + Items: append(n.FromClause.Items, n.Relations.Items...), } default: return nil, fmt.Errorf("sourceTables: unsupported node type: %T", n) diff --git a/internal/endtoend/testdata/update_join/mysql/db/db.go b/internal/endtoend/testdata/update_join/mysql/db/db.go new file mode 100644 index 0000000000..c3c034ae37 --- /dev/null +++ b/internal/endtoend/testdata/update_join/mysql/db/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/update_join/mysql/db/models.go b/internal/endtoend/testdata/update_join/mysql/db/models.go new file mode 100644 index 0000000000..616d987a61 --- /dev/null +++ b/internal/endtoend/testdata/update_join/mysql/db/models.go @@ -0,0 +1,17 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import () + +type JoinTable struct { + ID int64 + PrimaryTableID int64 + OtherTableID int64 + IsActive bool +} + +type PrimaryTable struct { + ID int64 + UserID int64 +} diff --git a/internal/endtoend/testdata/update_join/mysql/db/query.sql.go b/internal/endtoend/testdata/update_join/mysql/db/query.sql.go new file mode 100644 index 0000000000..a837b2768f --- /dev/null +++ b/internal/endtoend/testdata/update_join/mysql/db/query.sql.go @@ -0,0 +1,68 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package db + +import ( + "context" +) + +const updateJoin = `-- name: UpdateJoin :exec +UPDATE join_table as jt + JOIN primary_table as pt + ON jt.primary_table_id = pt.id +SET jt.is_active = ? +WHERE jt.id = ? + AND pt.user_id = ? +` + +type UpdateJoinParams struct { + IsActive bool + ID int64 + UserID int64 +} + +func (q *Queries) UpdateJoin(ctx context.Context, arg UpdateJoinParams) error { + _, err := q.db.ExecContext(ctx, updateJoin, arg.IsActive, arg.ID, arg.UserID) + return err +} + +const updateLeftJoin = `-- name: UpdateLeftJoin :exec +UPDATE join_table as jt + LEFT JOIN primary_table as pt + ON jt.primary_table_id = pt.id +SET jt.is_active = ? +WHERE jt.id = ? + AND pt.user_id = ? +` + +type UpdateLeftJoinParams struct { + IsActive bool + ID int64 + UserID int64 +} + +func (q *Queries) UpdateLeftJoin(ctx context.Context, arg UpdateLeftJoinParams) error { + _, err := q.db.ExecContext(ctx, updateLeftJoin, arg.IsActive, arg.ID, arg.UserID) + return err +} + +const updateRightJoin = `-- name: UpdateRightJoin :exec +UPDATE join_table as jt + RIGHT JOIN primary_table as pt + ON jt.primary_table_id = pt.id +SET jt.is_active = ? +WHERE jt.id = ? + AND pt.user_id = ? +` + +type UpdateRightJoinParams struct { + IsActive bool + ID int64 + UserID int64 +} + +func (q *Queries) UpdateRightJoin(ctx context.Context, arg UpdateRightJoinParams) error { + _, err := q.db.ExecContext(ctx, updateRightJoin, arg.IsActive, arg.ID, arg.UserID) + return err +} diff --git a/internal/endtoend/testdata/update_join/mysql/query.sql b/internal/endtoend/testdata/update_join/mysql/query.sql new file mode 100644 index 0000000000..8f702fa453 --- /dev/null +++ b/internal/endtoend/testdata/update_join/mysql/query.sql @@ -0,0 +1,37 @@ +CREATE TABLE primary_table ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + user_id bigint(20) unsigned NOT NULL, + PRIMARY KEY (id) +); + +CREATE TABLE join_table ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + primary_table_id bigint(20) unsigned NOT NULL, + other_table_id bigint(20) unsigned NOT NULL, + is_active tinyint(1) NOT NULL DEFAULT '0', + PRIMARY KEY (id) +); + +-- name: UpdateJoin :exec +UPDATE join_table as jt + JOIN primary_table as pt + ON jt.primary_table_id = pt.id +SET jt.is_active = ? +WHERE jt.id = ? + AND pt.user_id = ?; + +-- name: UpdateLeftJoin :exec +UPDATE join_table as jt + LEFT JOIN primary_table as pt + ON jt.primary_table_id = pt.id +SET jt.is_active = ? +WHERE jt.id = ? + AND pt.user_id = ?; + +-- name: UpdateRightJoin :exec +UPDATE join_table as jt + RIGHT JOIN primary_table as pt + ON jt.primary_table_id = pt.id +SET jt.is_active = ? +WHERE jt.id = ? + AND pt.user_id = ?; diff --git a/internal/endtoend/testdata/update_join/mysql/sqlc.json b/internal/endtoend/testdata/update_join/mysql/sqlc.json new file mode 100644 index 0000000000..b63437627d --- /dev/null +++ b/internal/endtoend/testdata/update_join/mysql/sqlc.json @@ -0,0 +1,11 @@ +{ + "version": "1", + "packages": [ + { + "path": "db", + "engine": "mysql", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/update_join/postgresql/db/db.go b/internal/endtoend/testdata/update_join/postgresql/db/db.go new file mode 100644 index 0000000000..c3c034ae37 --- /dev/null +++ b/internal/endtoend/testdata/update_join/postgresql/db/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/update_join/postgresql/db/models.go b/internal/endtoend/testdata/update_join/postgresql/db/models.go new file mode 100644 index 0000000000..0f0c4122a4 --- /dev/null +++ b/internal/endtoend/testdata/update_join/postgresql/db/models.go @@ -0,0 +1,17 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import () + +type JoinTable struct { + ID int32 + PrimaryTableID int32 + OtherTableID int32 + IsActive bool +} + +type PrimaryTable struct { + ID int32 + UserID int32 +} diff --git a/internal/endtoend/testdata/update_join/postgresql/db/query.sql.go b/internal/endtoend/testdata/update_join/postgresql/db/query.sql.go new file mode 100644 index 0000000000..2da6d0e19b --- /dev/null +++ b/internal/endtoend/testdata/update_join/postgresql/db/query.sql.go @@ -0,0 +1,28 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package db + +import ( + "context" +) + +const updateJoin = `-- name: UpdateJoin :exec +UPDATE join_table +SET is_active = $1 +FROM primary_table +WHERE join_table.id = $2 + AND primary_table.user_id = $3 + AND join_table.primary_table_id = primary_table.id +` + +type UpdateJoinParams struct { + IsActive bool + ID int32 + UserID int32 +} + +func (q *Queries) UpdateJoin(ctx context.Context, arg UpdateJoinParams) error { + _, err := q.db.ExecContext(ctx, updateJoin, arg.IsActive, arg.ID, arg.UserID) + return err +} diff --git a/internal/endtoend/testdata/update_join/postgresql/query.sql b/internal/endtoend/testdata/update_join/postgresql/query.sql new file mode 100644 index 0000000000..bfb5e87457 --- /dev/null +++ b/internal/endtoend/testdata/update_join/postgresql/query.sql @@ -0,0 +1,19 @@ +CREATE TABLE primary_table ( + id INT PRIMARY KEY, + user_id INT NOT NULL +); + +CREATE TABLE join_table ( + id INT PRIMARY KEY, + primary_table_id INT NOT NULL, + other_table_id INT NOT NULL, + is_active BOOLEAN NOT NULL +); + +-- name: UpdateJoin :exec +UPDATE join_table +SET is_active = $1 +FROM primary_table +WHERE join_table.id = $2 + AND primary_table.user_id = $3 + AND join_table.primary_table_id = primary_table.id; diff --git a/internal/endtoend/testdata/update_join/postgresql/sqlc.json b/internal/endtoend/testdata/update_join/postgresql/sqlc.json new file mode 100644 index 0000000000..c9cb1e1fdc --- /dev/null +++ b/internal/endtoend/testdata/update_join/postgresql/sqlc.json @@ -0,0 +1,11 @@ +{ + "version": "1", + "packages": [ + { + "path": "db", + "engine": "postgresql", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 6f8236e1c7..fd28839453 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -528,7 +528,7 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt { panic("expected one range var") } - var rangeVar *ast.RangeVar + relations := &ast.List{} switch rel := rels.Items[0].(type) { // Special case for joins in updates @@ -537,10 +537,16 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt { if !ok { panic("expected range var") } - rangeVar = left + relations.Items = append(relations.Items, left) + + right, ok := rel.Rarg.(*ast.RangeVar) + if !ok { + panic("expected range var") + } + relations.Items = append(relations.Items, right) case *ast.RangeVar: - rangeVar = rel + relations.Items = append(relations.Items, rel) default: panic("expected range var") @@ -552,7 +558,7 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt { list.Items = append(list.Items, c.convertAssignment(a)) } return &ast.UpdateStmt{ - Relation: rangeVar, + Relations: relations, TargetList: list, WhereClause: c.convert(n.Where), FromClause: &ast.List{}, diff --git a/internal/engine/postgresql/convert.go b/internal/engine/postgresql/convert.go index d960156c51..c6e8976b66 100644 --- a/internal/engine/postgresql/convert.go +++ b/internal/engine/postgresql/convert.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package postgresql @@ -2765,8 +2766,11 @@ func convertUpdateStmt(n *pg.UpdateStmt) *ast.UpdateStmt { if n == nil { return nil } + return &ast.UpdateStmt{ - Relation: convertRangeVar(n.Relation), + Relations: &ast.List{ + Items: []ast.Node{convertRangeVar(n.Relation)}, + }, TargetList: convertSlice(n.TargetList), WhereClause: convertNode(n.WhereClause), FromClause: convertSlice(n.FromClause), @@ -2780,10 +2784,10 @@ func convertVacuumStmt(n *pg.VacuumStmt) *ast.VacuumStmt { return nil } return &ast.VacuumStmt{ - // FIXME: The VacuumStmt node has changed quite a bit - // Options: n.Options - // Relation: convertRangeVar(n.Relation), - // VaCols: convertSlice(n.VaCols), + // FIXME: The VacuumStmt node has changed quite a bit + // Options: n.Options + // Relation: convertRangeVar(n.Relation), + // VaCols: convertSlice(n.VaCols), } } diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index dd476a3587..517d0b420b 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -1,7 +1,7 @@ package ast type UpdateStmt struct { - Relation *RangeVar + Relations *List TargetList *List WhereClause Node FromClause *List diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index c8473ca14e..ba30f2acfa 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -1141,7 +1141,7 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. // pass case *ast.UpdateStmt: - a.apply(n, "Relation", nil, n.Relation) + a.apply(n, "Relations", nil, n.Relations) a.apply(n, "TargetList", nil, n.TargetList) a.apply(n, "WhereClause", nil, n.WhereClause) a.apply(n, "FromClause", nil, n.FromClause) diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 632950c8b2..eefad0ac03 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -2008,8 +2008,8 @@ func Walk(f Visitor, node ast.Node) { // pass case *ast.UpdateStmt: - if n.Relation != nil { - Walk(f, n.Relation) + if n.Relations != nil { + Walk(f, n.Relations) } if n.TargetList != nil { Walk(f, n.TargetList) From b971a2824452041bc9860d499a151c9bfb9b50c0 Mon Sep 17 00:00:00 2001 From: Timothy Studd Date: Tue, 16 Nov 2021 05:08:45 -0800 Subject: [PATCH 2/3] Fix formatting --- internal/endtoend/testdata/update_join/postgresql/query.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/endtoend/testdata/update_join/postgresql/query.sql b/internal/endtoend/testdata/update_join/postgresql/query.sql index bfb5e87457..875dd8e0f9 100644 --- a/internal/endtoend/testdata/update_join/postgresql/query.sql +++ b/internal/endtoend/testdata/update_join/postgresql/query.sql @@ -16,4 +16,4 @@ SET is_active = $1 FROM primary_table WHERE join_table.id = $2 AND primary_table.user_id = $3 - AND join_table.primary_table_id = primary_table.id; + AND join_table.primary_table_id = primary_table.id; From d271151f43c364d3b953d7145b404dc69078b475 Mon Sep 17 00:00:00 2001 From: Timothy Studd Date: Tue, 16 Nov 2021 05:12:21 -0800 Subject: [PATCH 3/3] Fix formatting --- .../endtoend/testdata/update_join/postgresql/db/query.sql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/endtoend/testdata/update_join/postgresql/db/query.sql.go b/internal/endtoend/testdata/update_join/postgresql/db/query.sql.go index 2da6d0e19b..6fc919f141 100644 --- a/internal/endtoend/testdata/update_join/postgresql/db/query.sql.go +++ b/internal/endtoend/testdata/update_join/postgresql/db/query.sql.go @@ -13,7 +13,7 @@ SET is_active = $1 FROM primary_table WHERE join_table.id = $2 AND primary_table.user_id = $3 - AND join_table.primary_table_id = primary_table.id + AND join_table.primary_table_id = primary_table.id ` type UpdateJoinParams struct {