diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 7a0130d47e..cab17fdb10 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -104,6 +104,20 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } } + case *ast.UpdateStmt: + for _, item := range n.TargetList.Items { + target, ok := item.(*ast.ResTarget) + if !ok { + continue + } + ref, ok := target.Val.(*ast.ParamRef) + if !ok { + continue + } + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: n.Relation}) + p.seen[ref.Location] = struct{}{} + } + case *ast.RangeVar: p.rangeVar = n diff --git a/internal/endtoend/testdata/update_cte/pgx/go/db.go b/internal/endtoend/testdata/update_cte/pgx/go/db.go new file mode 100644 index 0000000000..4559f50a4f --- /dev/null +++ b/internal/endtoend/testdata/update_cte/pgx/go/db.go @@ -0,0 +1,30 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/update_cte/pgx/go/models.go b/internal/endtoend/testdata/update_cte/pgx/go/models.go new file mode 100644 index 0000000000..d514e4feaa --- /dev/null +++ b/internal/endtoend/testdata/update_cte/pgx/go/models.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" + "time" +) + +type Td3Code struct { + ID int32 + TsCreated time.Time + TsUpdated time.Time + CreatedBy string + UpdatedBy string + Code sql.NullString + Hash sql.NullString + IsPrivate sql.NullBool +} + +type Td3TestCode struct { + ID int32 + TsCreated time.Time + TsUpdated time.Time + CreatedBy string + UpdatedBy string + TestID int32 + CodeHash string +} diff --git a/internal/endtoend/testdata/update_cte/pgx/go/query.sql.go b/internal/endtoend/testdata/update_cte/pgx/go/query.sql.go new file mode 100644 index 0000000000..be391665e3 --- /dev/null +++ b/internal/endtoend/testdata/update_cte/pgx/go/query.sql.go @@ -0,0 +1,72 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" + "time" +) + +const updateCode = `-- name: UpdateCode :one + +WITH cc AS ( + UPDATE td3.codes + SET + created_by = $1, + updated_by = $1, + code = $2, + hash = $3, + is_private = false + RETURNING hash +) +UPDATE td3.test_codes +SET + created_by = $1, + updated_by = $1, + test_id = $4, + code_hash = cc.hash + FROM cc +RETURNING hash, id, ts_created, ts_updated, created_by, updated_by, test_id, code_hash +` + +type UpdateCodeParams struct { + CreatedBy string + Code sql.NullString + Hash sql.NullString + TestID int32 +} + +type UpdateCodeRow struct { + Hash sql.NullString + ID int32 + TsCreated time.Time + TsUpdated time.Time + CreatedBy string + UpdatedBy string + TestID int32 + CodeHash string +} + +// FILE: query.sql +func (q *Queries) UpdateCode(ctx context.Context, arg UpdateCodeParams) (UpdateCodeRow, error) { + row := q.db.QueryRow(ctx, updateCode, + arg.CreatedBy, + arg.Code, + arg.Hash, + arg.TestID, + ) + var i UpdateCodeRow + err := row.Scan( + &i.Hash, + &i.ID, + &i.TsCreated, + &i.TsUpdated, + &i.CreatedBy, + &i.UpdatedBy, + &i.TestID, + &i.CodeHash, + ) + return i, err +} diff --git a/internal/endtoend/testdata/update_cte/pgx/query.sql b/internal/endtoend/testdata/update_cte/pgx/query.sql new file mode 100644 index 0000000000..dc53e4f427 --- /dev/null +++ b/internal/endtoend/testdata/update_cte/pgx/query.sql @@ -0,0 +1,50 @@ +-- FILE: schema.sql + +DROP SCHEMA IF EXISTS td3 CASCADE; +CREATE SCHEMA td3; + +CREATE TABLE td3.codes ( + id SERIAL PRIMARY KEY, + ts_created timestamptz DEFAULT now() NOT NULL, + ts_updated timestamptz DEFAULT now() NOT NULL, + created_by text NOT NULL, + updated_by text NOT NULL, + + code text, + hash text, + is_private boolean +); + + +CREATE TABLE td3.test_codes ( + id SERIAL PRIMARY KEY, + ts_created timestamptz DEFAULT now() NOT NULL, + ts_updated timestamptz DEFAULT now() NOT NULL, + created_by text NOT NULL, + updated_by text NOT NULL, + + test_id integer NOT NULL, + code_hash text NOT NULL +); + +-- FILE: query.sql + +-- name: UpdateCode :one +WITH cc AS ( + UPDATE td3.codes + SET + created_by = $1, + updated_by = $1, + code = $2, + hash = $3, + is_private = false + RETURNING hash +) +UPDATE td3.test_codes +SET + created_by = $1, + updated_by = $1, + test_id = $4, + code_hash = cc.hash + FROM cc +RETURNING *; diff --git a/internal/endtoend/testdata/update_cte/pgx/sqlc.json b/internal/endtoend/testdata/update_cte/pgx/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/update_cte/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/update_cte/stdlib/go/db.go b/internal/endtoend/testdata/update_cte/stdlib/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/update_cte/stdlib/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/update_cte/stdlib/go/models.go b/internal/endtoend/testdata/update_cte/stdlib/go/models.go new file mode 100644 index 0000000000..d514e4feaa --- /dev/null +++ b/internal/endtoend/testdata/update_cte/stdlib/go/models.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" + "time" +) + +type Td3Code struct { + ID int32 + TsCreated time.Time + TsUpdated time.Time + CreatedBy string + UpdatedBy string + Code sql.NullString + Hash sql.NullString + IsPrivate sql.NullBool +} + +type Td3TestCode struct { + ID int32 + TsCreated time.Time + TsUpdated time.Time + CreatedBy string + UpdatedBy string + TestID int32 + CodeHash string +} diff --git a/internal/endtoend/testdata/update_cte/stdlib/go/query.sql.go b/internal/endtoend/testdata/update_cte/stdlib/go/query.sql.go new file mode 100644 index 0000000000..59746dbbef --- /dev/null +++ b/internal/endtoend/testdata/update_cte/stdlib/go/query.sql.go @@ -0,0 +1,72 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" + "time" +) + +const updateCode = `-- name: UpdateCode :one + +WITH cc AS ( + UPDATE td3.codes + SET + created_by = $1, + updated_by = $1, + code = $2, + hash = $3, + is_private = false + RETURNING hash +) +UPDATE td3.test_codes +SET + created_by = $1, + updated_by = $1, + test_id = $4, + code_hash = cc.hash + FROM cc +RETURNING hash, id, ts_created, ts_updated, created_by, updated_by, test_id, code_hash +` + +type UpdateCodeParams struct { + CreatedBy string + Code sql.NullString + Hash sql.NullString + TestID int32 +} + +type UpdateCodeRow struct { + Hash sql.NullString + ID int32 + TsCreated time.Time + TsUpdated time.Time + CreatedBy string + UpdatedBy string + TestID int32 + CodeHash string +} + +// FILE: query.sql +func (q *Queries) UpdateCode(ctx context.Context, arg UpdateCodeParams) (UpdateCodeRow, error) { + row := q.db.QueryRowContext(ctx, updateCode, + arg.CreatedBy, + arg.Code, + arg.Hash, + arg.TestID, + ) + var i UpdateCodeRow + err := row.Scan( + &i.Hash, + &i.ID, + &i.TsCreated, + &i.TsUpdated, + &i.CreatedBy, + &i.UpdatedBy, + &i.TestID, + &i.CodeHash, + ) + return i, err +} diff --git a/internal/endtoend/testdata/update_cte/stdlib/query.sql b/internal/endtoend/testdata/update_cte/stdlib/query.sql new file mode 100644 index 0000000000..dc53e4f427 --- /dev/null +++ b/internal/endtoend/testdata/update_cte/stdlib/query.sql @@ -0,0 +1,50 @@ +-- FILE: schema.sql + +DROP SCHEMA IF EXISTS td3 CASCADE; +CREATE SCHEMA td3; + +CREATE TABLE td3.codes ( + id SERIAL PRIMARY KEY, + ts_created timestamptz DEFAULT now() NOT NULL, + ts_updated timestamptz DEFAULT now() NOT NULL, + created_by text NOT NULL, + updated_by text NOT NULL, + + code text, + hash text, + is_private boolean +); + + +CREATE TABLE td3.test_codes ( + id SERIAL PRIMARY KEY, + ts_created timestamptz DEFAULT now() NOT NULL, + ts_updated timestamptz DEFAULT now() NOT NULL, + created_by text NOT NULL, + updated_by text NOT NULL, + + test_id integer NOT NULL, + code_hash text NOT NULL +); + +-- FILE: query.sql + +-- name: UpdateCode :one +WITH cc AS ( + UPDATE td3.codes + SET + created_by = $1, + updated_by = $1, + code = $2, + hash = $3, + is_private = false + RETURNING hash +) +UPDATE td3.test_codes +SET + created_by = $1, + updated_by = $1, + test_id = $4, + code_hash = cc.hash + FROM cc +RETURNING *; diff --git a/internal/endtoend/testdata/update_cte/stdlib/sqlc.json b/internal/endtoend/testdata/update_cte/stdlib/sqlc.json new file mode 100644 index 0000000000..ac7c2ed829 --- /dev/null +++ b/internal/endtoend/testdata/update_cte/stdlib/sqlc.json @@ -0,0 +1,11 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +}