diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index 209175936c..30b8d14444 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -41,6 +41,9 @@ func goType(req *plugin.CodeGenRequest, col *plugin.Column) string { } sameTable := sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema) if oride.Column != "" && sdk.MatchString(oride.ColumnName, cname) && sameTable { + if col.IsSqlcSlice { + return "[]" + oride.GoType.TypeName + } return oride.GoType.TypeName } } diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/go/models.go b/internal/endtoend/testdata/sqlc_slice/mysql/go/models.go index 904f89e5ec..d7685d6c17 100644 --- a/internal/endtoend/testdata/sqlc_slice/mysql/go/models.go +++ b/internal/endtoend/testdata/sqlc_slice/mysql/go/models.go @@ -6,10 +6,13 @@ package querytest import ( "database/sql" + + "github.com/kyleconroy/sqlc-testdata/mysql" ) type Foo struct { - ID int32 - Name string - Bar sql.NullString + ID int32 + Name string + Bar sql.NullString + Mystr mysql.ID } diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/go/query.sql.go b/internal/endtoend/testdata/sqlc_slice/mysql/go/query.sql.go index 90d9a0efb2..f925c6cea2 100644 --- a/internal/endtoend/testdata/sqlc_slice/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_slice/mysql/go/query.sql.go @@ -9,6 +9,8 @@ import ( "context" "database/sql" "strings" + + "github.com/kyleconroy/sqlc-testdata/mysql" ) const funcNullable = `-- name: FuncNullable :many @@ -202,3 +204,41 @@ func (q *Queries) SliceExec(ctx context.Context, arg SliceExecParams) error { _, err := q.db.ExecContext(ctx, query, queryParams...) return err } + +const typedMyStr = `-- name: TypedMyStr :many +SELECT bar FROM foo +WHERE mystr IN (/*SLICE:mystr*/?) +` + +func (q *Queries) TypedMyStr(ctx context.Context, mystr []mysql.ID) ([]sql.NullString, error) { + query := typedMyStr + var queryParams []interface{} + if len(mystr) > 0 { + for _, v := range mystr { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:mystr*/?", strings.Repeat(",?", len(mystr))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:mystr*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var bar sql.NullString + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/query.sql b/internal/endtoend/testdata/sqlc_slice/mysql/query.sql index e68046b0b4..b50e4c3bd6 100644 --- a/internal/endtoend/testdata/sqlc_slice/mysql/query.sql +++ b/internal/endtoend/testdata/sqlc_slice/mysql/query.sql @@ -1,4 +1,4 @@ -CREATE TABLE foo (id int not null, name text not null, bar text null); +CREATE TABLE foo (id int not null, name text not null, bar text null, mystr text not null); /* name: FuncParamIdent :many */ SELECT name FROM foo @@ -20,4 +20,8 @@ WHERE id IN (sqlc.slice(favourites)); /* name: FuncNullable :many */ SELECT bar FROM foo -WHERE id IN (sqlc.slice('favourites')); \ No newline at end of file +WHERE id IN (sqlc.slice('favourites')); + +/* name: TypedMyStr :many */ +SELECT bar FROM foo +WHERE mystr IN (sqlc.slice(mystr)); diff --git a/internal/endtoend/testdata/sqlc_slice/mysql/sqlc.json b/internal/endtoend/testdata/sqlc_slice/mysql/sqlc.json index 0657f4db83..e37bdfff08 100644 --- a/internal/endtoend/testdata/sqlc_slice/mysql/sqlc.json +++ b/internal/endtoend/testdata/sqlc_slice/mysql/sqlc.json @@ -8,5 +8,11 @@ "schema": "query.sql", "queries": "query.sql" } + ], + "overrides": [ + { + "column": "foo.mystr", + "go_type": "github.com/kyleconroy/sqlc-testdata/mysql.ID" + } ] }