Skip to content

Add sqlc.slice() new function type (#695) #1312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions docs/howto/select.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ func (q *Queries) GetInfoForAuthor(ctx context.Context, id int) (GetInfoForAutho

## Passing a slice as a parameter to a query

### PostgreSQL

In PostgreSQL,
[ANY](https://www.postgresql.org/docs/current/functions-comparisons.html#id-1.5.8.28.16)
allows you to check if a value exists in an array expression. Queries using ANY
Expand Down Expand Up @@ -262,3 +264,111 @@ func (q *Queries) ListAuthorsByIDs(ctx context.Context, ids []int) ([]Author, er
return items, nil
}
```

### MySQL

MySQL differs from PostgreSQL in that placeholders must be generated based on
the number of elements in the slice you pass in. Though trivial it is still
something of a nuisance. The passed in slice must not be nil or empty or an
error will be returned (ie not a panic). The placeholder insertion location is
marked by the meta-function `sqlc.slice()` (which is similar to `sqlc.arg()`
that you see documented under [Naming parameters](named_parameters.md)).

To rephrase, the `sqlc.slice('param')` behaves identically to `sqlc.arg()` it
terms of how it maps the explicit argument to the function signature, eg:

* `sqlc.slice('ids')` maps to `ids []GoType` in the function signature
* `sqlc.slice(cust_ids)` maps to `custIds []GoType` in the function signature
(like `sqlc.arg()`, the parameter does not have to be quoted)

This feature is not compatible with `emit_prepared_queries` statement found in the
[Configuration file](../reference/config.md).

```sql
CREATE TABLE authors (
id SERIAL PRIMARY KEY,
bio text NOT NULL,
birth_year int NOT NULL
);

-- name: ListAuthorsByIDs :many
SELECT * FROM authors
WHERE id IN (sqlc.slice('ids'));
```

The above SQL will generate the following code:

```go
package db

import (
"context"
"database/sql"
"fmt"
"strings"
)

type Author struct {
ID int
Bio string
BirthYear int
}

type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}

func New(db DBTX) *Queries {
return &Queries{db: db}
}

type Queries struct {
db DBTX
}

func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}

const listAuthorsByIDs = `-- name: ListAuthorsByIDs :many
SELECT id, bio, birth_year FROM authors
WHERE id IN (/*REPLACE:ids*/?)
`

