diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 551bd08515..7d796b33c4 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -34,10 +34,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if o.Debug.DumpAST { debug.Dump(stmt) } - if err := validate.ParamStyle(stmt); err != nil { - return nil, err - } - if err := validate.ParamRef(stmt); err != nil { + lastNumber, err := validate.ParamRef(stmt) + if err != nil { return nil, err } raw, ok := stmt.(*ast.RawStmt) @@ -75,7 +73,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } - raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw) + raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, lastNumber) rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) if o.UsePositionalParameters { diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/db.go b/internal/endtoend/testdata/mix_param_types/mysql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/models.go b/internal/endtoend/testdata/mix_param_types/mysql/go/models.go new file mode 100644 index 0000000000..c66db5f3b2 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Bar struct { + ID int32 + Name string + Phone string +} diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go new file mode 100644 index 0000000000..cef2361690 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go @@ -0,0 +1,57 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: test.sql + +package querytest + +import ( + "context" +) + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 +` + +type CountOneParams struct { + Name string + ID int32 +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> $3 AND name <> $1 +` + +type CountThreeParams struct { + Name string + ID int32 + Phone string +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.Name, arg.ID, arg.Phone) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> $2 +` + +type CountTwoParams struct { + ID int32 + Name string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json b/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json new file mode 100644 index 0000000000..145f64ba3f --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "test.sql", + "queries": "test.sql", + "engine": "mysql" + } + ] +} diff --git a/internal/endtoend/testdata/mix_param_types/mysql/test.sql b/internal/endtoend/testdata/mix_param_types/mysql/test.sql new file mode 100644 index 0000000000..0bed5e4381 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/test.sql @@ -0,0 +1,14 @@ +CREATE TABLE bar ( + id serial not null, + name text not null, + phone text not null +); + +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1; + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1; diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go new file mode 100644 index 0000000000..c66db5f3b2 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Bar struct { + ID int32 + Name string + Phone string +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go new file mode 100644 index 0000000000..27c84b6956 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go @@ -0,0 +1,58 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: test.sql + +package querytest + +import ( + "context" +) + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 LIMIT $3 +` + +type CountOneParams struct { + Name string + ID int32 + Limit int32 +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID, arg.Limit) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> $3 AND name <> $1 +` + +type CountThreeParams struct { + Name string + ID int32 + Phone string +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.Name, arg.ID, arg.Phone) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> $2 +` + +type CountTwoParams struct { + ID int32 + Name string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json b/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json new file mode 100644 index 0000000000..dfd2f59a26 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "test.sql", + "queries": "test.sql", + "engine": "postgresql" + } + ] +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql new file mode 100644 index 0000000000..411f99829f --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql @@ -0,0 +1,14 @@ +CREATE TABLE bar ( + id serial not null, + name text not null, + phone text not null +); + +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1 LIMIT sqlc.arg('limit'); + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1; diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 7c79e9b6d9..dc85f42dac 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,7 +41,7 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[int]string, []source.Edit) { +func NamedParameters(engine config.Engine, raw *ast.RawStmt, argn int) (*ast.RawStmt, map[int]string, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) if len(foundFunc.Items)+len(foundSign.Items) == 0 { @@ -51,7 +51,6 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[ hasNamedParameterSupport := engine != config.EngineMySQL args := map[string]int{} - argn := 0 var edits []source.Edit node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { node := cr.Node() diff --git a/internal/sql/validate/param_ref.go b/internal/sql/validate/param_ref.go index eacdb64fad..e3332f2981 100644 --- a/internal/sql/validate/param_ref.go +++ b/internal/sql/validate/param_ref.go @@ -8,7 +8,7 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) -func ParamRef(n ast.Node) error { +func ParamRef(n ast.Node) (int, error) { var allrefs []*ast.ParamRef // Find all parameter references @@ -23,14 +23,17 @@ func ParamRef(n ast.Node) error { for _, r := range allrefs { seen[r.Number] = struct{}{} } - + var max int for i := 1; i <= len(seen); i += 1 { + if i > max { + max = i + } if _, ok := seen[i]; !ok { - return &sqlerr.Error{ + return 0, &sqlerr.Error{ Code: "42P18", Message: fmt.Sprintf("could not determine data type of parameter $%d", i), } } } - return nil + return max, nil } diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go deleted file mode 100644 index 870f460f66..0000000000 --- a/internal/sql/validate/param_style.go +++ /dev/null @@ -1,34 +0,0 @@ -package validate - -import ( - "github.com/kyleconroy/sqlc/internal/sql/ast" - "github.com/kyleconroy/sqlc/internal/sql/astutils" - "github.com/kyleconroy/sqlc/internal/sql/named" - "github.com/kyleconroy/sqlc/internal/sql/sqlerr" -) - -// A query can use one (and only one) of the following formats: -// - positional parameters $1 -// - named parameter operator @param -// - named parameter function calls sqlc.arg(param) -func ParamStyle(n ast.Node) error { - positional := astutils.Search(n, func(node ast.Node) bool { - _, ok := node.(*ast.ParamRef) - return ok - }) - namedFunc := astutils.Search(n, named.IsParamFunc) - namedSign := astutils.Search(n, named.IsParamSign) - for _, check := range []bool{ - len(positional.Items) > 0 && len(namedSign.Items)+len(namedFunc.Items) > 0, - len(namedFunc.Items) > 0 && len(namedSign.Items)+len(positional.Items) > 0, - len(namedSign.Items) > 0 && len(positional.Items)+len(namedFunc.Items) > 0, - } { - if check { - return &sqlerr.Error{ - Code: "", // TODO: Pick a new error code - Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)", - } - } - } - return nil -}