diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index cf5ef2f3e3..ca548e3e58 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -46,7 +46,10 @@ func (t *tmplCtx) OutputQuery(sourceName string) bool { func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) { enums := buildEnums(r, settings) structs := buildStructs(r, settings) - queries := buildQueries(r, settings, structs) + queries, err := buildQueries(r, settings, structs) + if err != nil { + return nil, err + } return generate(settings, enums, structs, queries) } diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 63ee7f68de..171125adc0 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -63,6 +63,21 @@ func (v *QueryValue) ReturnName() string { return v.Name } +func (v QueryValue) UniqueFields() []Field { + seen := map[string]struct{}{} + fields := make([]Field, 0, len(v.Struct.Fields)) + + for _, field := range v.Struct.Fields { + if _, found := seen[field.Name]; found { + continue + } + seen[field.Name] = struct{}{} + fields = append(fields, field) + } + + return fields +} + func (v QueryValue) Params() string { if v.isEmpty() { return "" diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 65069151ef..b213826eb9 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -135,7 +135,7 @@ func argName(name string) string { return out } -func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query { +func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) ([]Query, error) { qs := make([]Query, 0, len(r.Queries)) for _, query := range r.Queries { if query.Name == "" { @@ -178,11 +178,15 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs Column: p.Column, }) } + s, err := columnsToStruct(r, gq.MethodName+"Params", cols, settings, false) + if err != nil { + return nil, err + } gq.Arg = QueryValue{ - Emit: true, - Name: "arg", - Struct: columnsToStruct(r, gq.MethodName+"Params", cols, settings, false), - SQLPackage: sqlpkg, + Emit: true, + Name: "arg", + Struct: s, + SQLPackage: sqlpkg, EmitPointer: settings.Go.EmitParamsStructPointers, } } @@ -226,7 +230,11 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs Column: c, }) } - gs = columnsToStruct(r, gq.MethodName+"Row", columns, settings, true) + var err error + gs, err = columnsToStruct(r, gq.MethodName+"Row", columns, settings, true) + if err != nil { + return nil, err + } emit = true } gq.Ret = QueryValue{ @@ -241,7 +249,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs qs = append(qs, gq) } sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) - return qs + return qs, nil } // It's possible that this method will generate duplicate JSON tag values @@ -251,11 +259,11 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs // JSON tags: count, count_2, count_2 // // This is unlikely to happen, so don't fix it yet -func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settings config.CombinedSettings, useID bool) *Struct { +func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settings config.CombinedSettings, useID bool) (*Struct, error) { gs := Struct{ Name: name, } - seen := map[string]int{} + seen := map[string][]int{} suffixes := map[int]int{} for i, c := range columns { colName := columnName(c.Column, i) @@ -267,7 +275,7 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin suffix := 0 if o, ok := suffixes[c.id]; ok && useID { suffix = o - } else if v := seen[fieldName]; v > 0 { + } else if v := len(seen[fieldName]); v > 0 && !c.IsNamedParam { suffix = v + 1 } suffixes[c.id] = suffix @@ -287,8 +295,47 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin Type: goType(r, c.Column, settings), Tags: tags, }) - seen[baseFieldName]++ + if _, found := seen[baseFieldName]; !found { + seen[baseFieldName] = []int{i} + } else { + seen[baseFieldName] = append(seen[baseFieldName], i) + } } - return &gs + // If a field does not have a known type, but another + // field with the same name has a known type, assign + // the known type to the field without a known type + for i, field := range gs.Fields { + if len(seen[field.Name]) > 1 && field.Type == "interface{}" { + for _, j := range seen[field.Name] { + if i == j { + continue + } + otherField := gs.Fields[j] + if otherField.Type != field.Type { + field.Type = otherField.Type + } + gs.Fields[i] = field + } + } + } + + err := checkIncompatibleFieldTypes(gs.Fields) + if err != nil { + return nil, err + } + + return &gs, nil +} + +func checkIncompatibleFieldTypes(fields []Field) error { + fieldTypes := map[string]string{} + for _, field := range fields { + if fieldType, found := fieldTypes[field.Name]; !found { + fieldTypes[field.Name] = field.Type + } else if field.Type != fieldType { + return fmt.Errorf("named param %s has incompatible types: %s, %s", field.Name, field.Type, fieldType) + } + } + return nil } diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 9a23ed3fba..72472c5c50 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -6,7 +6,7 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} {{$.Q}} {{if .Arg.EmitStruct}} -type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} +type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}} {{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} {{- end}} } diff --git a/internal/compiler/query.go b/internal/compiler/query.go index e4721ed0c0..d2eb1d2fd7 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -15,12 +15,13 @@ type Table struct { } type Column struct { - Name string - DataType string - NotNull bool - IsArray bool - Comment string - Length *int + Name string + DataType string + NotNull bool + IsArray bool + Comment string + Length *int + IsNamedParam bool // XXX: Figure out what PostgreSQL calls `foo.id` Scope string diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index f8c73109fd..3086a703f3 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -31,6 +31,11 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa return defaultName } + isNamedParam := func(n int) bool { + _, ok := names[n] + return ok + } + typeMap := map[string]map[string]map[string]*catalog.Column{} indexTable := func(table catalog.Table) error { tables = append(tables, table.Rel) @@ -88,9 +93,10 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, "offset"), - DataType: "integer", - NotNull: true, + Name: parameterName(ref.ref.Number, "offset"), + DataType: "integer", + NotNull: true, + IsNamedParam: isNamedParam(ref.ref.Number), }, }) @@ -98,9 +104,10 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, "limit"), - DataType: "integer", - NotNull: true, + Name: parameterName(ref.ref.Number, "limit"), + DataType: "integer", + NotNull: true, + IsNamedParam: isNamedParam(ref.ref.Number), }, }) @@ -121,8 +128,9 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, ""), - DataType: dataType, + Name: parameterName(ref.ref.Number, ""), + DataType: dataType, + IsNamedParam: isNamedParam(ref.ref.Number), }, }) continue @@ -178,12 +186,13 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), - DataType: dataType(&c.Type), - NotNull: c.IsNotNull, - IsArray: c.IsArray, - Length: c.Length, - Table: table, + Name: parameterName(ref.ref.Number, key), + DataType: dataType(&c.Type), + NotNull: c.IsNotNull, + IsArray: c.IsArray, + Length: c.Length, + Table: table, + IsNamedParam: isNamedParam(ref.ref.Number), }, }) } @@ -234,11 +243,12 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), - DataType: dataType(&c.Type), - NotNull: c.IsNotNull, - IsArray: c.IsArray, - Table: table, + Name: parameterName(ref.ref.Number, key), + DataType: dataType(&c.Type), + NotNull: c.IsNotNull, + IsArray: c.IsArray, + Table: table, + IsNamedParam: isNamedParam(ref.ref.Number), }, }) } @@ -300,8 +310,9 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, defaultName), - DataType: "any", + Name: parameterName(ref.ref.Number, defaultName), + DataType: "any", + IsNamedParam: isNamedParam(ref.ref.Number), }, }) continue @@ -330,9 +341,10 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, paramName), - DataType: dataType(paramType), - NotNull: true, + Name: parameterName(ref.ref.Number, paramName), + DataType: dataType(paramType), + NotNull: true, + IsNamedParam: isNamedParam(ref.ref.Number), }, }) } @@ -388,12 +400,13 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), - DataType: dataType(&c.Type), - NotNull: c.IsNotNull, - IsArray: c.IsArray, - Table: &ast.TableName{Schema: schema, Name: rel}, - Length: c.Length, + Name: parameterName(ref.ref.Number, key), + DataType: dataType(&c.Type), + NotNull: c.IsNotNull, + IsArray: c.IsArray, + Table: &ast.TableName{Schema: schema, Name: rel}, + Length: c.Length, + IsNamedParam: isNamedParam(ref.ref.Number), }, }) } else { @@ -488,11 +501,12 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa a = append(a, Parameter{ Number: number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), - DataType: dataType(&c.Type), - NotNull: c.IsNotNull, - IsArray: c.IsArray, - Table: table, + Name: parameterName(ref.ref.Number, key), + DataType: dataType(&c.Type), + NotNull: c.IsNotNull, + IsArray: c.IsArray, + Table: table, + IsNamedParam: isNamedParam(ref.ref.Number), }, }) } diff --git a/internal/endtoend/testdata/case_named_params/mysql/go/db.go b/internal/endtoend/testdata/case_named_params/mysql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/mysql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/case_named_params/mysql/go/models.go b/internal/endtoend/testdata/case_named_params/mysql/go/models.go new file mode 100644 index 0000000000..ba8a932cb0 --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/mysql/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type Author struct { + ID int64 + Username sql.NullString + Email sql.NullString + Name string + Bio sql.NullString +} diff --git a/internal/endtoend/testdata/case_named_params/mysql/go/query.sql.go b/internal/endtoend/testdata/case_named_params/mysql/go/query.sql.go new file mode 100644 index 0000000000..96cdf2cc9c --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/mysql/go/query.sql.go @@ -0,0 +1,40 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const listAuthors = `-- name: ListAuthors :one +SELECT id, username, email, name, bio +FROM authors +WHERE email = CASE WHEN ? = '' then NULL else ? END + OR username = CASE WHEN ? = '' then NULL else ? END +LIMIT 1 +` + +type ListAuthorsParams struct { + Email sql.NullString + Username sql.NullString +} + +func (q *Queries) ListAuthors(ctx context.Context, arg ListAuthorsParams) (Author, error) { + row := q.db.QueryRowContext(ctx, listAuthors, + arg.Email, + arg.Email, + arg.Username, + arg.Username, + ) + var i Author + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.Name, + &i.Bio, + ) + return i, err +} diff --git a/internal/endtoend/testdata/case_named_params/mysql/query.sql b/internal/endtoend/testdata/case_named_params/mysql/query.sql new file mode 100644 index 0000000000..5167ec8769 --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/mysql/query.sql @@ -0,0 +1,18 @@ +-- https://github.com/kyleconroy/sqlc/issues/1195 + +CREATE TABLE authors ( + id BIGINT PRIMARY KEY, + username TEXT NULL, + email TEXT NULL, + name TEXT NOT NULL, + bio TEXT, + UNIQUE KEY idx_username (username), + UNIQUE KEY ids_email (email) +); + +-- name: ListAuthors :one +SELECT * +FROM authors +WHERE email = CASE WHEN sqlc.arg(email) = '' then NULL else sqlc.arg(email) END + OR username = CASE WHEN sqlc.arg(username) = '' then NULL else sqlc.arg(username) END +LIMIT 1; diff --git a/internal/endtoend/testdata/case_named_params/mysql/sqlc.json b/internal/endtoend/testdata/case_named_params/mysql/sqlc.json new file mode 100644 index 0000000000..534b7e24e9 --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "mysql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/case_named_params/postgresql/go/db.go b/internal/endtoend/testdata/case_named_params/postgresql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/postgresql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/case_named_params/postgresql/go/models.go b/internal/endtoend/testdata/case_named_params/postgresql/go/models.go new file mode 100644 index 0000000000..ba8a932cb0 --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/postgresql/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type Author struct { + ID int64 + Username sql.NullString + Email sql.NullString + Name string + Bio sql.NullString +} diff --git a/internal/endtoend/testdata/case_named_params/postgresql/go/query.sql.go b/internal/endtoend/testdata/case_named_params/postgresql/go/query.sql.go new file mode 100644 index 0000000000..fb3387df43 --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/postgresql/go/query.sql.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthors = `-- name: ListAuthors :one +SELECT id, username, email, name, bio +FROM authors +WHERE email = CASE WHEN $1::text = '' then NULL else $1::text END + OR username = CASE WHEN $2::text = '' then NULL else $2::text END +LIMIT 1 +` + +type ListAuthorsParams struct { + Email string + Username string +} + +func (q *Queries) ListAuthors(ctx context.Context, arg ListAuthorsParams) (Author, error) { + row := q.db.QueryRowContext(ctx, listAuthors, arg.Email, arg.Username) + var i Author + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.Name, + &i.Bio, + ) + return i, err +} diff --git a/internal/endtoend/testdata/case_named_params/postgresql/query.sql b/internal/endtoend/testdata/case_named_params/postgresql/query.sql new file mode 100644 index 0000000000..ddd974424a --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/postgresql/query.sql @@ -0,0 +1,16 @@ +-- https://github.com/kyleconroy/sqlc/issues/1195 + +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + username TEXT NULL, + email TEXT NULL, + name TEXT NOT NULL, + bio TEXT +); + +-- name: ListAuthors :one +SELECT * +FROM authors +WHERE email = CASE WHEN sqlc.arg(email)::text = '' then NULL else sqlc.arg(email)::text END + OR username = CASE WHEN sqlc.arg(username)::text = '' then NULL else sqlc.arg(username)::text END +LIMIT 1; diff --git a/internal/endtoend/testdata/case_named_params/postgresql/sqlc.json b/internal/endtoend/testdata/case_named_params/postgresql/sqlc.json new file mode 100644 index 0000000000..af57681f66 --- /dev/null +++ b/internal/endtoend/testdata/case_named_params/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/invalid_named_params/mysql/query.sql b/internal/endtoend/testdata/invalid_named_params/mysql/query.sql new file mode 100644 index 0000000000..1ea4a9e5dc --- /dev/null +++ b/internal/endtoend/testdata/invalid_named_params/mysql/query.sql @@ -0,0 +1,11 @@ +CREATE TABLE authors ( + id BIGINT PRIMARY KEY, + bio TEXT +); + +-- name: ListAuthors :one +SELECT * +FROM authors +WHERE id = sqlc.arg(my_named_param) + OR bio = sqlc.arg(my_named_param) +LIMIT 1; diff --git a/internal/endtoend/testdata/invalid_named_params/mysql/sqlc.json b/internal/endtoend/testdata/invalid_named_params/mysql/sqlc.json new file mode 100644 index 0000000000..534b7e24e9 --- /dev/null +++ b/internal/endtoend/testdata/invalid_named_params/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "mysql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/invalid_named_params/mysql/stderr.txt b/internal/endtoend/testdata/invalid_named_params/mysql/stderr.txt new file mode 100644 index 0000000000..9e45d80fbf --- /dev/null +++ b/internal/endtoend/testdata/invalid_named_params/mysql/stderr.txt @@ -0,0 +1,2 @@ +# package querytest +error generating code: named param MyNamedParam has incompatible types: sql.NullString, int64 diff --git a/internal/endtoend/testdata/on_duplicate_key_update/mysql/db/query.sql.go b/internal/endtoend/testdata/on_duplicate_key_update/mysql/db/query.sql.go index 4f2406ac33..07c6f36395 100644 --- a/internal/endtoend/testdata/on_duplicate_key_update/mysql/db/query.sql.go +++ b/internal/endtoend/testdata/on_duplicate_key_update/mysql/db/query.sql.go @@ -25,3 +25,20 @@ func (q *Queries) UpsertAuthor(ctx context.Context, arg UpsertAuthorParams) erro _, err := q.db.ExecContext(ctx, upsertAuthor, arg.Name, arg.Bio, arg.Bio_2) return err } + +const upsertAuthorNamed = `-- name: UpsertAuthorNamed :exec +INSERT INTO authors (name, bio) +VALUES (?, ?) +ON DUPLICATE KEY + UPDATE bio = ? +` + +type UpsertAuthorNamedParams struct { + Name string + Bio sql.NullString +} + +func (q *Queries) UpsertAuthorNamed(ctx context.Context, arg UpsertAuthorNamedParams) error { + _, err := q.db.ExecContext(ctx, upsertAuthorNamed, arg.Name, arg.Bio, arg.Bio) + return err +} diff --git a/internal/endtoend/testdata/on_duplicate_key_update/mysql/query.sql b/internal/endtoend/testdata/on_duplicate_key_update/mysql/query.sql index 25b98d3b15..de29787d4c 100644 --- a/internal/endtoend/testdata/on_duplicate_key_update/mysql/query.sql +++ b/internal/endtoend/testdata/on_duplicate_key_update/mysql/query.sql @@ -11,3 +11,9 @@ INSERT INTO authors (name, bio) VALUES (?, ?) ON DUPLICATE KEY UPDATE bio = ?; + +-- name: UpsertAuthorNamed :exec +INSERT INTO authors (name, bio) +VALUES (?, sqlc.arg(bio)) +ON DUPLICATE KEY + UPDATE bio = sqlc.arg(bio); diff --git a/internal/endtoend/testdata/on_duplicate_key_update/postgresql/db/query.sql.go b/internal/endtoend/testdata/on_duplicate_key_update/postgresql/db/query.sql.go index d96cde7f11..29fb0b6387 100644 --- a/internal/endtoend/testdata/on_duplicate_key_update/postgresql/db/query.sql.go +++ b/internal/endtoend/testdata/on_duplicate_key_update/postgresql/db/query.sql.go @@ -24,3 +24,20 @@ func (q *Queries) UpsertAuthor(ctx context.Context, arg UpsertAuthorParams) erro _, err := q.db.ExecContext(ctx, upsertAuthor, arg.Name, arg.Bio) return err } + +const upsertAuthorNamed = `-- name: UpsertAuthorNamed :exec +INSERT INTO authors (name, bio) +VALUES ($1, $2) +ON CONFLICT (name) DO UPDATE +SET bio = $2 +` + +type UpsertAuthorNamedParams struct { + Name string + Bio sql.NullString +} + +func (q *Queries) UpsertAuthorNamed(ctx context.Context, arg UpsertAuthorNamedParams) error { + _, err := q.db.ExecContext(ctx, upsertAuthorNamed, arg.Name, arg.Bio) + return err +} diff --git a/internal/endtoend/testdata/on_duplicate_key_update/postgresql/query.sql b/internal/endtoend/testdata/on_duplicate_key_update/postgresql/query.sql index b2a51c5f42..465a8d67f7 100644 --- a/internal/endtoend/testdata/on_duplicate_key_update/postgresql/query.sql +++ b/internal/endtoend/testdata/on_duplicate_key_update/postgresql/query.sql @@ -10,3 +10,9 @@ INSERT INTO authors (name, bio) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET bio = $2; + +-- name: UpsertAuthorNamed :exec +INSERT INTO authors (name, bio) +VALUES ($1, sqlc.arg(bio)) +ON CONFLICT (name) DO UPDATE +SET bio = sqlc.arg(bio); diff --git a/internal/endtoend/testdata/params_duplicate/mysql/go/query.sql.go b/internal/endtoend/testdata/params_duplicate/mysql/go/query.sql.go index 1f48521520..874683dae9 100644 --- a/internal/endtoend/testdata/params_duplicate/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/params_duplicate/mysql/go/query.sql.go @@ -14,12 +14,11 @@ users where (? = id OR ? = 0) ` type SelectUserByIDParams struct { - Column1 interface{} - ID interface{} + ID interface{} } func (q *Queries) SelectUserByID(ctx context.Context, arg SelectUserByIDParams) ([]sql.NullString, error) { - rows, err := q.db.QueryContext(ctx, selectUserByID, arg.Column1, arg.ID) + rows, err := q.db.QueryContext(ctx, selectUserByID, arg.ID, arg.ID) if err != nil { return nil, err } @@ -49,12 +48,11 @@ WHERE first_name = ? ` type SelectUserByNameParams struct { - FirstName sql.NullString - Name sql.NullString + Name sql.NullString } func (q *Queries) SelectUserByName(ctx context.Context, arg SelectUserByNameParams) ([]sql.NullString, error) { - rows, err := q.db.QueryContext(ctx, selectUserByName, arg.FirstName, arg.Name) + rows, err := q.db.QueryContext(ctx, selectUserByName, arg.Name, arg.Name) if err != nil { return nil, err } diff --git a/internal/endtoend/testdata/params_duplicate/mysql/sqlc.json b/internal/endtoend/testdata/params_duplicate/mysql/sqlc.json index a9e7b055a4..1f70415f37 100644 --- a/internal/endtoend/testdata/params_duplicate/mysql/sqlc.json +++ b/internal/endtoend/testdata/params_duplicate/mysql/sqlc.json @@ -1,12 +1,12 @@ { - "version": "1", - "packages": [ - { - "name": "querytest", - "path": "go", - "schema": "schema.sql", - "queries": "query.sql", - "engine": "mysql" - } - ] + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql" + } + ] } diff --git a/internal/endtoend/testdata/params_duplicate/postgresql/go/db.go b/internal/endtoend/testdata/params_duplicate/postgresql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/params_duplicate/postgresql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/params_duplicate/postgresql/go/models.go b/internal/endtoend/testdata/params_duplicate/postgresql/go/models.go new file mode 100644 index 0000000000..ec72cdf8ec --- /dev/null +++ b/internal/endtoend/testdata/params_duplicate/postgresql/go/models.go @@ -0,0 +1,13 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type User struct { + ID int32 + FirstName sql.NullString + LastName sql.NullString +} diff --git a/internal/endtoend/testdata/params_duplicate/postgresql/go/query.sql.go b/internal/endtoend/testdata/params_duplicate/postgresql/go/query.sql.go new file mode 100644 index 0000000000..7373b8100c --- /dev/null +++ b/internal/endtoend/testdata/params_duplicate/postgresql/go/query.sql.go @@ -0,0 +1,95 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const selectUserByID = `-- name: SelectUserByID :many +SELECT first_name from +users where ($1 = id OR $1 = 0) +` + +func (q *Queries) SelectUserByID(ctx context.Context, id interface{}) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, selectUserByID, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var first_name sql.NullString + if err := rows.Scan(&first_name); err != nil { + return nil, err + } + items = append(items, first_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUserByName = `-- name: SelectUserByName :many +SELECT first_name +FROM users +WHERE first_name = $1 + OR last_name = $1 +` + +func (q *Queries) SelectUserByName(ctx context.Context, name sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, selectUserByName, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var first_name sql.NullString + if err := rows.Scan(&first_name); err != nil { + return nil, err + } + items = append(items, first_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUserQuestion = `-- name: SelectUserQuestion :many +SELECT first_name from +users where ($1 = id OR $1 = 0) +` + +func (q *Queries) SelectUserQuestion(ctx context.Context, dollar_1 interface{}) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, selectUserQuestion, dollar_1) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var first_name sql.NullString + if err := rows.Scan(&first_name); err != nil { + return nil, err + } + items = append(items, first_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/params_duplicate/postgresql/query.sql b/internal/endtoend/testdata/params_duplicate/postgresql/query.sql new file mode 100644 index 0000000000..eb5f43c7f9 --- /dev/null +++ b/internal/endtoend/testdata/params_duplicate/postgresql/query.sql @@ -0,0 +1,19 @@ +CREATE TABLE users ( + id INT PRIMARY KEY, + first_name varchar(255), + last_name varchar(255) +); + +/* name: SelectUserByID :many */ +SELECT first_name from +users where (sqlc.arg(id) = id OR sqlc.arg(id) = 0); + +/* name: SelectUserByName :many */ +SELECT first_name +FROM users +WHERE first_name = sqlc.arg(name) + OR last_name = sqlc.arg(name); + +/* name: SelectUserQuestion :many */ +SELECT first_name from +users where ($1 = id OR $1 = 0); diff --git a/internal/endtoend/testdata/params_duplicate/postgresql/sqlc.json b/internal/endtoend/testdata/params_duplicate/postgresql/sqlc.json new file mode 100644 index 0000000000..541dd43a06 --- /dev/null +++ b/internal/endtoend/testdata/params_duplicate/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "query.sql", + "queries": "query.sql", + "engine": "postgresql" + } + ] +} diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index ba30f2acfa..209cfb382c 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -190,6 +190,9 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. case *ast.FuncSpec: a.apply(n, "Name", nil, n.Name) + case *ast.In: + a.applyList(n, "List") + case *ast.List: // Since item is a slice a.applyList(n, "Items") diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 446e81abba..b9ba52001e 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -50,7 +50,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, hasNamedParameterSupport := engine != config.EngineMySQL - args := map[string]int{} + args := map[string][]int{} argn := 0 var edits []source.Edit node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { @@ -59,9 +59,9 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamFunc(node): fun := node.(*ast.FuncCall) param, isConst := flatten(fun.Args) - if num, ok := args[param]; ok && hasNamedParameterSupport { + if nums, ok := args[param]; ok && hasNamedParameterSupport { cr.Replace(&ast.ParamRef{ - Number: num, + Number: nums[0], Location: fun.Location, }) } else { @@ -69,7 +69,11 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, for numbs[argn] { argn++ } - args[param] = argn + if _, found := args[param]; !found { + args[param] = []int{argn} + } else { + args[param] = append(args[param], argn) + } cr.Replace(&ast.ParamRef{ Number: argn, Location: fun.Location, @@ -85,7 +89,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param]) + replace = fmt.Sprintf("$%d", args[param][0]) } edits = append(edits, source.Edit{ Location: fun.Location - raw.StmtLocation, @@ -98,9 +102,9 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, expr := node.(*ast.A_Expr) cast := expr.Rexpr.(*ast.TypeCast) param, _ := flatten(cast.Arg) - if num, ok := args[param]; ok { + if nums, ok := args[param]; ok { cast.Arg = &ast.ParamRef{ - Number: num, + Number: nums[0], Location: expr.Location, } cr.Replace(cast) @@ -109,7 +113,11 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, for numbs[argn] { argn++ } - args[param] = argn + if _, found := args[param]; !found { + args[param] = []int{argn} + } else { + args[param] = append(args[param], argn) + } cast.Arg = &ast.ParamRef{ Number: argn, Location: expr.Location, @@ -121,7 +129,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param]) + replace = fmt.Sprintf("$%d", args[param][0]) } edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, @@ -133,9 +141,9 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamSign(node): expr := node.(*ast.A_Expr) param, _ := flatten(expr.Rexpr) - if num, ok := args[param]; ok { + if nums, ok := args[param]; ok { cr.Replace(&ast.ParamRef{ - Number: num, + Number: nums[0], Location: expr.Location, }) } else { @@ -143,7 +151,11 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, for numbs[argn] { argn++ } - args[param] = argn + if _, found := args[param]; !found { + args[param] = []int{argn} + } else { + args[param] = append(args[param], argn) + } cr.Replace(&ast.ParamRef{ Number: argn, Location: expr.Location, @@ -154,7 +166,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param]) + replace = fmt.Sprintf("$%d", args[param][0]) } edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, @@ -169,8 +181,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, }, nil) named := map[int]string{} - for k, v := range args { - named[v] = k + for k, vs := range args { + for _, v := range vs { + named[v] = k + } } return node.(*ast.RawStmt), named, edits }