func (q *Queries) ListAuthorsByIDs(ctx context.Context, ids []int64) ([]Author, error) {
sql := listAuthorsByIDs
var queryParams []interface{}
if len(ids) == 0 {
return nil, fmt.Errorf("slice ids must have at least one element")
}
for _, v := range ids {
queryParams = append(queryParams, v)
}
sql = strings.Replace(sql, "/*REPLACE:ids*/?", strings.Repeat(",?", len(ids))[1:], 1)
rows, err := q.db.QueryContext(ctx, sql, queryParams...)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Author
for rows.Next() {
var i Author
if err := rows.Scan(&i.ID, &i.Bio, &i.BirthYear); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
```
6 changes: 6 additions & 0 deletions internal/codegen/golang/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"sort"
"strings"

"github.com/kyleconroy/sqlc/internal/compiler"
"github.com/kyleconroy/sqlc/internal/config"
)

Expand All @@ -13,6 +14,7 @@ type Field struct {
Type string
Tags map[string]string
Comment string
Column *compiler.Column
}

func (gf Field) Tag() string {
Expand All @@ -27,6 +29,10 @@ func (gf Field) Tag() string {
return strings.Join(tags, " ")
}

func (gf Field) HasSlice() bool {
return gf.Column.IsSlice
}

func JSONTagName(name string, settings config.CombinedSettings) string {
style := settings.Go.JSONTagsCaseStyle
if style == "" || style == "none" {
Expand Down
98 changes: 81 additions & 17 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,63 @@ func (t *tmplCtx) OutputQuery(sourceName string) bool {
return t.SourceName == sourceName
}

func (t *tmplCtx) codegenDbarg() string {
if t.EmitMethodsWithDBArgument {
return "db DBTX, "
}
return ""
}

// Called as a global method since subtemplate queryCodeStdExec does not have
// access to the toplevel tmplCtx
func (t *tmplCtx) codegenEmitPreparedQueries() bool {
return t.EmitPreparedQueries
}

func (t *tmplCtx) codegenQueryMethod(q Query) string {
db := "q.db"
if t.EmitMethodsWithDBArgument {
db = "db"
}

switch q.Cmd {
case ":one":
if t.EmitPreparedQueries {
return "q.queryRow"
}
return db + ".QueryRowContext"

case ":many":
if t.EmitPreparedQueries {
return "q.query"
}
return db + ".QueryContext"

default:
if t.EmitPreparedQueries {
return "q.exec"
}
return db + ".ExecContext"
}
}

func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
switch q.Cmd {
case ":one":
return "row :=", nil
case ":many":
return "rows, err :=", nil
case ":exec":
return "_, err :=", nil
case ":execrows":
return "result, err :=", nil
case ":execresult":
return "return", nil
default:
return "", fmt.Errorf("unhandled q.Cmd case %q", q.Cmd)
}
}

func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) {
enums := buildEnums(r, settings)
structs := buildStructs(r, settings)
Expand All @@ -61,23 +118,6 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
Structs: structs,
}

funcMap := template.FuncMap{
"lowerTitle": codegen.LowerTitle,
"comment": codegen.DoubleSlashComment,
"escape": codegen.EscapeBacktick,
"imports": i.Imports,
}

tmpl := template.Must(
template.New("table").
Funcs(funcMap).
ParseFS(
templates,
"templates/*.tmpl",
"templates/*/*.tmpl",
),
)

golang := settings.Go
tctx := tmplCtx{
Settings: settings.Global,
Expand All @@ -95,6 +135,30 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
Structs: structs,
}

funcMap := template.FuncMap{
"lowerTitle": codegen.LowerTitle,
"comment": codegen.DoubleSlashComment,
"escape": codegen.EscapeBacktick,
"imports": i.Imports,

// These methods are Go specific, they do not belong in the codegen package
// (as that is language independent)
"dbarg": tctx.codegenDbarg,
"emitPreparedQueries": tctx.codegenEmitPreparedQueries,
"queryMethod": tctx.codegenQueryMethod,
"queryRetval": tctx.codegenQueryRetval,
}

tmpl := template.Must(
template.New("table").
Funcs(funcMap).
ParseFS(
templates,
"templates/*.tmpl",
"templates/*/*.tmpl",
),
)

output := map[string]string{}

execute := func(name, templateName string) error {
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func goType(r *compiler.Result, col *compiler.Column, settings config.CombinedSe
}
}
typ := goInnerType(r, col, settings)
if col.IsArray {
if col.IsArray || col.IsSlice {
return "[]" + typ
}
return typ
Expand Down
14 changes: 13 additions & 1 deletion internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,22 @@ func (i *importer) queryImports(filename string) fileImports {
return false
}

mysqlSliceScan := func() bool {
for _, q := range gq {
if q.Arg.HasSlices() {
return true
}
}
return false
}

std["context"] = struct{}{}

sqlpkg := SQLPackageFromString(i.Settings.Go.SQLPackage)
if sliceScan() && sqlpkg != SQLPackagePGX {
if mysqlSliceScan() {
std["fmt"] = struct{}{}
std["strings"] = struct{}{}
} else if sliceScan() && sqlpkg != SQLPackagePGX {
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
}

Expand Down
6 changes: 6 additions & 0 deletions internal/codegen/golang/mysql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ func mysqlType(r *compiler.Result, col *compiler.Column, settings config.Combine
}
return "sql.NullFloat64"

case "float":
if notNull {
return "float32"
}
return "sql.NullFloat64"

case "decimal", "dec", "fixed":
if notNull {
return "string"
Expand Down
23 changes: 21 additions & 2 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package golang
import (
"strings"

"github.com/kyleconroy/sqlc/internal/compiler"
"github.com/kyleconroy/sqlc/internal/metadata"
)

Expand All @@ -13,6 +14,10 @@ type QueryValue struct {
Struct *Struct
Typ string
SQLPackage SQLPackage

// Column is kept so late in the generation process around to differentiate
// between mysql slices and pg arrays
Column *compiler.Column
}

func (v QueryValue) EmitStruct() bool {
Expand Down Expand Up @@ -84,14 +89,14 @@ func (v QueryValue) Params() string {
}
var out []string
if v.Struct == nil {
if strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && v.SQLPackage != SQLPackagePGX {
if !v.Column.IsSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && v.SQLPackage != SQLPackagePGX {
out = append(out, "pq.Array("+v.Name+")")
} else {
out = append(out, v.Name)
}
} else {
for _, f := range v.Struct.Fields {
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX {
if !f.HasSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX {
out = append(out, "pq.Array("+v.Name+"."+f.Name+")")
} else {
out = append(out, v.Name+"."+f.Name)
Expand All @@ -105,6 +110,20 @@ func (v QueryValue) Params() string {
return "\n" + strings.Join(out, ",\n")
}

// When true, we have to build the arguments to q.db.QueryContext in addition to
// munging the SQL
func (v QueryValue) HasSlices() bool {
if v.Struct == nil {
return v.Column != nil && v.Column.IsSlice
}
for _, v := range v.Struct.Fields {
if v.Column.IsSlice {
return true
}
}
return false
}

func (v QueryValue) Scan() string {
var out []string
if v.Struct == nil {
Expand Down
8 changes: 5 additions & 3 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
Name: paramName(p),
Typ: goType(r, p.Column, settings),
SQLPackage: sqlpkg,
Column: p.Column,
}
} else if len(query.Params) > 1 {
var cols []goColumn
Expand Down Expand Up @@ -291,9 +292,10 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
tags["json:"] = JSONTagName(tagName, settings)
}
gs.Fields = append(gs.Fields, Field{
Name: fieldName,
Type: goType(r, c.Column, settings),
Tags: tags,
Name: fieldName,
Type: goType(r, c.Column, settings),
Tags: tags,
Column: c.Column,
})
if _, found := seen[baseFieldName]; !found {
seen[baseFieldName] = []int{i}
Expand Down
Loading