diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index b8329b6f5e..e2c8a94c90 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -6,6 +6,7 @@ import ( "sort" "strings" + "github.com/kyleconroy/sqlc/internal/config" "github.com/kyleconroy/sqlc/internal/debug" "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/opts" @@ -37,7 +38,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err := validate.ParamStyle(stmt); err != nil { return nil, err } - if err := validate.ParamRef(stmt); err != nil { + numbers, err := validate.ParamRef(stmt) + if err != nil { return nil, err } raw, ok := stmt.(*ast.RawStmt) @@ -75,7 +77,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } - raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw) + raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers) rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) if o.UsePositionalParameters { @@ -85,7 +87,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, } } else { refs = uniqueParamRefs(refs) - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + if c.conf.Engine == config.EngineMySQL { + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location }) + } else { + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + } } qc, err := buildQueryCatalog(c.catalog, raw.Stmt) if err != nil { @@ -122,7 +128,6 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err != nil { return nil, err } - return &Query{ Cmd: cmd, Comments: comments, diff --git a/internal/endtoend/testdata/invalid_params/pgx/stderr.txt b/internal/endtoend/testdata/invalid_params/pgx/stderr.txt index f338b8d756..722c3e2408 100644 --- a/internal/endtoend/testdata/invalid_params/pgx/stderr.txt +++ b/internal/endtoend/testdata/invalid_params/pgx/stderr.txt @@ -2,4 +2,4 @@ query.sql:4:1: could not determine data type of parameter $1 query.sql:7:1: could not determine data type of parameter $2 query.sql:10:8: column "foo" does not exist -query.sql:13:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) +query.sql:13:1: could not determine data type of parameter $2 diff --git a/internal/endtoend/testdata/invalid_params/stdlib/stderr.txt b/internal/endtoend/testdata/invalid_params/stdlib/stderr.txt index f338b8d756..722c3e2408 100644 --- a/internal/endtoend/testdata/invalid_params/stdlib/stderr.txt +++ b/internal/endtoend/testdata/invalid_params/stdlib/stderr.txt @@ -2,4 +2,4 @@ query.sql:4:1: could not determine data type of parameter $1 query.sql:7:1: could not determine data type of parameter $2 query.sql:10:8: column "foo" does not exist -query.sql:13:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) +query.sql:13:1: could not determine data type of parameter $2 diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/db.go b/internal/endtoend/testdata/mix_param_types/mysql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/models.go b/internal/endtoend/testdata/mix_param_types/mysql/go/models.go new file mode 100644 index 0000000000..b10bc44571 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Bar struct { + ID int64 + Name string + Phone string +} diff --git a/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go new file mode 100644 index 0000000000..379da53185 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go @@ -0,0 +1,57 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: test.sql + +package querytest + +import ( + "context" +) + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = ? AND name <> ? +` + +type CountOneParams struct { + ID int64 + Name string +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > ? AND phone <> ? AND name <> ? +` + +type CountThreeParams struct { + ID int64 + Phone string + Name string +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.ID, arg.Phone, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = ? AND name <> ? +` + +type CountTwoParams struct { + ID int64 + Name string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json b/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json new file mode 100644 index 0000000000..145f64ba3f --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "test.sql", + "queries": "test.sql", + "engine": "mysql" + } + ] +} diff --git a/internal/endtoend/testdata/mix_param_types/mysql/test.sql b/internal/endtoend/testdata/mix_param_types/mysql/test.sql new file mode 100644 index 0000000000..b624d3e2ea --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/mysql/test.sql @@ -0,0 +1,14 @@ +CREATE TABLE bar ( + id serial not null, + name text not null, + phone text not null +); + +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> ?; + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = ? AND name <> sqlc.arg(name); + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?; diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go new file mode 100644 index 0000000000..c66db5f3b2 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Bar struct { + ID int32 + Name string + Phone string +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go new file mode 100644 index 0000000000..27c84b6956 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go @@ -0,0 +1,58 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: test.sql + +package querytest + +import ( + "context" +) + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 LIMIT $3 +` + +type CountOneParams struct { + Name string + ID int32 + Limit int32 +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.Name, arg.ID, arg.Limit) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> $3 AND name <> $1 +` + +type CountThreeParams struct { + Name string + ID int32 + Phone string +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.Name, arg.ID, arg.Phone) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> $2 +` + +type CountTwoParams struct { + ID int32 + Name string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json b/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json new file mode 100644 index 0000000000..dfd2f59a26 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "test.sql", + "queries": "test.sql", + "engine": "postgresql" + } + ] +} diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql new file mode 100644 index 0000000000..411f99829f --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql @@ -0,0 +1,14 @@ +CREATE TABLE bar ( + id serial not null, + name text not null, + phone text not null +); + +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1 LIMIT sqlc.arg('limit'); + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1; diff --git a/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt b/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt index ebf473a70c..3f07cbb5ef 100644 --- a/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt +++ b/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt @@ -1,5 +1,5 @@ # package querytest query.sql:7:1: function "sqlc.argh" does not exist query.sql:10:45: expected 1 parameter to sqlc.arg; got 2 -query.sql:13:45: expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall -query.sql:16:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) +query.sql:13:54: Invalid argument to sqlc.arg() +query.sql:16:54: Invalid argument to sqlc.arg() diff --git a/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt b/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt index ebf473a70c..3f07cbb5ef 100644 --- a/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt +++ b/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt @@ -1,5 +1,5 @@ # package querytest query.sql:7:1: function "sqlc.argh" does not exist query.sql:10:45: expected 1 parameter to sqlc.arg; got 2 -query.sql:13:45: expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall -query.sql:16:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) +query.sql:13:54: Invalid argument to sqlc.arg() +query.sql:16:54: Invalid argument to sqlc.arg() diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 7c79e9b6d9..178fdefdc1 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,7 +41,7 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[int]string, []source.Edit) { +func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool) (*ast.RawStmt, map[int]string, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) if len(foundFunc.Items)+len(foundSign.Items) == 0 { @@ -56,7 +56,6 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[ node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { node := cr.Node() switch { - case named.IsParamFunc(node): fun := node.(*ast.FuncCall) param, isConst := flatten(fun.Args) @@ -66,7 +65,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[ Location: fun.Location, }) } else { - argn += 1 + argn++ + for numbs[argn] { + argn++ + } args[param] = argn cr.Replace(&ast.ParamRef{ Number: argn, @@ -103,7 +105,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[ } cr.Replace(cast) } else { - argn += 1 + argn++ + for numbs[argn] { + argn++ + } args[param] = argn cast.Arg = &ast.ParamRef{ Number: argn, @@ -128,7 +133,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[ Location: expr.Location, }) } else { - argn += 1 + argn++ + for numbs[argn] { + argn++ + } args[param] = argn cr.Replace(&ast.ParamRef{ Number: argn, diff --git a/internal/sql/validate/param_ref.go b/internal/sql/validate/param_ref.go index eacdb64fad..1175eccb45 100644 --- a/internal/sql/validate/param_ref.go +++ b/internal/sql/validate/param_ref.go @@ -8,7 +8,7 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) -func ParamRef(n ast.Node) error { +func ParamRef(n ast.Node) (map[int]bool, error) { var allrefs []*ast.ParamRef // Find all parameter references @@ -19,18 +19,19 @@ func ParamRef(n ast.Node) error { } }), n) - seen := map[int]struct{}{} + seen := map[int]bool{} for _, r := range allrefs { - seen[r.Number] = struct{}{} + if r.Number > 0 { + seen[r.Number] = true + } } - for i := 1; i <= len(seen); i += 1 { if _, ok := seen[i]; !ok { - return &sqlerr.Error{ + return nil, &sqlerr.Error{ Code: "42P18", Message: fmt.Sprintf("could not determine data type of parameter $%d", i), } } } - return nil + return seen, nil } diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index 870f460f66..6bdc18ff86 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -12,21 +12,40 @@ import ( // - named parameter operator @param // - named parameter function calls sqlc.arg(param) func ParamStyle(n ast.Node) error { - positional := astutils.Search(n, func(node ast.Node) bool { - _, ok := node.(*ast.ParamRef) - return ok - }) namedFunc := astutils.Search(n, named.IsParamFunc) - namedSign := astutils.Search(n, named.IsParamSign) - for _, check := range []bool{ - len(positional.Items) > 0 && len(namedSign.Items)+len(namedFunc.Items) > 0, - len(namedFunc.Items) > 0 && len(namedSign.Items)+len(positional.Items) > 0, - len(namedSign.Items) > 0 && len(positional.Items)+len(namedFunc.Items) > 0, - } { - if check { - return &sqlerr.Error{ - Code: "", // TODO: Pick a new error code - Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)", + for _, f := range namedFunc.Items { + fc, ok := f.(*ast.FuncCall) + if ok { + /* + if len(fc.Args.Items) != 1 { + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: fmt.Sprintf("expected 1 parameter to sqlc.arg; got %d", len(fc.Args.Items)), + } + } + */ + switch fc.Args.Items[0].(type) { + case *ast.FuncCall: + l := fc.Args.Items[0].(*ast.FuncCall) + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: "Invalid argument to sqlc.arg()", + Location: l.Location, + } + case *ast.ParamRef: + l := fc.Args.Items[0].(*ast.ParamRef) + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: "Invalid argument to sqlc.arg()", + Location: l.Location, + } + case *ast.A_Const, *ast.ColumnRef: + default: + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: "Invalid argument to sqlc.arg()", + } + } } }