From 89c27e513320fe7789d025d59fabdffd4c2a1020 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Fri, 27 Nov 2020 15:31:24 -0800 Subject: [PATCH 01/10] Allow for mixing parameter styles --- internal/compiler/parse.go | 8 +- internal/endtoend/testdata/mix/go/db.go | 29 +++++++ internal/endtoend/testdata/mix/go/models.go | 11 +++ internal/endtoend/testdata/mix/go/test.sql.go | 75 +++++++++++++++++++ internal/endtoend/testdata/mix/sqlc.json | 11 +++ internal/endtoend/testdata/mix/test.sql | 17 +++++ internal/sql/rewrite/parameters.go | 3 +- internal/sql/validate/param_ref.go | 11 ++- internal/sql/validate/param_style.go | 34 --------- 9 files changed, 154 insertions(+), 45 deletions(-) create mode 100644 internal/endtoend/testdata/mix/go/db.go create mode 100644 internal/endtoend/testdata/mix/go/models.go create mode 100644 internal/endtoend/testdata/mix/go/test.sql.go create mode 100644 internal/endtoend/testdata/mix/sqlc.json create mode 100644 internal/endtoend/testdata/mix/test.sql delete mode 100644 internal/sql/validate/param_style.go diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index b8329b6f5e..0fbd242ea8 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -34,10 +34,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if o.Debug.DumpAST { debug.Dump(stmt) } - if err := validate.ParamStyle(stmt); err != nil { - return nil, err - } - if err := validate.ParamRef(stmt); err != nil { + lastNumber, err := validate.ParamRef(stmt) + if err != nil { return nil, err } raw, ok := stmt.(*ast.RawStmt) @@ -75,7 +73,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } - raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw) + raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, lastNumber) rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) if o.UsePositionalParameters { diff --git a/internal/endtoend/testdata/mix/go/db.go b/internal/endtoend/testdata/mix/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mix/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mix/go/models.go b/internal/endtoend/testdata/mix/go/models.go new file mode 100644 index 0000000000..c66db5f3b2 --- /dev/null +++ b/internal/endtoend/testdata/mix/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Bar struct { + ID int32 + Name string + Phone string +} diff --git a/internal/endtoend/testdata/mix/go/test.sql.go b/internal/endtoend/testdata/mix/go/test.sql.go new file mode 100644 index 0000000000..f8aec50e5b --- /dev/null +++ b/internal/endtoend/testdata/mix/go/test.sql.go @@ -0,0 +1,75 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: test.sql + +package querytest + +import ( + "context" +) + +const countFour = `-- name: CountFour :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> $3 AND name <> $1 +` + +type CountFourParams struct { + Name string + ID int32 + PhoneParam string +} + +func (q *Queries) CountFour(ctx context.Context, arg CountFourParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countFour, arg.Name, arg.ID, arg.PhoneParam) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = $2 AND phone < $3 and name <> $1 +` + +type CountOneParams struct { + Name string + ID int32 + PhoneParam string +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID, arg.PhoneParam) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND name = $1 +` + +type CountThreeParams struct { + Name string + IDParam int32 +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.Name, arg.IDParam) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $2 AND phone < $3 and name <> $1 +` + +type CountTwoParams struct { + Name string + IDParam int32 + PhoneParam string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.Name, arg.IDParam, arg.PhoneParam) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix/sqlc.json b/internal/endtoend/testdata/mix/sqlc.json new file mode 100644 index 0000000000..6ca0bbcfa4 --- /dev/null +++ b/internal/endtoend/testdata/mix/sqlc.json @@ -0,0 +1,11 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "test.sql", + "queries": "test.sql" + } + ] +} diff --git a/internal/endtoend/testdata/mix/test.sql b/internal/endtoend/testdata/mix/test.sql new file mode 100644 index 0000000000..5832d0fa09 --- /dev/null +++ b/internal/endtoend/testdata/mix/test.sql @@ -0,0 +1,17 @@ +CREATE TABLE bar ( + id serial not null, + name text not null, + phone text not null +); + +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = $2 AND phone < @phone_param and name <> $1; + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id_param) AND phone < @phone_param and name <> $1; + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > sqlc.arg(id_param) AND name = $1; + +-- name: CountFour :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> @phone_param AND name <> $1; diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 7c79e9b6d9..dc85f42dac 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,7 +41,7 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[int]string, []source.Edit) { +func NamedParameters(engine config.Engine, raw *ast.RawStmt, argn int) (*ast.RawStmt, map[int]string, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) if len(foundFunc.Items)+len(foundSign.Items) == 0 { @@ -51,7 +51,6 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[ hasNamedParameterSupport := engine != config.EngineMySQL args := map[string]int{} - argn := 0 var edits []source.Edit node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { node := cr.Node() diff --git a/internal/sql/validate/param_ref.go b/internal/sql/validate/param_ref.go index eacdb64fad..e3332f2981 100644 --- a/internal/sql/validate/param_ref.go +++ b/internal/sql/validate/param_ref.go @@ -8,7 +8,7 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) -func ParamRef(n ast.Node) error { +func ParamRef(n ast.Node) (int, error) { var allrefs []*ast.ParamRef // Find all parameter references @@ -23,14 +23,17 @@ func ParamRef(n ast.Node) error { for _, r := range allrefs { seen[r.Number] = struct{}{} } - + var max int for i := 1; i <= len(seen); i += 1 { + if i > max { + max = i + } if _, ok := seen[i]; !ok { - return &sqlerr.Error{ + return 0, &sqlerr.Error{ Code: "42P18", Message: fmt.Sprintf("could not determine data type of parameter $%d", i), } } } - return nil + return max, nil } diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go deleted file mode 100644 index 870f460f66..0000000000 --- a/internal/sql/validate/param_style.go +++ /dev/null @@ -1,34 +0,0 @@ -package validate - -import ( - "github.com/kyleconroy/sqlc/internal/sql/ast" - "github.com/kyleconroy/sqlc/internal/sql/astutils" - "github.com/kyleconroy/sqlc/internal/sql/named" - "github.com/kyleconroy/sqlc/internal/sql/sqlerr" -) - -// A query can use one (and only one) of the following formats: -// - positional parameters $1 -// - named parameter operator @param -// - named parameter function calls sqlc.arg(param) -func ParamStyle(n ast.Node) error { - positional := astutils.Search(n, func(node ast.Node) bool { - _, ok := node.(*ast.ParamRef) - return ok - }) - namedFunc := astutils.Search(n, named.IsParamFunc) - namedSign := astutils.Search(n, named.IsParamSign) - for _, check := range []bool{ - len(positional.Items) > 0 && len(namedSign.Items)+len(namedFunc.Items) > 0, - len(namedFunc.Items) > 0 && len(namedSign.Items)+len(positional.Items) > 0, - len(namedSign.Items) > 0 && len(positional.Items)+len(namedFunc.Items) > 0, - } { - if check { - return &sqlerr.Error{ - Code: "", // TODO: Pick a new error code - Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)", - } - } - } - return nil -} From 44e60356676b8c160ede8d0cd55113b3034eeb29 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sat, 28 Nov 2020 21:27:08 -0800 Subject: [PATCH 02/10] Adding mysql/postgres engines and renaming dir --- internal/endtoend/testdata/mix/go/test.sql.go | 75 ------------------- internal/endtoend/testdata/mix/test.sql | 17 ----- .../{mix => mix_param_types/mysql}/go/db.go | 0 .../mysql}/go/models.go | 0 .../mix_param_types/mysql/go/test.sql.go | 57 ++++++++++++++ .../{mix => mix_param_types/mysql}/sqlc.json | 3 +- .../testdata/mix_param_types/mysql/test.sql | 14 ++++ .../mix_param_types/postgresql/go/db.go | 29 +++++++ .../mix_param_types/postgresql/go/models.go | 11 +++ .../mix_param_types/postgresql/go/test.sql.go | 57 ++++++++++++++ .../mix_param_types/postgresql/sqlc.json | 12 +++ .../mix_param_types/postgresql/test.sql | 14 ++++ 12 files changed, 196 insertions(+), 93 deletions(-) delete mode 100644 internal/endtoend/testdata/mix/go/test.sql.go delete mode 100644 internal/endtoend/testdata/mix/test.sql rename internal/endtoend/testdata/{mix => mix_param_types/mysql}/go/db.go (100%) rename internal/endtoend/testdata/{mix => mix_param_types/mysql}/go/models.go (100%) create mode 100644 internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go rename internal/endtoend/testdata/{mix => mix_param_types/mysql}/sqlc.json (68%) create mode 100644 internal/endtoend/testdata/mix_param_types/mysql/test.sql create mode 100644 internal/endtoend/testdata/mix_param_types/postgresql/go/db.go create mode 100644 internal/endtoend/testdata/mix_param_types/postgresql/go/models.go create mode 100644 internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go create mode 100644 internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json create mode 100644 internal/endtoend/testdata/mix_param_types/postgresql/test.sql diff --git a/internal/endtoend/testdata/mix/go/test.sql.go b/internal/endtoend/testdata/mix/go/test.sql.go deleted file mode 100644 index f8aec50e5b..0000000000 --- a/internal/endtoend/testdata/mix/go/test.sql.go +++ /dev/null @@ -1,75 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// source: test.sql - -package querytest - -import ( - "context" -) - -const countFour = `-- name: CountFour :one -SELECT count(1) FROM bar WHERE id > $2 AND phone <> $3 AND name <> $1 -` - -type CountFourParams struct { - Name string - ID int32 - PhoneParam string -} - -func (q *Queries) CountFour(ctx context.Context, arg CountFourParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countFour, arg.Name, arg.ID, arg.PhoneParam) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countOne = `-- name: CountOne :one -SELECT count(1) FROM bar WHERE id = $2 AND phone < $3 and name <> $1 -` - -type CountOneParams struct { - Name string - ID int32 - PhoneParam string -} - -func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID, arg.PhoneParam) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countThree = `-- name: CountThree :one -SELECT count(1) FROM bar WHERE id > $2 AND name = $1 -` - -type CountThreeParams struct { - Name string - IDParam int32 -} - -func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countThree, arg.Name, arg.IDParam) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countTwo = `-- name: CountTwo :one -SELECT count(1) FROM bar WHERE id = $2 AND phone < $3 and name <> $1 -` - -type CountTwoParams struct { - Name string - IDParam int32 - PhoneParam string -} - -func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countTwo, arg.Name, arg.IDParam, arg.PhoneParam) - var count int64 - err := row.Scan(&count) - return count, err -} diff --git a/internal/endtoend/testdata/mix/test.sql b/internal/endtoend/testdata/mix/test.sql deleted file mode 100644 index 5832d0fa09..0000000000 --- a/internal/endtoend/testdata/mix/test.sql +++ /dev/null @@ -1,17 +0,0 @@ -CREATE TABLE bar ( - id serial not null, - name text not null, - phone text not null -); - --- name: CountOne :one -SELECT count(1) FROM bar WHERE id = $2 AND phone < @phone_param and name <> $1; - --- name: CountTwo :one -SELECT count(1) FROM bar WHERE id = sqlc.arg(id_param) AND phone < @phone_param and name <> $1; - --- name: CountThree :one -SELECT count(1) FROM bar WHERE id > sqlc.arg(id_param) AND name = $1; - --- name: CountFour :one -SELECT count(1) FROM bar WHERE id > $2 AND phone <> @phone_param AND name <> $1; diff --git a/internal/endtoend/testdata/mix/go/db.go b/internal/endtoend/testdata/mix_param_types/mysql/go/db.go similarity index 100% rename from internal/endtoend/testdata/mix/go/db.go rename to internal/endtoend/testdata/mix_param_types/mysql/go/db.go diff --git a/internal/endtoend/testdata/mix/go/models.go b/internal/endtoend/testdata/mix_param_types/mysql/go/models.go similarity index 100% rename from internal/endtoend/testdata/mix/go/models.go rename to internal/endtoend/testdata/mix_param_types/mysql/go/models.go diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go new file mode 100644 index 0000000000..cef2361690 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go @@ -0,0 +1,57 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: test.sql + +package querytest + +import ( + "context" +) + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 +` + +type CountOneParams struct { + Name string + ID int32 +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> $3 AND name <> $1 +` + +type CountThreeParams struct { + Name string + ID int32 + Phone string +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.Name, arg.ID, arg.Phone) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> $2 +` + +type CountTwoParams struct { + ID int32 + Name string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix/sqlc.json b/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json similarity index 68% rename from internal/endtoend/testdata/mix/sqlc.json rename to internal/endtoend/testdata/mix_param_types/mysql/sqlc.json index 6ca0bbcfa4..dfd2f59a26 100644 --- a/internal/endtoend/testdata/mix/sqlc.json +++ b/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json @@ -5,7 +5,8 @@ "path": "go", "name": "querytest", "schema": "test.sql", - "queries": "test.sql" + "queries": "test.sql", + "engine": "postgresql" } ] } diff --git a/internal/endtoend/testdata/mix_param_types/mysql/test.sql b/internal/endtoend/testdata/mix_param_types/mysql/test.sql new file mode 100644 index 0000000000..0bed5e4381 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/test.sql @@ -0,0 +1,14 @@ +CREATE TABLE bar ( + id serial not null, + name text not null, + phone text not null +); + +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1; + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1; diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go new file mode 100644 index 0000000000..c66db5f3b2 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Bar struct { + ID int32 + Name string + Phone string +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go new file mode 100644 index 0000000000..cef2361690 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go @@ -0,0 +1,57 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: test.sql + +package querytest + +import ( + "context" +) + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 +` + +type CountOneParams struct { + Name string + ID int32 +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> $3 AND name <> $1 +` + +type CountThreeParams struct { + Name string + ID int32 + Phone string +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.Name, arg.ID, arg.Phone) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> $2 +` + +type CountTwoParams struct { + ID int32 + Name string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json b/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json new file mode 100644 index 0000000000..dfd2f59a26 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "test.sql", + "queries": "test.sql", + "engine": "postgresql" + } + ] +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql new file mode 100644 index 0000000000..0bed5e4381 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql @@ -0,0 +1,14 @@ +CREATE TABLE bar ( + id serial not null, + name text not null, + phone text not null +); + +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1; + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1; From bfcb327c47091ed8860c417de1976e8825e715ef Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Mon, 8 Mar 2021 12:16:33 -0800 Subject: [PATCH 03/10] Adding a test --- .../testdata/mix_param_types/postgresql/go/test.sql.go | 9 +++++---- .../testdata/mix_param_types/postgresql/test.sql | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go index cef2361690..27c84b6956 100644 --- a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go @@ -8,16 +8,17 @@ import ( ) const countOne = `-- name: CountOne :one -SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 +SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 LIMIT $3 ` type CountOneParams struct { - Name string - ID int32 + Name string + ID int32 + Limit int32 } func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID) + row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID, arg.Limit) var count int64 err := row.Scan(&count) return count, err diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql index 0bed5e4381..411f99829f 100644 --- a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql +++ b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql @@ -5,7 +5,7 @@ CREATE TABLE bar ( ); -- name: CountOne :one -SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1; +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1 LIMIT sqlc.arg('limit'); -- name: CountTwo :one SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); From 90018fca4dad2a6defbc1d2b5342e4b65b43e333 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Wed, 28 Apr 2021 09:32:30 -0700 Subject: [PATCH 04/10] Fix param --- internal/endtoend/testdata/mix_param_types/mysql/sqlc.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json b/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json index dfd2f59a26..145f64ba3f 100644 --- a/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json +++ b/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json @@ -6,7 +6,7 @@ "name": "querytest", "schema": "test.sql", "queries": "test.sql", - "engine": "postgresql" + "engine": "mysql" } ] } From 6f40bc10116feddd41fdb75e6a1fc27cf6ea8183 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 27 Jun 2021 23:18:08 -0700 Subject: [PATCH 05/10] Some changes --- internal/compiler/parse.go | 3 ++ .../mix_param_types/mysql/go/models.go | 2 +- .../mix_param_types/mysql/go/test.sql.go | 16 +++--- .../testdata/mix_param_types/mysql/test.sql | 6 +-- internal/sql/validate/param_style.go | 49 +++++++++++++++++++ 5 files changed, 64 insertions(+), 12 deletions(-) create mode 100644 internal/sql/validate/param_style.go diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 0fbd242ea8..3781bee9f8 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -34,6 +34,9 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if o.Debug.DumpAST { debug.Dump(stmt) } + if err := validate.ParamStyle(stmt); err != nil { + return nil, err + } lastNumber, err := validate.ParamRef(stmt) if err != nil { return nil, err diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/models.go b/internal/endtoend/testdata/mix_param_types/mysql/go/models.go index c66db5f3b2..b10bc44571 100644 --- a/internal/endtoend/testdata/mix_param_types/mysql/go/models.go +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/models.go @@ -5,7 +5,7 @@ package querytest import () type Bar struct { - ID int32 + ID int64 Name string Phone string } diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go index cef2361690..cb72276900 100644 --- a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go @@ -8,44 +8,44 @@ import ( ) const countOne = `-- name: CountOne :one -SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 +SELECT count(1) FROM bar WHERE id = ? AND name <> ? ` type CountOneParams struct { Name string - ID int32 + ID int64 } func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID) + row := q.db.QueryRowContext(ctx, countOne, arg.ID, arg.Name) var count int64 err := row.Scan(&count) return count, err } const countThree = `-- name: CountThree :one -SELECT count(1) FROM bar WHERE id > $2 AND phone <> $3 AND name <> $1 +SELECT count(1) FROM bar WHERE id > ? AND phone <> ? AND name <> ? ` type CountThreeParams struct { Name string - ID int32 + ID int64 Phone string } func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countThree, arg.Name, arg.ID, arg.Phone) + row := q.db.QueryRowContext(ctx, countThree, arg.ID, arg.Phone, arg.Name) var count int64 err := row.Scan(&count) return count, err } const countTwo = `-- name: CountTwo :one -SELECT count(1) FROM bar WHERE id = $1 AND name <> $2 +SELECT count(1) FROM bar WHERE id = ? AND name <> ? ` type CountTwoParams struct { - ID int32 + ID int64 Name string } diff --git a/internal/endtoend/testdata/mix_param_types/mysql/test.sql b/internal/endtoend/testdata/mix_param_types/mysql/test.sql index 0bed5e4381..b624d3e2ea 100644 --- a/internal/endtoend/testdata/mix_param_types/mysql/test.sql +++ b/internal/endtoend/testdata/mix_param_types/mysql/test.sql @@ -5,10 +5,10 @@ CREATE TABLE bar ( ); -- name: CountOne :one -SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1; +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> ?; -- name: CountTwo :one -SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); +SELECT count(1) FROM bar WHERE id = ? AND name <> sqlc.arg(name); -- name: CountThree :one -SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1; +SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?; diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go new file mode 100644 index 0000000000..33c886d467 --- /dev/null +++ b/internal/sql/validate/param_style.go @@ -0,0 +1,49 @@ +package validate + +import ( + "github.com/kyleconroy/sqlc/internal/sql/ast" + "github.com/kyleconroy/sqlc/internal/sql/astutils" + "github.com/kyleconroy/sqlc/internal/sql/named" + "github.com/kyleconroy/sqlc/internal/sql/sqlerr" +) + +// A query can use one (and only one) of the following formats: +// - positional parameters $1 +// - named parameter operator @param +// - named parameter function calls sqlc.arg(param) +func ParamStyle(n ast.Node) error { + namedFunc := astutils.Search(n, named.IsParamFunc) + for _, f := range namedFunc.Items { + fc, ok := f.(*ast.FuncCall) + if ok { + /* + if len(fc.Args.Items) != 1 { + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: "Wrong number of arguments to sqlc.arg()", + } + } + */ + switch fc.Args.Items[0].(type) { + case *ast.FuncCall: + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: "expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall", + } + case *ast.ParamRef: + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)", + } + case *ast.A_Const, *ast.ColumnRef: + default: + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: "Invalid argument to sqlc.arg()", + } + + } + } + } + return nil +} From ef9076eee05a9922958699cfa984fe85cabf3cbd Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Wed, 30 Jun 2021 06:23:39 -0700 Subject: [PATCH 06/10] Updates to get it to pass the tests and also deal with mysql challenges --- internal/compiler/parse.go | 4 ++-- .../testdata/invalid_params/pgx/stderr.txt | 2 +- .../testdata/invalid_params/stdlib/stderr.txt | 2 +- .../mix_param_types/mysql/go/test.sql.go | 6 +++--- .../sqlc_arg_invalid/mysql/stderr.txt | 4 ++-- .../sqlc_arg_invalid/postgresql/stderr.txt | 4 ++-- internal/sql/rewrite/parameters.go | 19 ++++++++++++++----- internal/sql/validate/param_ref.go | 16 +++++++--------- internal/sql/validate/param_style.go | 12 ++++++++---- 9 files changed, 40 insertions(+), 29 deletions(-) diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 3781bee9f8..3bac628bd8 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -37,7 +37,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err := validate.ParamStyle(stmt); err != nil { return nil, err } - lastNumber, err := validate.ParamRef(stmt) + numbers, err := validate.ParamRef(stmt) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } - raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, lastNumber) + raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers) rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) if o.UsePositionalParameters { diff --git a/internal/endtoend/testdata/invalid_params/pgx/stderr.txt b/internal/endtoend/testdata/invalid_params/pgx/stderr.txt index f338b8d756..722c3e2408 100644 --- a/internal/endtoend/testdata/invalid_params/pgx/stderr.txt +++ b/internal/endtoend/testdata/invalid_params/pgx/stderr.txt @@ -2,4 +2,4 @@ query.sql:4:1: could not determine data type of parameter $1 query.sql:7:1: could not determine data type of parameter $2 query.sql:10:8: column "foo" does not exist -query.sql:13:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) +query.sql:13:1: could not determine data type of parameter $2 diff --git a/internal/endtoend/testdata/invalid_params/stdlib/stderr.txt b/internal/endtoend/testdata/invalid_params/stdlib/stderr.txt index f338b8d756..722c3e2408 100644 --- a/internal/endtoend/testdata/invalid_params/stdlib/stderr.txt +++ b/internal/endtoend/testdata/invalid_params/stdlib/stderr.txt @@ -2,4 +2,4 @@ query.sql:4:1: could not determine data type of parameter $1 query.sql:7:1: could not determine data type of parameter $2 query.sql:10:8: column "foo" does not exist -query.sql:13:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) +query.sql:13:1: could not determine data type of parameter $2 diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go index cb72276900..0fa5b33c52 100644 --- a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go @@ -17,7 +17,7 @@ type CountOneParams struct { } func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countOne, arg.ID, arg.Name) + row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID) var count int64 err := row.Scan(&count) return count, err @@ -28,13 +28,13 @@ SELECT count(1) FROM bar WHERE id > ? AND phone <> ? AND name <> ? ` type CountThreeParams struct { - Name string ID int64 + Name string Phone string } func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countThree, arg.ID, arg.Phone, arg.Name) + row := q.db.QueryRowContext(ctx, countThree, arg.ID, arg.Name, arg.Phone) var count int64 err := row.Scan(&count) return count, err diff --git a/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt b/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt index ebf473a70c..3f07cbb5ef 100644 --- a/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt +++ b/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt @@ -1,5 +1,5 @@ # package querytest query.sql:7:1: function "sqlc.argh" does not exist query.sql:10:45: expected 1 parameter to sqlc.arg; got 2 -query.sql:13:45: expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall -query.sql:16:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) +query.sql:13:54: Invalid argument to sqlc.arg() +query.sql:16:54: Invalid argument to sqlc.arg() diff --git a/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt b/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt index ebf473a70c..3f07cbb5ef 100644 --- a/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt +++ b/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt @@ -1,5 +1,5 @@ # package querytest query.sql:7:1: function "sqlc.argh" does not exist query.sql:10:45: expected 1 parameter to sqlc.arg; got 2 -query.sql:13:45: expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall -query.sql:16:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) +query.sql:13:54: Invalid argument to sqlc.arg() +query.sql:16:54: Invalid argument to sqlc.arg() diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index dc85f42dac..178fdefdc1 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,7 +41,7 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -func NamedParameters(engine config.Engine, raw *ast.RawStmt, argn int) (*ast.RawStmt, map[int]string, []source.Edit) { +func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool) (*ast.RawStmt, map[int]string, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) if len(foundFunc.Items)+len(foundSign.Items) == 0 { @@ -51,11 +51,11 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, argn int) (*ast.Raw hasNamedParameterSupport := engine != config.EngineMySQL args := map[string]int{} + argn := 0 var edits []source.Edit node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { node := cr.Node() switch { - case named.IsParamFunc(node): fun := node.(*ast.FuncCall) param, isConst := flatten(fun.Args) @@ -65,7 +65,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, argn int) (*ast.Raw Location: fun.Location, }) } else { - argn += 1 + argn++ + for numbs[argn] { + argn++ + } args[param] = argn cr.Replace(&ast.ParamRef{ Number: argn, @@ -102,7 +105,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, argn int) (*ast.Raw } cr.Replace(cast) } else { - argn += 1 + argn++ + for numbs[argn] { + argn++ + } args[param] = argn cast.Arg = &ast.ParamRef{ Number: argn, @@ -127,7 +133,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, argn int) (*ast.Raw Location: expr.Location, }) } else { - argn += 1 + argn++ + for numbs[argn] { + argn++ + } args[param] = argn cr.Replace(&ast.ParamRef{ Number: argn, diff --git a/internal/sql/validate/param_ref.go b/internal/sql/validate/param_ref.go index e3332f2981..1175eccb45 100644 --- a/internal/sql/validate/param_ref.go +++ b/internal/sql/validate/param_ref.go @@ -8,7 +8,7 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) -func ParamRef(n ast.Node) (int, error) { +func ParamRef(n ast.Node) (map[int]bool, error) { var allrefs []*ast.ParamRef // Find all parameter references @@ -19,21 +19,19 @@ func ParamRef(n ast.Node) (int, error) { } }), n) - seen := map[int]struct{}{} + seen := map[int]bool{} for _, r := range allrefs { - seen[r.Number] = struct{}{} + if r.Number > 0 { + seen[r.Number] = true + } } - var max int for i := 1; i <= len(seen); i += 1 { - if i > max { - max = i - } if _, ok := seen[i]; !ok { - return 0, &sqlerr.Error{ + return nil, &sqlerr.Error{ Code: "42P18", Message: fmt.Sprintf("could not determine data type of parameter $%d", i), } } } - return max, nil + return seen, nil } diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index 33c886d467..b943e4f2ec 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -26,14 +26,18 @@ func ParamStyle(n ast.Node) error { */ switch fc.Args.Items[0].(type) { case *ast.FuncCall: + l := fc.Args.Items[0].(*ast.FuncCall) return &sqlerr.Error{ - Code: "", // TODO: Pick a new error code - Message: "expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall", + Code: "", // TODO: Pick a new error code + Message: "Invalid argument to sqlc.arg()", + Location: l.Location, } case *ast.ParamRef: + l := fc.Args.Items[0].(*ast.ParamRef) return &sqlerr.Error{ - Code: "", // TODO: Pick a new error code - Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)", + Code: "", // TODO: Pick a new error code + Message: "Invalid argument to sqlc.arg()", + Location: l.Location, } case *ast.A_Const, *ast.ColumnRef: default: From 614448b73203c0bc3586b1f86d2ff6529461a235 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Wed, 30 Jun 2021 21:44:17 -0700 Subject: [PATCH 07/10] Updates to allow for mysql --- internal/compiler/parse.go | 8 ++++++-- .../testdata/mix_param_types/mysql/go/test.sql.go | 8 ++++---- internal/sql/validate/param_style.go | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 3bac628bd8..e2c8a94c90 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -6,6 +6,7 @@ import ( "sort" "strings" + "github.com/kyleconroy/sqlc/internal/config" "github.com/kyleconroy/sqlc/internal/debug" "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/opts" @@ -86,7 +87,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, } } else { refs = uniqueParamRefs(refs) - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + if c.conf.Engine == config.EngineMySQL { + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location }) + } else { + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + } } qc, err := buildQueryCatalog(c.catalog, raw.Stmt) if err != nil { @@ -123,7 +128,6 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err != nil { return nil, err } - return &Query{ Cmd: cmd, Comments: comments, diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go index 0fa5b33c52..379da53185 100644 --- a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go @@ -12,12 +12,12 @@ SELECT count(1) FROM bar WHERE id = ? AND name <> ? ` type CountOneParams struct { - Name string ID int64 + Name string } func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID) + row := q.db.QueryRowContext(ctx, countOne, arg.ID, arg.Name) var count int64 err := row.Scan(&count) return count, err @@ -29,12 +29,12 @@ SELECT count(1) FROM bar WHERE id > ? AND phone <> ? AND name <> ? type CountThreeParams struct { ID int64 - Name string Phone string + Name string } func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countThree, arg.ID, arg.Name, arg.Phone) + row := q.db.QueryRowContext(ctx, countThree, arg.ID, arg.Phone, arg.Name) var count int64 err := row.Scan(&count) return count, err diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index b943e4f2ec..6bdc18ff86 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -20,7 +20,7 @@ func ParamStyle(n ast.Node) error { if len(fc.Args.Items) != 1 { return &sqlerr.Error{ Code: "", // TODO: Pick a new error code - Message: "Wrong number of arguments to sqlc.arg()", + Message: fmt.Sprintf("expected 1 parameter to sqlc.arg; got %d", len(fc.Args.Items)), } } */ From a5a5ee19a744f7c22a3dba83d1f0b7abb0234b9f Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Thu, 1 Jul 2021 07:28:50 -0700 Subject: [PATCH 08/10] Changes to allow working with ? in postgresql usefull for when you want to port from mysql --- internal/compiler/parse.go | 32 +++++++++++++------ .../mix_param_types/postgresql/go/test.sql.go | 17 ++++++++++ .../mix_param_types/postgresql/test.sql | 3 ++ internal/engine/postgresql/convert.go | 5 +++ internal/sql/ast/param_ref.go | 1 + internal/sql/rewrite/parameters.go | 20 +++++++++--- internal/sql/validate/param_ref.go | 20 +++++++++--- 7 files changed, 80 insertions(+), 18 deletions(-) diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index e2c8a94c90..2c312f0837 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -38,7 +38,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err := validate.ParamStyle(stmt); err != nil { return nil, err } - numbers, err := validate.ParamRef(stmt) + numbers, dollar, err := validate.ParamRef(stmt) if err != nil { return nil, err } @@ -77,7 +77,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } - raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers) + raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar) rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) if o.UsePositionalParameters { @@ -86,8 +86,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } } else { - refs = uniqueParamRefs(refs) - if c.conf.Engine == config.EngineMySQL { + refs = uniqueParamRefs(refs, dollar) + if c.conf.Engine == config.EngineMySQL || !dollar { sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location }) } else { sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) @@ -150,13 +150,27 @@ func rangeVars(root ast.Node) []*ast.RangeVar { return vars } -func uniqueParamRefs(in []paramRef) []paramRef { - m := make(map[int]struct{}, len(in)) +func uniqueParamRefs(in []paramRef, dollar bool) []paramRef { + m := make(map[int]bool, len(in)) o := make([]paramRef, 0, len(in)) for _, v := range in { - if _, ok := m[v.ref.Number]; !ok { - m[v.ref.Number] = struct{}{} - o = append(o, v) + if !m[v.ref.Number] { + m[v.ref.Number] = true + if v.ref.Number != 0 { + o = append(o, v) + } + } + } + if !dollar { + start := 1 + for _, v := range in { + if v.ref.Number == 0 { + for m[start] { + start++ + } + v.ref.Number = start + o = append(o, v) + } } } return o diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go index 27c84b6956..057a2ac26e 100644 --- a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go @@ -7,6 +7,23 @@ import ( "context" ) +const countFour = `-- name: CountFour :one +SELECT count(1) FROM bar WHERE id > ? AND phone <> ? AND name <> ? +` + +type CountFourParams struct { + ID int32 + Phone string + Name string +} + +func (q *Queries) CountFour(ctx context.Context, arg CountFourParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countFour, arg.ID, arg.Phone, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} + const countOne = `-- name: CountOne :one SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 LIMIT $3 ` diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql index 411f99829f..9ec77fb270 100644 --- a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql +++ b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql @@ -12,3 +12,6 @@ SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); -- name: CountThree :one SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1; + +-- name: CountFour :one +SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?; diff --git a/internal/engine/postgresql/convert.go b/internal/engine/postgresql/convert.go index 73f2631826..ddc9225162 100644 --- a/internal/engine/postgresql/convert.go +++ b/internal/engine/postgresql/convert.go @@ -2015,7 +2015,12 @@ func convertParamRef(n *pg.ParamRef) *ast.ParamRef { if n == nil { return nil } + var dollar bool + if n.Number != 0 { + dollar = true + } return &ast.ParamRef{ + Dollar: dollar, Number: int(n.Number), Location: int(n.Location), } diff --git a/internal/sql/ast/param_ref.go b/internal/sql/ast/param_ref.go index cc17dc065d..d0f486cf85 100644 --- a/internal/sql/ast/param_ref.go +++ b/internal/sql/ast/param_ref.go @@ -3,6 +3,7 @@ package ast type ParamRef struct { Number int Location int + Dollar bool } func (n *ParamRef) Pos() int { diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 178fdefdc1..446e81abba 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,7 +41,7 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool) (*ast.RawStmt, map[int]string, []source.Edit) { +func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]string, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) if len(foundFunc.Items)+len(foundSign.Items) == 0 { @@ -82,7 +82,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool) } else { old = fmt.Sprintf("sqlc.arg(%s)", param) } - if engine == config.EngineMySQL { + if engine == config.EngineMySQL || !dollar { replace = "?" } else { replace = fmt.Sprintf("$%d", args[param]) @@ -117,10 +117,16 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool) cr.Replace(cast) } // TODO: This code assumes that @foo::bool is on a single line + var replace string + if engine == config.EngineMySQL || !dollar { + replace = "?" + } else { + replace = fmt.Sprintf("$%d", args[param]) + } edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, Old: fmt.Sprintf("@%s", param), - New: fmt.Sprintf("$%d", args[param]), + New: replace, }) return false @@ -144,10 +150,16 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool) }) } // TODO: This code assumes that @foo is on a single line + var replace string + if engine == config.EngineMySQL || !dollar { + replace = "?" + } else { + replace = fmt.Sprintf("$%d", args[param]) + } edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, Old: fmt.Sprintf("@%s", param), - New: fmt.Sprintf("$%d", args[param]), + New: replace, }) return false diff --git a/internal/sql/validate/param_ref.go b/internal/sql/validate/param_ref.go index 1175eccb45..85becae81b 100644 --- a/internal/sql/validate/param_ref.go +++ b/internal/sql/validate/param_ref.go @@ -1,23 +1,33 @@ package validate import ( + "errors" "fmt" - "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) -func ParamRef(n ast.Node) (map[int]bool, error) { +func ParamRef(n ast.Node) (map[int]bool, bool, error) { var allrefs []*ast.ParamRef - + var dollar bool + var nodollar bool // Find all parameter references astutils.Walk(astutils.VisitorFunc(func(node ast.Node) { switch n := node.(type) { case *ast.ParamRef: + ref := node.(*ast.ParamRef) + if ref.Dollar { + dollar = true + } else { + nodollar = true + } allrefs = append(allrefs, n) } }), n) + if dollar && nodollar { + return nil, false, errors.New("Can not mix $1 format with ? format") + } seen := map[int]bool{} for _, r := range allrefs { @@ -27,11 +37,11 @@ func ParamRef(n ast.Node) (map[int]bool, error) { } for i := 1; i <= len(seen); i += 1 { if _, ok := seen[i]; !ok { - return nil, &sqlerr.Error{ + return nil, false, &sqlerr.Error{ Code: "42P18", Message: fmt.Sprintf("could not determine data type of parameter $%d", i), } } } - return seen, nil + return seen, dollar, nil } From 0de4cab192af57349d6a6efb8ec88dda13133d8c Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Thu, 1 Jul 2021 22:08:12 -0700 Subject: [PATCH 09/10] Some fixes to ensure $1 format is used when named params are used --- internal/engine/postgresql/rewrite_test.go | 1 + internal/sql/validate/param_ref.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/engine/postgresql/rewrite_test.go b/internal/engine/postgresql/rewrite_test.go index f416d70750..0aa20c565d 100644 --- a/internal/engine/postgresql/rewrite_test.go +++ b/internal/engine/postgresql/rewrite_test.go @@ -30,6 +30,7 @@ func TestApply(t *testing.T) { } if astutils.Join(fun.Funcname, ".") == "sqlc.arg" { cr.Replace(&ast.ParamRef{ + Dollar: true, Number: 1, Location: fun.Location, }) diff --git a/internal/sql/validate/param_ref.go b/internal/sql/validate/param_ref.go index 85becae81b..fbec8f9066 100644 --- a/internal/sql/validate/param_ref.go +++ b/internal/sql/validate/param_ref.go @@ -43,5 +43,5 @@ func ParamRef(n ast.Node) (map[int]bool, bool, error) { } } } - return seen, dollar, nil + return seen, !nodollar, nil } From 33df90d5d36c7eceb6b85164ec825baa8ff81325 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 12 Sep 2021 17:52:03 -0700 Subject: [PATCH 10/10] Changes per pull request comments --- internal/sql/validate/param_style.go | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index 6bdc18ff86..5e89601e03 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -16,28 +16,18 @@ func ParamStyle(n ast.Node) error { for _, f := range namedFunc.Items { fc, ok := f.(*ast.FuncCall) if ok { - /* - if len(fc.Args.Items) != 1 { - return &sqlerr.Error{ - Code: "", // TODO: Pick a new error code - Message: fmt.Sprintf("expected 1 parameter to sqlc.arg; got %d", len(fc.Args.Items)), - } - } - */ - switch fc.Args.Items[0].(type) { + switch val := fc.Args.Items[0].(type) { case *ast.FuncCall: - l := fc.Args.Items[0].(*ast.FuncCall) return &sqlerr.Error{ Code: "", // TODO: Pick a new error code Message: "Invalid argument to sqlc.arg()", - Location: l.Location, + Location: val.Location, } case *ast.ParamRef: - l := fc.Args.Items[0].(*ast.ParamRef) return &sqlerr.Error{ Code: "", // TODO: Pick a new error code Message: "Invalid argument to sqlc.arg()", - Location: l.Location, + Location: val.Location, } case *ast.A_Const, *ast.ColumnRef: default: