From a59d9512af60404c0225dbea81391b43654236b8 Mon Sep 17 00:00:00 2001 From: Paul Cameron Date: Thu, 23 Sep 2021 01:19:41 +1000 Subject: [PATCH] Add sqlc.slice() new function type (#695) This feature (currently MySQL-specific) allows passing in a slice to an IN clause. Adding the new function sqlc.slice() as opposed to overloading the parsing of "IN (?)" was chosen to guarantee backwards compatibility. SELECT * FROM tab WHERE col IN (sqlc.slice("go_param_name")) The MySQL FLOAT datatype mapping has been added too. --- docs/howto/select.md | 110 ++++++++++++++ internal/codegen/golang/field.go | 6 + internal/codegen/golang/gen.go | 98 ++++++++++--- internal/codegen/golang/go_type.go | 2 +- internal/codegen/golang/imports.go | 14 +- internal/codegen/golang/mysql_type.go | 6 + internal/codegen/golang/query.go | 23 ++- internal/codegen/golang/result.go | 8 +- .../golang/templates/stdlib/queryCode.tmpl | 110 +++++++------- internal/compiler/parse.go | 4 +- internal/compiler/query.go | 9 ++ internal/compiler/resolve.go | 15 +- internal/endtoend/endtoend_test.go | 3 + .../testdata/sqlc_slice/mysql/go/db.go | 29 ++++ .../testdata/sqlc_slice/mysql/go/models.go | 10 ++ .../testdata/sqlc_slice/mysql/go/query.sql.go | 135 ++++++++++++++++++ .../testdata/sqlc_slice/mysql/query.sql | 15 ++ .../testdata/sqlc_slice/mysql/sqlc.json | 12 ++ .../sqlc_slice/postgresql/pgx/go/db.go | 30 ++++ .../sqlc_slice/postgresql/pgx/go/models.go | 9 ++ .../sqlc_slice/postgresql/pgx/go/query.sql.go | 56 ++++++++ .../sqlc_slice/postgresql/pgx/query.sql | 7 + .../sqlc_slice/postgresql/pgx/sqlc.json | 13 ++ .../sqlc_slice/postgresql/stdlib/go/db.go | 29 ++++ .../sqlc_slice/postgresql/stdlib/go/models.go | 10 ++ .../postgresql/stdlib/go/query.sql.go | 76 ++++++++++ .../sqlc_slice/postgresql/stdlib/query.sql | 13 ++ .../sqlc_slice/postgresql/stdlib/sqlc.json | 12 ++ internal/sql/named/is.go | 2 +- internal/sql/rewrite/parameters.go | 38 +++-- internal/sql/validate/func_call.go | 8 +- internal/sql/validate/in.go | 86 +++++++++++ 32 files changed, 892 insertions(+), 106 deletions(-) create mode 100644 internal/endtoend/testdata/sqlc_slice/mysql/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_slice/mysql/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_slice/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_slice/mysql/query.sql create mode 100644 internal/endtoend/testdata/sqlc_slice/mysql/sqlc.json create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/pgx/query.sql create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/pgx/sqlc.json create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/query.sql create mode 100644 internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/sqlc.json create mode 100644 internal/sql/validate/in.go diff --git a/docs/howto/select.md b/docs/howto/select.md index 41c72c7885..097f951d57 100644 --- a/docs/howto/select.md +++ b/docs/howto/select.md @@ -185,6 +185,8 @@ func (q *Queries) GetInfoForAuthor(ctx context.Context, id int) (GetInfoForAutho ## Passing a slice as a parameter to a query +### PostgreSQL + In PostgreSQL, [ANY](https://www.postgresql.org/docs/current/functions-comparisons.html#id-1.5.8.28.16) allows you to check if a value exists in an array expression. Queries using ANY @@ -262,3 +264,111 @@ func (q *Queries) ListAuthorsByIDs(ctx context.Context, ids []int) ([]Author, er return items, nil } ``` + +### MySQL + +MySQL differs from PostgreSQL in that placeholders must be generated based on +the number of elements in the slice you pass in. Though trivial it is still +something of a nuisance. The passed in slice must not be nil or empty or an +error will be returned (ie not a panic). The placeholder insertion location is +marked by the meta-function `sqlc.slice()` (which is similar to `sqlc.arg()` +that you see documented under [Naming parameters](named_parameters.md)). + +To rephrase, the `sqlc.slice('param')` behaves identically to `sqlc.arg()` it +terms of how it maps the explicit argument to the function signature, eg: + + * `sqlc.slice('ids')` maps to `ids []GoType` in the function signature + * `sqlc.slice(cust_ids)` maps to `custIds []GoType` in the function signature + (like `sqlc.arg()`, the parameter does not have to be quoted) + +This feature is not compatible with `emit_prepared_queries` statement found in the +[Configuration file](../reference/config.md). + +```sql +CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + bio text NOT NULL, + birth_year int NOT NULL +); + +-- name: ListAuthorsByIDs :many +SELECT * FROM authors +WHERE id IN (sqlc.slice('ids')); +``` + +The above SQL will generate the following code: + +```go +package db + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +type Author struct { + ID int + Bio string + BirthYear int +} + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} + +const listAuthorsByIDs = `-- name: ListAuthorsByIDs :many +SELECT id, bio, birth_year FROM authors +WHERE id IN (/*REPLACE:ids*/?) +` + +func (q *Queries) ListAuthorsByIDs(ctx context.Context, ids []int64) ([]Author, error) { + sql := listAuthorsByIDs + var queryParams []interface{} + if len(ids) == 0 { + return nil, fmt.Errorf("slice ids must have at least one element") + } + for _, v := range ids { + queryParams = append(queryParams, v) + } + sql = strings.Replace(sql, "/*REPLACE:ids*/?", strings.Repeat(",?", len(ids))[1:], 1) + rows, err := q.db.QueryContext(ctx, sql, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Bio, &i.BirthYear); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} +``` \ No newline at end of file diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index e036b0e041..f9ce534b1b 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -5,6 +5,7 @@ import ( "sort" "strings" + "github.com/kyleconroy/sqlc/internal/compiler" "github.com/kyleconroy/sqlc/internal/config" ) @@ -13,6 +14,7 @@ type Field struct { Type string Tags map[string]string Comment string + Column *compiler.Column } func (gf Field) Tag() string { @@ -27,6 +29,10 @@ func (gf Field) Tag() string { return strings.Join(tags, " ") } +func (gf Field) HasSlice() bool { + return gf.Column.IsSlice +} + func JSONTagName(name string, settings config.CombinedSettings) string { style := settings.Go.JSONTagsCaseStyle if style == "" || style == "none" { diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index ca548e3e58..c0eb7913b6 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -43,6 +43,63 @@ func (t *tmplCtx) OutputQuery(sourceName string) bool { return t.SourceName == sourceName } +func (t *tmplCtx) codegenDbarg() string { + if t.EmitMethodsWithDBArgument { + return "db DBTX, " + } + return "" +} + +// Called as a global method since subtemplate queryCodeStdExec does not have +// access to the toplevel tmplCtx +func (t *tmplCtx) codegenEmitPreparedQueries() bool { + return t.EmitPreparedQueries +} + +func (t *tmplCtx) codegenQueryMethod(q Query) string { + db := "q.db" + if t.EmitMethodsWithDBArgument { + db = "db" + } + + switch q.Cmd { + case ":one": + if t.EmitPreparedQueries { + return "q.queryRow" + } + return db + ".QueryRowContext" + + case ":many": + if t.EmitPreparedQueries { + return "q.query" + } + return db + ".QueryContext" + + default: + if t.EmitPreparedQueries { + return "q.exec" + } + return db + ".ExecContext" + } +} + +func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) { + switch q.Cmd { + case ":one": + return "row :=", nil + case ":many": + return "rows, err :=", nil + case ":exec": + return "_, err :=", nil + case ":execrows": + return "result, err :=", nil + case ":execresult": + return "return", nil + default: + return "", fmt.Errorf("unhandled q.Cmd case %q", q.Cmd) + } +} + func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) { enums := buildEnums(r, settings) structs := buildStructs(r, settings) @@ -61,23 +118,6 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, Structs: structs, } - funcMap := template.FuncMap{ - "lowerTitle": codegen.LowerTitle, - "comment": codegen.DoubleSlashComment, - "escape": codegen.EscapeBacktick, - "imports": i.Imports, - } - - tmpl := template.Must( - template.New("table"). - Funcs(funcMap). - ParseFS( - templates, - "templates/*.tmpl", - "templates/*/*.tmpl", - ), - ) - golang := settings.Go tctx := tmplCtx{ Settings: settings.Global, @@ -95,6 +135,30 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, Structs: structs, } + funcMap := template.FuncMap{ + "lowerTitle": codegen.LowerTitle, + "comment": codegen.DoubleSlashComment, + "escape": codegen.EscapeBacktick, + "imports": i.Imports, + + // These methods are Go specific, they do not belong in the codegen package + // (as that is language independent) + "dbarg": tctx.codegenDbarg, + "emitPreparedQueries": tctx.codegenEmitPreparedQueries, + "queryMethod": tctx.codegenQueryMethod, + "queryRetval": tctx.codegenQueryRetval, + } + + tmpl := template.Must( + template.New("table"). + Funcs(funcMap). + ParseFS( + templates, + "templates/*.tmpl", + "templates/*/*.tmpl", + ), + ) + output := map[string]string{} execute := func(name, templateName string) error { diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index 718a8dca1c..61fb87c5b5 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -17,7 +17,7 @@ func goType(r *compiler.Result, col *compiler.Column, settings config.CombinedSe } } typ := goInnerType(r, col, settings) - if col.IsArray { + if col.IsArray || col.IsSlice { return "[]" + typ } return typ diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 2c38a3c48e..0e3d780945 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -349,10 +349,22 @@ func (i *importer) queryImports(filename string) fileImports { return false } + mysqlSliceScan := func() bool { + for _, q := range gq { + if q.Arg.HasSlices() { + return true + } + } + return false + } + std["context"] = struct{}{} sqlpkg := SQLPackageFromString(i.Settings.Go.SQLPackage) - if sliceScan() && sqlpkg != SQLPackagePGX { + if mysqlSliceScan() { + std["fmt"] = struct{}{} + std["strings"] = struct{}{} + } else if sliceScan() && sqlpkg != SQLPackagePGX { pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index 4a56d64f23..8fa919fba0 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -55,6 +55,12 @@ func mysqlType(r *compiler.Result, col *compiler.Column, settings config.Combine } return "sql.NullFloat64" + case "float": + if notNull { + return "float32" + } + return "sql.NullFloat64" + case "decimal", "dec", "fixed": if notNull { return "string" diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 171125adc0..4b24d72974 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -3,6 +3,7 @@ package golang import ( "strings" + "github.com/kyleconroy/sqlc/internal/compiler" "github.com/kyleconroy/sqlc/internal/metadata" ) @@ -13,6 +14,10 @@ type QueryValue struct { Struct *Struct Typ string SQLPackage SQLPackage + + // Column is kept so late in the generation process around to differentiate + // between mysql slices and pg arrays + Column *compiler.Column } func (v QueryValue) EmitStruct() bool { @@ -84,14 +89,14 @@ func (v QueryValue) Params() string { } var out []string if v.Struct == nil { - if strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && v.SQLPackage != SQLPackagePGX { + if !v.Column.IsSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && v.SQLPackage != SQLPackagePGX { out = append(out, "pq.Array("+v.Name+")") } else { out = append(out, v.Name) } } else { for _, f := range v.Struct.Fields { - if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX { + if !f.HasSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX { out = append(out, "pq.Array("+v.Name+"."+f.Name+")") } else { out = append(out, v.Name+"."+f.Name) @@ -105,6 +110,20 @@ func (v QueryValue) Params() string { return "\n" + strings.Join(out, ",\n") } +// When true, we have to build the arguments to q.db.QueryContext in addition to +// munging the SQL +func (v QueryValue) HasSlices() bool { + if v.Struct == nil { + return v.Column != nil && v.Column.IsSlice + } + for _, v := range v.Struct.Fields { + if v.Column.IsSlice { + return true + } + } + return false +} + func (v QueryValue) Scan() string { var out []string if v.Struct == nil { diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index b213826eb9..79e4a88522 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -169,6 +169,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs Name: paramName(p), Typ: goType(r, p.Column, settings), SQLPackage: sqlpkg, + Column: p.Column, } } else if len(query.Params) > 1 { var cols []goColumn @@ -291,9 +292,10 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin tags["json:"] = JSONTagName(tagName, settings) } gs.Fields = append(gs.Fields, Field{ - Name: fieldName, - Type: goType(r, c.Column, settings), - Tags: tags, + Name: fieldName, + Type: goType(r, c.Column, settings), + Tags: tags, + Column: c.Column, }) if _, found := seen[baseFieldName]; !found { seen[baseFieldName] = []int{i} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 421cf7958f..ea1309be64 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -22,18 +22,8 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{if eq .Cmd ":one"}} {{range .Comments}}//{{.}} {{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { -{{- end -}} - {{- if $.EmitPreparedQueries}} - row := q.queryRow(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else if $.EmitMethodsWithDBArgument}} - row := db.QueryRowContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} - row := q.db.QueryRowContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- end}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + {{- template "queryCodeStdExec" . }} {{- if ne .Arg.Pair .Ret.Pair }} var {{.Ret.Name}} {{.Ret.Type}} {{- end}} @@ -45,18 +35,8 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De {{if eq .Cmd ":many"}} {{range .Comments}}//{{.}} {{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { -{{- end -}} - {{- if $.EmitPreparedQueries}} - rows, err := q.query(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else if $.EmitMethodsWithDBArgument}} - rows, err := db.QueryContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} - rows, err := q.db.QueryContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- end}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { + {{- template "queryCodeStdExec" . }} if err != nil { return nil, err } @@ -86,18 +66,8 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret. {{if eq .Cmd ":exec"}} {{range .Comments}}//{{.}} {{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error { -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { -{{- end -}} - {{- if $.EmitPreparedQueries}} - _, err := q.exec(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else if $.EmitMethodsWithDBArgument}} - _, err := db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} - _, err := q.db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- end}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) error { + {{- template "queryCodeStdExec" . }} return err } {{end}} @@ -105,18 +75,8 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { {{if eq .Cmd ":execrows"}} {{range .Comments}}//{{.}} {{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) { -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) { -{{- end -}} - {{- if $.EmitPreparedQueries}} - result, err := q.exec(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else if $.EmitMethodsWithDBArgument}} - result, err := db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} - result, err := q.db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- end}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { + {{- template "queryCodeStdExec" . }} if err != nil { return 0, err } @@ -127,21 +87,51 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er {{if eq .Cmd ":execresult"}} {{range .Comments}}//{{.}} {{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (sql.Result, error) { -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (sql.Result, error) { -{{- end -}} - {{- if $.EmitPreparedQueries}} - return q.exec(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else if $.EmitMethodsWithDBArgument}} - return db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} - return q.db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- end}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (sql.Result, error) { + {{- template "queryCodeStdExec" . }} } {{end}} {{end}} {{end}} {{end}} + +{{define "queryCodeStdExec"}} + {{- if .Arg.HasSlices }} + sql := {{.ConstantName}} + var queryParams []interface{} + {{- if .Arg.Struct }} + {{- $arg := .Arg }} + {{- range .Arg.Struct.Fields }} + {{- if .HasSlice }} + if len({{$arg.Name}}.{{.Name}}) == 0 { + return nil, fmt.Errorf("slice {{$arg.DefineType}}.{{.Name}} must have at least one element") + } + for _, v := range {{$arg.Name}}.{{.Name}} { + queryParams = append(queryParams, v) + } + sql = strings.Replace(sql, {{.Column.InterpolatedMagic}}, strings.Repeat(",?", len({{$arg.Name}}.{{.Name}}))[1:], 1) + {{- else }} + queryParams = append(queryParams, {{$arg.Name}}.{{.Name}}) + {{- end }} + {{- end }} + {{- else }} + {{- /* Single argument parameter to this goroutine (they are not packed + in a struct), because .Arg.HasSlices further up above was true, + this section is 100% a slice (impossible to get here otherwise). + */ -}}{{/* need a newline */}} + if len({{.Arg.Name}}) == 0 { + return nil, fmt.Errorf("slice {{.Arg.Name}} must have at least one element") + } + for _, v := range {{.Arg.Name}} { + queryParams = append(queryParams, v) + } + sql = strings.Replace(sql, {{.Arg.Column.InterpolatedMagic}}, strings.Repeat(",?", len({{.Arg.Name}}))[1:], 1) + {{- end }} + {{ queryRetval . }} {{ queryMethod . }}(ctx, sql, queryParams...) + {{- else if emitPreparedQueries }} + {{- queryRetval . }} {{ queryMethod . }}(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) + {{- else}} + {{- queryRetval . }} {{ queryMethod . }}(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end -}} +{{end}} diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index de47b9cd68..54a9ee4c03 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -69,6 +69,9 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err := validate.FuncCall(c.catalog, raw); err != nil { return nil, err } + if err := validate.In(c.catalog, raw); err != nil { + return nil, err + } name, cmd, err := metadata.Parse(strings.TrimSpace(rawSQL), c.parser.CommentSyntax()) if err != nil { return nil, err @@ -114,7 +117,6 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } edits = append(edits, expandEdits...) - expanded, err := source.Mutate(rawSQL, edits) if err != nil { return nil, err diff --git a/internal/compiler/query.go b/internal/compiler/query.go index d2eb1d2fd7..d373d565e4 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -1,6 +1,8 @@ package compiler import ( + "fmt" + "github.com/kyleconroy/sqlc/internal/sql/ast" ) @@ -29,9 +31,16 @@ type Column struct { TableAlias string Type *ast.TypeName + IsSlice bool // is this sqlc.slice + skipTableRequiredCheck bool } +// Named with "...Magic" because of the fixed string to be replaced +func (c *Column) InterpolatedMagic() string { + return fmt.Sprintf(`"/*REPLACE:%s*/?"`, c.Name) +} + type Query struct { SQL string Name string diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 3086a703f3..e3cb18d62a 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -7,6 +7,7 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" "github.com/kyleconroy/sqlc/internal/sql/catalog" + "github.com/kyleconroy/sqlc/internal/sql/rewrite" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) @@ -18,7 +19,7 @@ func dataType(n *ast.TypeName) string { } } -func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, names map[int]string) ([]Parameter, error) { +func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, names map[int]rewrite.NamedParam) ([]Parameter, error) { aliasMap := map[string]*ast.TableName{} // TODO: Deprecate defaultTable var defaultTable *ast.TableName @@ -26,7 +27,7 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa parameterName := func(n int, defaultName string) string { if n, ok := names[n]; ok { - return n + return n.Name } return defaultName } @@ -498,15 +499,23 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa if ref.name != "" { key = ref.name } + + paramName := key + var isSlice bool + if n, ok := names[ref.ref.Number]; ok { + paramName, isSlice = n.Name, n.Slice + } + a = append(a, Parameter{ Number: number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: paramName, DataType: dataType(&c.Type), NotNull: c.IsNotNull, IsArray: c.IsArray, Table: table, IsNamedParam: isNamedParam(ref.ref.Number), + IsSlice: isSlice, }, }) } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index e3756b22f3..e2f79aeee4 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -80,6 +80,9 @@ func TestReplay(t *testing.T) { return err } if info.Name() == "sqlc.json" || info.Name() == "sqlc.yaml" { + // if filepath.Dir(path) != "testdata/params_duplicate/mysql" { + // return filepath.SkipDir + // } dirs = append(dirs, filepath.Dir(path)) return filepath.SkipDir } diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/go/db.go b/internal/endtoend/testdata/sqlc_slice/mysql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/mysql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/go/models.go b/internal/endtoend/testdata/sqlc_slice/mysql/go/models.go new file mode 100644 index 0000000000..ae9587385d --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/mysql/go/models.go @@ -0,0 +1,10 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Foo struct { + ID int32 + Name string +} diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/go/query.sql.go b/internal/endtoend/testdata/sqlc_slice/mysql/go/query.sql.go new file mode 100644 index 0000000000..973c443c23 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/mysql/go/query.sql.go @@ -0,0 +1,135 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "fmt" + "strings" +) + +const funcParamIdent = `-- name: FuncParamIdent :many +SELECT name FROM foo +WHERE name = ? + AND id IN (/*REPLACE:favourites*/?) +` + +type FuncParamIdentParams struct { + Slug string + Favourites []int32 +} + +func (q *Queries) FuncParamIdent(ctx context.Context, arg FuncParamIdentParams) ([]string, error) { + sql := funcParamIdent + var queryParams []interface{} + queryParams = append(queryParams, arg.Slug) + if len(arg.Favourites) == 0 { + return nil, fmt.Errorf("slice FuncParamIdentParams.Favourites must have at least one element") + } + for _, v := range arg.Favourites { + queryParams = append(queryParams, v) + } + sql = strings.Replace(sql, "/*REPLACE:favourites*/?", strings.Repeat(",?", len(arg.Favourites))[1:], 1) + rows, err := q.db.QueryContext(ctx, sql, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const funcParamSoloArg = `-- name: FuncParamSoloArg :many +SELECT name FROM foo +WHERE id IN (/*REPLACE:favourites*/?) +` + +func (q *Queries) FuncParamSoloArg(ctx context.Context, favourites []int32) ([]string, error) { + sql := funcParamSoloArg + var queryParams []interface{} + if len(favourites) == 0 { + return nil, fmt.Errorf("slice favourites must have at least one element") + } + for _, v := range favourites { + queryParams = append(queryParams, v) + } + sql = strings.Replace(sql, "/*REPLACE:favourites*/?", strings.Repeat(",?", len(favourites))[1:], 1) + rows, err := q.db.QueryContext(ctx, sql, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const funcParamString = `-- name: FuncParamString :many +SELECT name FROM foo +WHERE name = ? + AND id IN (/*REPLACE:favourites*/?) +` + +type FuncParamStringParams struct { + Slug string + Favourites []int32 +} + +func (q *Queries) FuncParamString(ctx context.Context, arg FuncParamStringParams) ([]string, error) { + sql := funcParamString + var queryParams []interface{} + queryParams = append(queryParams, arg.Slug) + if len(arg.Favourites) == 0 { + return nil, fmt.Errorf("slice FuncParamStringParams.Favourites must have at least one element") + } + for _, v := range arg.Favourites { + queryParams = append(queryParams, v) + } + sql = strings.Replace(sql, "/*REPLACE:favourites*/?", strings.Repeat(",?", len(arg.Favourites))[1:], 1) + rows, err := q.db.QueryContext(ctx, sql, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/query.sql b/internal/endtoend/testdata/sqlc_slice/mysql/query.sql new file mode 100644 index 0000000000..3d0a78458e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/mysql/query.sql @@ -0,0 +1,15 @@ +CREATE TABLE foo (id int not null, name text not null); + +/* name: FuncParamIdent :many */ +SELECT name FROM foo +WHERE name = sqlc.arg(slug) + AND id IN (sqlc.slice(favourites)); + +/* name: FuncParamString :many */ +SELECT name FROM foo +WHERE name = sqlc.arg('slug') + AND id IN (sqlc.slice('favourites')); + +/* name: FuncParamSoloArg :many */ +SELECT name FROM foo +WHERE id IN (sqlc.slice('favourites')); diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/sqlc.json b/internal/endtoend/testdata/sqlc_slice/mysql/sqlc.json new file mode 100644 index 0000000000..0657f4db83 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "mysql", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/db.go b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..4559f50a4f --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/db.go @@ -0,0 +1,30 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/models.go b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..d21739e2dd --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/models.go @@ -0,0 +1,9 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Foo struct { + Name string +} diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..983b1520b6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/go/query.sql.go @@ -0,0 +1,56 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const funcParamIdent = `-- name: FuncParamIdent :many +SELECT name FROM foo WHERE name = $1 +` + +func (q *Queries) FuncParamIdent(ctx context.Context, slug string) ([]string, error) { + rows, err := q.db.Query(ctx, funcParamIdent, slug) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const funcParamString = `-- name: FuncParamString :many +SELECT name FROM foo WHERE name = $1 +` + +func (q *Queries) FuncParamString(ctx context.Context, slug string) ([]string, error) { + rows, err := q.db.Query(ctx, funcParamString, slug) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/query.sql b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/query.sql new file mode 100644 index 0000000000..9a8e98e223 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/query.sql @@ -0,0 +1,7 @@ +CREATE TABLE foo (name text not null); + +-- name: FuncParamIdent :many +SELECT name FROM foo WHERE name = sqlc.arg(slug); + +-- name: FuncParamString :many +SELECT name FROM foo WHERE name = sqlc.arg('slug'); diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..ae9587385d --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/models.go @@ -0,0 +1,10 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Foo struct { + ID int32 + Name string +} diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..ca2b56bd0e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,76 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const funcParamIdent = `-- name: FuncParamIdent :many +SELECT name FROM foo +WHERE name = $1 + AND id IN ($2) +` + +type FuncParamIdentParams struct { + Slug string + Favourites int32 +} + +func (q *Queries) FuncParamIdent(ctx context.Context, arg FuncParamIdentParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, funcParamIdent, arg.Slug, arg.Favourites) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const funcParamString = `-- name: FuncParamString :many +SELECT name FROM foo +WHERE name = $1 + AND id IN ($2) +` + +type FuncParamStringParams struct { + Slug string + Favourites int32 +} + +func (q *Queries) FuncParamString(ctx context.Context, arg FuncParamStringParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, funcParamString, arg.Slug, arg.Favourites) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/query.sql b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..ca10765339 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/query.sql @@ -0,0 +1,13 @@ +CREATE TABLE foo (id int not null, name text not null); + +/* name: FuncParamIdent :many */ +SELECT name FROM foo +WHERE name = sqlc.arg(slug) + AND id IN (sqlc.slice(favourites)); + + + +/* name: FuncParamString :many */ +SELECT name FROM foo +WHERE name = sqlc.arg('slug') + AND id IN (sqlc.slice('favourites')); diff --git a/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..de427d069f --- /dev/null +++ b/internal/endtoend/testdata/sqlc_slice/postgresql/stdlib/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/sql/named/is.go b/internal/sql/named/is.go index 5421a85bb1..91b1d5108e 100644 --- a/internal/sql/named/is.go +++ b/internal/sql/named/is.go @@ -13,7 +13,7 @@ func IsParamFunc(node ast.Node) bool { if call.Func == nil { return false } - return call.Func.Schema == "sqlc" && call.Func.Name == "arg" + return call.Func.Schema == "sqlc" && (call.Func.Name == "arg" || call.Func.Name == "slice") } func IsParamSign(node ast.Node) bool { diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index b9ba52001e..bb0fef894b 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,16 +41,22 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]string, []source.Edit) { +type NamedParam struct { + Name string + Slice bool +} + +func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]NamedParam, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) if len(foundFunc.Items)+len(foundSign.Items) == 0 { - return raw, map[int]string{}, nil + return raw, map[int]NamedParam{}, nil } hasNamedParameterSupport := engine != config.EngineMySQL args := map[string][]int{} + argsSlice := map[string]bool{} argn := 0 var edits []source.Edit node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { @@ -59,6 +65,8 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamFunc(node): fun := node.(*ast.FuncCall) param, isConst := flatten(fun.Args) + sqlcFunc := fun.Func.Name // "arg" or "slice" + isSlice := sqlcFunc == "slice" if nums, ok := args[param]; ok && hasNamedParameterSupport { cr.Replace(&ast.ParamRef{ Number: nums[0], @@ -69,11 +77,8 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, for numbs[argn] { argn++ } - if _, found := args[param]; !found { - args[param] = []int{argn} - } else { - args[param] = append(args[param], argn) - } + args[param] = append(args[param], argn) + argsSlice[param] = isSlice cr.Replace(&ast.ParamRef{ Number: argn, Location: fun.Location, @@ -82,12 +87,18 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, // TODO: This code assumes that sqlc.arg(name) is on a single line var old, replace string if isConst { - old = fmt.Sprintf("sqlc.arg('%s')", param) + old = fmt.Sprintf("sqlc.%s('%s')", sqlcFunc, param) } else { - old = fmt.Sprintf("sqlc.arg(%s)", param) + old = fmt.Sprintf("sqlc.%s(%s)", sqlcFunc, param) } if engine == config.EngineMySQL || !dollar { - replace = "?" + if isSlice { + // This sequence is also replicated in internal/codegen/golang.Field + // since it's needed during template generation for replacement + replace = fmt.Sprintf(`/*REPLACE:%s*/?`, param) + } else { + replace = "?" + } } else { replace = fmt.Sprintf("$%d", args[param][0]) } @@ -180,11 +191,12 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, } }, nil) - named := map[int]string{} + namedPos := make(map[int]NamedParam, len(args)) for k, vs := range args { for _, v := range vs { - named[v] = k + namedPos[v] = NamedParam{Name: k, Slice: argsSlice[k]} } } - return node.(*ast.RawStmt), named, edits + + return node.(*ast.RawStmt), namedPos, edits } diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index 157f835f28..cf75c11ddb 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -29,10 +29,10 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { return v } - // Custom validation for sqlc.arg + // Custom validation for sqlc.arg and sqlc.slice // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if fn.Name != "arg" { + if fn.Name != "arg" && fn.Name != "slice" { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } @@ -41,7 +41,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { } if len(call.Args.Items) > 1 { v.err = &sqlerr.Error{ - Message: fmt.Sprintf("expected 1 parameter to sqlc.arg; got %d", len(call.Args.Items)), + Message: fmt.Sprintf("expected 1 parameter to sqlc.%s; got %d", fn.Name, len(call.Args.Items)), Location: call.Pos(), } return nil @@ -51,7 +51,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { case *ast.ColumnRef: default: v.err = &sqlerr.Error{ - Message: fmt.Sprintf("expected parameter to sqlc.arg to be string or reference; got %T", n), + Message: fmt.Sprintf("expected parameter to sqlc.%s to be string or reference; got %T", fn.Name, n), Location: call.Pos(), } return nil diff --git a/internal/sql/validate/in.go b/internal/sql/validate/in.go new file mode 100644 index 0000000000..990819e899 --- /dev/null +++ b/internal/sql/validate/in.go @@ -0,0 +1,86 @@ +package validate + +import ( + "fmt" + + "github.com/kyleconroy/sqlc/internal/sql/ast" + "github.com/kyleconroy/sqlc/internal/sql/astutils" + "github.com/kyleconroy/sqlc/internal/sql/catalog" + "github.com/kyleconroy/sqlc/internal/sql/sqlerr" +) + +type inVisitor struct { + catalog *catalog.Catalog + err error +} + +func (v *inVisitor) Visit(node ast.Node) astutils.Visitor { + if v.err != nil { + return nil + } + + in, ok := node.(*ast.In) + if !ok { + return v + } + + // Validate that sqlc.slice in an IN statement is the only arg, eg: + // id IN (sqlc.slice("ids")) -- GOOD + // id in (0, 1, sqlc.slice("ids")) -- BAD + + if len(in.List) <= 1 { + return v + } + + for _, n := range in.List { + call, ok := n.(*ast.FuncCall) + if !ok { + continue + } + fn := call.Func + if fn == nil { + continue + } + + if fn.Schema == "sqlc" && fn.Name == "slice" { + var inExpr, multiArg string + + // determine inExpr + switch n := in.Expr.(type) { + case *ast.ColumnRef: + inExpr = n.Name + default: + inExpr = "..." + } + + // determine multiArg + if len(call.Args.Items) == 1 { + switch n := call.Args.Items[0].(type) { + case *ast.A_Const: + if str, ok := n.Val.(*ast.String); ok { + multiArg = "\"" + str.Str + "\"" + } else { + multiArg = "?" + } + case *ast.ColumnRef: + multiArg = n.Name + default: + // impossible, validate.FuncCall should have caught this + multiArg = "..." + } + } + v.err = &sqlerr.Error{ + Message: fmt.Sprintf("expected '%s IN' expr to consist only of sqlc.multi(%s); eg ", inExpr, multiArg), + Location: call.Pos(), + } + } + } + + return v +} + +func In(c *catalog.Catalog, n ast.Node) error { + visitor := inVisitor{catalog: c} + astutils.Walk(&visitor, n) + return visitor.err +}