Skip to content

Commit 08fd504

Browse files
Jillecameronpm
andcommitted
Add sqlc.slice() to support IN clauses in MySQL (#695)
This feature (currently MySQL-specific) allows passing in a slice to an IN clause. Adding the new function sqlc.slice() as opposed to overloading the parsing of "IN (?)" was chosen to guarantee backwards compatibility. SELECT * FROM tab WHERE col IN (sqlc.slice("go_param_name")) This commit is based on #1312 by Paul Cameron. I just rebased and did some cleanup. Co-authored-by: Paul Cameron <cameronpm@gmail.com>
1 parent d4902a6 commit 08fd504

File tree

37 files changed

+10736
-5004
lines changed

37 files changed

+10736
-5004
lines changed

docs/howto/select.md

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ func (q *Queries) GetInfoForAuthor(ctx context.Context, id int) (GetInfoForAutho
185185

186186
## Passing a slice as a parameter to a query
187187

188+
### PostgreSQL
189+
188190
In PostgreSQL,
189191
[ANY](https://www.postgresql.org/docs/current/functions-comparisons.html#id-1.5.8.28.16)
190192
allows you to check if a value exists in an array expression. Queries using ANY
@@ -262,3 +264,111 @@ func (q *Queries) ListAuthorsByIDs(ctx context.Context, ids []int) ([]Author, er
262264
return items, nil
263265
}
264266
```
267+
268+
### MySQL
269+
270+
MySQL differs from PostgreSQL in that placeholders must be generated based on
271+
the number of elements in the slice you pass in. Though trivial it is still
272+
something of a nuisance. The passed in slice must not be nil or empty or an
273+
error will be returned (ie not a panic). The placeholder insertion location is
274+
marked by the meta-function `sqlc.slice()` (which is similar to `sqlc.arg()`
275+
that you see documented under [Naming parameters](named_parameters.md)).
276+
277+
To rephrase, the `sqlc.slice('param')` behaves identically to `sqlc.arg()` it
278+
terms of how it maps the explicit argument to the function signature, eg:
279+
280+
* `sqlc.slice('ids')` maps to `ids []GoType` in the function signature
281+
* `sqlc.slice(cust_ids)` maps to `custIds []GoType` in the function signature
282+
(like `sqlc.arg()`, the parameter does not have to be quoted)
283+
284+
This feature is not compatible with `emit_prepared_queries` statement found in the
285+
[Configuration file](../reference/config.md).
286+
287+
```sql
288+
CREATE TABLE authors (
289+
id SERIAL PRIMARY KEY,
290+
bio text NOT NULL,
291+
birth_year int NOT NULL
292+
);
293+
294+
-- name: ListAuthorsByIDs :many
295+
SELECT * FROM authors
296+
WHERE id IN (sqlc.slice('ids'));
297+
```
298+
299+
The above SQL will generate the following code:
300+
301+
```go
302+
package db
303+
304+
import (
305+
"context"
306+
"database/sql"
307+
"fmt"
308+
"strings"
309+
)
310+
311+
type Author struct {
312+
ID int
313+
Bio string
314+
BirthYear int
315+
}
316+
317+
type DBTX interface {
318+
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
319+
PrepareContext(context.Context, string) (*sql.Stmt, error)
320+
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
321+
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
322+
}
323+
324+
func New(db DBTX) *Queries {
325+
return &Queries{db: db}
326+
}
327+
328+
type Queries struct {
329+
db DBTX
330+
}
331+
332+
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
333+
return &Queries{
334+
db: tx,
335+
}
336+
}
337+
338+
const listAuthorsByIDs = `-- name: ListAuthorsByIDs :many
339+
SELECT id, bio, birth_year FROM authors
340+
WHERE id IN (/*SLICE:ids*/?)
341+
`
342+
343+
func (q *Queries) ListAuthorsByIDs(ctx context.Context, ids []int64) ([]Author, error) {
344+
sql := listAuthorsByIDs
345+
var queryParams []interface{}
346+
if len(ids) == 0 {
347+
return nil, fmt.Errorf("slice ids must have at least one element")
348+
}
349+
for _, v := range ids {
350+
queryParams = append(queryParams, v)
351+
}
352+
sql = strings.Replace(sql, "/*SLICE:ids*/?", strings.Repeat(",?", len(ids))[1:], 1)
353+
rows, err := q.db.QueryContext(ctx, sql, queryParams...)
354+
if err != nil {
355+
return nil, err
356+
}
357+
defer rows.Close()
358+
var items []Author
359+
for rows.Next() {
360+
var i Author
361+
if err := rows.Scan(&i.ID, &i.Bio, &i.BirthYear); err != nil {
362+
return nil, err
363+
}
364+
items = append(items, i)
365+
}
366+
if err := rows.Close(); err != nil {
367+
return nil, err
368+
}
369+
if err := rows.Err(); err != nil {
370+
return nil, err
371+
}
372+
return items, nil
373+
}
374+
```

internal/cmd/shim.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
241241
Length: int32(l),
242242
IsNamedParam: c.IsNamedParam,
243243
IsFuncCall: c.IsFuncCall,
244+
IsSqlcSlice: c.IsSqlcSlice,
244245
}
245246

246247
if c.Type != nil {

internal/codegen/golang/field.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type Field struct {
1414
Type string
1515
Tags map[string]string
1616
Comment string
17+
Column *plugin.Column
1718
}
1819

1920
func (gf Field) Tag() string {
@@ -28,6 +29,10 @@ func (gf Field) Tag() string {
2829
return strings.Join(tags, " ")
2930
}
3031

32+
func (gf Field) HasSqlcSlice() bool {
33+
return gf.Column.IsSqlcSlice
34+
}
35+
3136
func JSONTagName(name string, settings *plugin.Settings) string {
3237
style := settings.Go.JsonTagsCaseStyle
3338
if style == "" || style == "none" {

internal/codegen/golang/gen.go

Lines changed: 82 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,63 @@ func (t *tmplCtx) OutputQuery(sourceName string) bool {
4343
return t.SourceName == sourceName
4444
}
4545

46+
func (t *tmplCtx) codegenDbarg() string {
47+
if t.EmitMethodsWithDBArgument {
48+
return "db DBTX, "
49+
}
50+
return ""
51+
}
52+
53+
// Called as a global method since subtemplate queryCodeStdExec does not have
54+
// access to the toplevel tmplCtx
55+
func (t *tmplCtx) codegenEmitPreparedQueries() bool {
56+
return t.EmitPreparedQueries
57+
}
58+
59+
func (t *tmplCtx) codegenQueryMethod(q Query) string {
60+
db := "q.db"
61+
if t.EmitMethodsWithDBArgument {
62+
db = "db"
63+
}
64+
65+
switch q.Cmd {
66+
case ":one":
67+
if t.EmitPreparedQueries {
68+
return "q.queryRow"
69+
}
70+
return db + ".QueryRowContext"
71+
72+
case ":many":
73+
if t.EmitPreparedQueries {
74+
return "q.query"
75+
}
76+
return db + ".QueryContext"
77+
78+
default:
79+
if t.EmitPreparedQueries {
80+
return "q.exec"
81+
}
82+
return db + ".ExecContext"
83+
}
84+
}
85+
86+
func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
87+
switch q.Cmd {
88+
case ":one":
89+
return "row :=", nil
90+
case ":many":
91+
return "rows, err :=", nil
92+
case ":exec":
93+
return "_, err :=", nil
94+
case ":execrows":
95+
return "result, err :=", nil
96+
case ":execresult":
97+
return "return", nil
98+
default:
99+
return "", fmt.Errorf("unhandled q.Cmd case %q", q.Cmd)
100+
}
101+
}
102+
46103
func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
47104
enums := buildEnums(req)
48105
structs := buildStructs(req)
@@ -61,24 +118,6 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
61118
Structs: structs,
62119
}
63120

64-
funcMap := template.FuncMap{
65-
"lowerTitle": sdk.LowerTitle,
66-
"comment": sdk.DoubleSlashComment,
67-
"escape": sdk.EscapeBacktick,
68-
"imports": i.Imports,
69-
"hasPrefix": strings.HasPrefix,
70-
}
71-
72-
tmpl := template.Must(
73-
template.New("table").
74-
Funcs(funcMap).
75-
ParseFS(
76-
templates,
77-
"templates/*.tmpl",
78-
"templates/*/*.tmpl",
79-
),
80-
)
81-
82121
golang := req.Settings.Go
83122
tctx := tmplCtx{
84123
EmitInterface: golang.EmitInterface,
@@ -107,6 +146,31 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
107146
return nil, errors.New(":batch* commands are only supported by pgx")
108147
}
109148

149+
funcMap := template.FuncMap{
150+
"lowerTitle": sdk.LowerTitle,
151+
"comment": sdk.DoubleSlashComment,
152+
"escape": sdk.EscapeBacktick,
153+
"imports": i.Imports,
154+
"hasPrefix": strings.HasPrefix,
155+
156+
// These methods are Go specific, they do not belong in the codegen package
157+
// (as that is language independent)
158+
"dbarg": tctx.codegenDbarg,
159+
"emitPreparedQueries": tctx.codegenEmitPreparedQueries,
160+
"queryMethod": tctx.codegenQueryMethod,
161+
"queryRetval": tctx.codegenQueryRetval,
162+
}
163+
164+
tmpl := template.Must(
165+
template.New("table").
166+
Funcs(funcMap).
167+
ParseFS(
168+
templates,
169+
"templates/*.tmpl",
170+
"templates/*/*.tmpl",
171+
),
172+
)
173+
110174
output := map[string]string{}
111175

112176
execute := func(name, templateName string) error {

internal/codegen/golang/go_type.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func goType(req *plugin.CodeGenRequest, col *plugin.Column) string {
3737
}
3838
}
3939
typ := goInnerType(req, col)
40-
if col.IsArray {
40+
if col.IsArray || col.IsSqlcSlice {
4141
return "[]" + typ
4242
}
4343
return typ

internal/codegen/golang/imports.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,12 +363,24 @@ func (i *importer) queryImports(filename string) fileImports {
363363
return false
364364
}
365365

366+
// Search for sqlc.slice() calls
367+
sqlcSliceScan := func() bool {
368+
for _, q := range gq {
369+
if q.Arg.HasSqlcSlices() {
370+
return true
371+
}
372+
}
373+
return false
374+
}
375+
366376
if anyNonCopyFrom {
367377
std["context"] = struct{}{}
368378
}
369379

370380
sqlpkg := parseDriver(i.Settings.Go.SqlPackage)
371-
if sliceScan() && !sqlpkg.IsPGX() {
381+
if sqlcSliceScan() {
382+
std["strings"] = struct{}{}
383+
} else if sliceScan() && !sqlpkg.IsPGX() {
372384
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
373385
}
374386

internal/codegen/golang/query.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ type QueryValue struct {
1515
Struct *Struct
1616
Typ string
1717
SQLDriver SQLDriver
18+
19+
// Column is kept so late in the generation process around to differentiate
20+
// between mysql slices and pg arrays
21+
Column *plugin.Column
1822
}
1923

2024
func (v QueryValue) EmitStruct() bool {
@@ -93,14 +97,14 @@ func (v QueryValue) Params() string {
9397
}
9498
var out []string
9599
if v.Struct == nil {
96-
if strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() {
100+
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() {
97101
out = append(out, "pq.Array("+v.Name+")")
98102
} else {
99103
out = append(out, v.Name)
100104
}
101105
} else {
102106
for _, f := range v.Struct.Fields {
103-
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
107+
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
104108
out = append(out, "pq.Array("+v.Name+"."+f.Name+")")
105109
} else {
106110
out = append(out, v.Name+"."+f.Name)
@@ -125,6 +129,20 @@ func (v QueryValue) ColumnNames() string {
125129
return "[]string{" + strings.Join(escapedNames, ", ") + "}"
126130
}
127131

132+
// When true, we have to build the arguments to q.db.QueryContext in addition to
133+
// munging the SQL
134+
func (v QueryValue) HasSqlcSlices() bool {
135+
if v.Struct == nil {
136+
return v.Column != nil && v.Column.IsSqlcSlice
137+
}
138+
for _, v := range v.Struct.Fields {
139+
if v.Column.IsSqlcSlice {
140+
return true
141+
}
142+
}
143+
return false
144+
}
145+
128146
func (v QueryValue) Scan() string {
129147
var out []string
130148
if v.Struct == nil {

internal/codegen/golang/result.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
168168
Name: paramName(p),
169169
Typ: goType(req, p.Column),
170170
SQLDriver: sqlpkg,
171+
Column: p.Column,
171172
}
172173
} else if len(query.Params) > 1 {
173174
var cols []goColumn
@@ -311,6 +312,7 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
311312
DBName: colName,
312313
Type: goType(req, c.Column),
313314
Tags: tags,
315+
Column: c.Column,
314316
})
315317
if _, found := seen[baseFieldName]; !found {
316318
seen[baseFieldName] = []int{i}

0 commit comments

Comments
 (0)