From 34c1a5498a8f34f86d9650416ec39560c7e3b7c0 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 20 Aug 2023 07:06:10 -0700 Subject: [PATCH 01/12] Separate the scope of tables to allow removal of aliases --- internal/compiler/find_params.go | 41 ++++++++-- internal/compiler/parse.go | 6 +- internal/compiler/resolve.go | 81 +++++++++++++++++-- .../testdata/subquery_with_where/go/db.go | 31 +++++++ .../testdata/subquery_with_where/go/models.go | 19 +++++ .../subquery_with_where/go/query.sql.go | 53 ++++++++++++ .../testdata/subquery_with_where/query.sql | 9 +++ .../testdata/subquery_with_where/sqlc.json | 12 +++ internal/sql/astutils/walk.go | 6 +- 9 files changed, 244 insertions(+), 14 deletions(-) create mode 100644 internal/endtoend/testdata/subquery_with_where/go/db.go create mode 100644 internal/endtoend/testdata/subquery_with_where/go/models.go create mode 100644 internal/endtoend/testdata/subquery_with_where/go/query.sql.go create mode 100644 internal/endtoend/testdata/subquery_with_where/query.sql create mode 100644 internal/endtoend/testdata/subquery_with_where/sqlc.json diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 41ffaf8ad7..656481fb85 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -10,7 +10,7 @@ import ( func findParameters(root ast.Node) ([]paramRef, error) { refs := make([]paramRef, 0) errors := make([]error, 0) - v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors} + v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors, rvs: &[]*ast.RangeVar{}} astutils.Walk(v, root) if len(*v.errs) > 0 { problems := *v.errs @@ -22,6 +22,7 @@ func findParameters(root ast.Node) ([]paramRef, error) { type paramRef struct { parent ast.Node + rvs []*ast.RangeVar rv *ast.RangeVar ref *ast.ParamRef name string // Named parameter support @@ -31,6 +32,7 @@ type paramSearch struct { parent ast.Node rangeVar *ast.RangeVar refs *[]paramRef + rvs *[]*ast.RangeVar seen map[int]struct{} errs *[]error @@ -58,6 +60,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { return p } + var reset bool switch n := node.(type) { case *ast.A_Expr: @@ -70,6 +73,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.parent = n.FuncCall case *ast.DeleteStmt: + reset = true if n.LimitCount != nil { p.limitCount = n.LimitCount } @@ -78,7 +82,12 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.parent = node case *ast.InsertStmt: + reset = true + if n.Relation != nil { + *p.rvs = append(*p.rvs, n.Relation) + } if s, ok := n.SelectStmt.(*ast.SelectStmt); ok { + *p.rvs = append(*p.rvs, toTables(s.FromClause)...) for i, item := range s.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -92,7 +101,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: *p.rvs}) p.seen[ref.Location] = struct{}{} } for _, item := range s.ValuesLists.Items { @@ -109,13 +118,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: *p.rvs}) p.seen[ref.Location] = struct{}{} } } } case *ast.UpdateStmt: + reset = true + *p.rvs = append(*p.rvs, toTables(n.FromClause)...) + *p.rvs = append(*p.rvs, toTables(n.Relations)...) for _, item := range n.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -130,7 +142,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { if !ok { continue } - *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv}) + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, rvs: *p.rvs}) } p.seen[ref.Location] = struct{}{} } @@ -139,12 +151,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } case *ast.RangeVar: + if n != nil { + *p.rvs = append(*p.rvs, n) + } p.rangeVar = n case *ast.ResTarget: p.parent = node case *ast.SelectStmt: + reset = true if n.LimitCount != nil { p.limitCount = n.LimitCount } @@ -191,7 +207,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } if set { - *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar}) + *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar, rvs: *p.rvs}) p.seen[n.Location] = struct{}{} } return nil @@ -215,5 +231,20 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.Visit(n.Expr) } } + if reset { + rvs := *p.rvs + return paramSearch{seen: p.seen, refs: p.refs, errs: p.errs, rvs: &rvs, parent: p.parent, rangeVar: p.rangeVar, limitCount: p.limitCount, limitOffset: p.limitOffset} + } return p } + +func toTables(tbl *ast.List) []*ast.RangeVar { + tables := make([]*ast.RangeVar, len(tbl.Items)) + for _, t := range tbl.Items { + item, ok := t.(*ast.RangeVar) + if ok && item != nil { + tables = append(tables, item) + } + } + return tables +} diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 8354bd340a..7f9d3badf5 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -95,7 +95,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } - params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) + err = c.resolveCatalogEmbeds(qc, rvs, embeds) + if err != nil { + return nil, err + } + params, err := c.resolveCatalogRefs(qc, refs, namedParams) if err != nil { return nil, err } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 0a91b45f25..b5f8acf9dc 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -20,7 +20,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar, embeds rewrite.EmbedSet) error { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -55,7 +55,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } fqn, err := ParseTableName(rv) if err != nil { - return nil, err + return err } if _, found := aliasMap[fqn.Name]; found { continue @@ -64,13 +64,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if err != nil { // If the table name doesn't exist, fisrt check if it's a CTE if _, qcerr := qc.GetTable(fqn); qcerr != nil { - return nil, err + return err } continue } err = indexTable(table) if err != nil { - return nil, err + return err } if rv.Alias != nil { aliasMap[*rv.Alias.Aliasname] = fqn @@ -90,11 +90,71 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } - return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) + return fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) } + return nil +} +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, params *named.ParamSet) ([]Parameter, error) { + c := comp.catalog + + // resolve a table for an embed var a []Parameter for _, ref := range args { + aliasMap := map[string]*ast.TableName{} + // TODO: Deprecate defaultTable + var defaultTable *ast.TableName + var tables []*ast.TableName + + typeMap := map[string]map[string]map[string]*catalog.Column{} + indexTable := func(table catalog.Table) error { + tables = append(tables, table.Rel) + if defaultTable == nil { + defaultTable = table.Rel + } + schema := table.Rel.Schema + if schema == "" { + schema = c.DefaultSchema + } + if _, exists := typeMap[schema]; !exists { + typeMap[schema] = map[string]map[string]*catalog.Column{} + } + typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{} + for _, c := range table.Columns { + cc := c + typeMap[schema][table.Rel.Name][c.Name] = cc + } + return nil + } + + for _, rv := range ref.rvs { + if rv == nil || rv.Relname == nil { + continue + } + fqn, err := ParseTableName(rv) + if err != nil { + return nil, err + } + if _, found := aliasMap[fqn.Name]; found { + continue + } + table, err := c.GetTable(fqn) + if err != nil { + // If the table name doesn't exist, fisrt check if it's a CTE + if _, qcerr := qc.GetTable(fqn); qcerr != nil { + return nil, err + } + continue + } + err = indexTable(table) + if err != nil { + return nil, err + } + if rv.Alias != nil { + aliasMap[*rv.Alias.Aliasname] = fqn + } + } + switch n := ref.parent.(type) { case *limitOffset: @@ -196,7 +256,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } var found int + seenTable := make(map[string]bool, len(search)) for _, table := range search { + if seenTable[table.Name] { + continue + } + seenTable[table.Name] = true schema := table.Schema if schema == "" { schema = c.DefaultSchema @@ -236,6 +301,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } } if found > 1 { + fmt.Println("ambiguous 3") return nil, &sqlerr.Error{ Code: "42703", Message: fmt.Sprintf("column reference %q is ambiguous", key), @@ -551,7 +617,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } } + seenTables := make(map[string]bool, len(search)) for _, table := range search { + if seenTables[table.Name] { + continue + } + seenTables[table.Name] = true schema := table.Schema if schema == "" { schema = c.DefaultSchema diff --git a/internal/endtoend/testdata/subquery_with_where/go/db.go b/internal/endtoend/testdata/subquery_with_where/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/subquery_with_where/go/models.go b/internal/endtoend/testdata/subquery_with_where/go/models.go new file mode 100644 index 0000000000..3fa48ca789 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/models.go @@ -0,0 +1,19 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql" +) + +type Bar struct { + A int32 + Alias sql.NullString +} + +type Foo struct { + A int32 + Name sql.NullString +} diff --git a/internal/endtoend/testdata/subquery_with_where/go/query.sql.go b/internal/endtoend/testdata/subquery_with_where/go/query.sql.go new file mode 100644 index 0000000000..d6db500c95 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/query.sql.go @@ -0,0 +1,53 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const subquery = `-- name: Subquery :many +SELECT + a, + name, + (SELECT alias FROM bar WHERE bar.a=foo.a AND alias = $1 ORDER BY bar.a DESC limit 1) as alias +FROM FOO WHERE a = $2 +` + +type SubqueryParams struct { + Alias sql.NullString + A int32 +} + +type SubqueryRow struct { + A int32 + Name sql.NullString + Alias sql.NullString +} + +func (q *Queries) Subquery(ctx context.Context, arg SubqueryParams) ([]SubqueryRow, error) { + rows, err := q.db.QueryContext(ctx, subquery, arg.Alias, arg.A) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SubqueryRow + for rows.Next() { + var i SubqueryRow + if err := rows.Scan(&i.A, &i.Name, &i.Alias); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/subquery_with_where/query.sql b/internal/endtoend/testdata/subquery_with_where/query.sql new file mode 100644 index 0000000000..12e6dfaf3f --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/query.sql @@ -0,0 +1,9 @@ +CREATE TABLE foo (a int not null, name text); +CREATE TABLE bar (a int not null, alias text); + +-- name: Subquery :many +SELECT + a, + name, + (SELECT alias FROM bar WHERE bar.a=foo.a AND alias = $1 ORDER BY bar.a DESC limit 1) as alias +FROM FOO WHERE a = $2; diff --git a/internal/endtoend/testdata/subquery_with_where/sqlc.json b/internal/endtoend/testdata/subquery_with_where/sqlc.json new file mode 100644 index 0000000000..c72b6132d5 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 9f26617ad3..149403601c 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1818,6 +1818,9 @@ func Walk(f Visitor, node ast.Node) { } case *ast.SelectStmt: + if n.FromClause != nil { + Walk(f, n.FromClause) + } if n.DistinctClause != nil { Walk(f, n.DistinctClause) } @@ -1827,9 +1830,6 @@ func Walk(f Visitor, node ast.Node) { if n.TargetList != nil { Walk(f, n.TargetList) } - if n.FromClause != nil { - Walk(f, n.FromClause) - } if n.WhereClause != nil { Walk(f, n.WhereClause) } From 6361797d598ba954801d58c6f6e75f710044159f Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 20 Aug 2023 07:32:12 -0700 Subject: [PATCH 02/12] Remove duplicate additions of tables --- internal/compiler/find_params.go | 15 ++++++++------- internal/compiler/resolve.go | 10 ---------- internal/sql/astutils/walk.go | 6 +++--- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 656481fb85..05fa1af187 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -83,11 +83,12 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { case *ast.InsertStmt: reset = true + rvs := *p.rvs if n.Relation != nil { - *p.rvs = append(*p.rvs, n.Relation) + rvs = append(rvs, n.Relation) } if s, ok := n.SelectStmt.(*ast.SelectStmt); ok { - *p.rvs = append(*p.rvs, toTables(s.FromClause)...) + rvs = append(rvs, toTables(s.FromClause)...) for i, item := range s.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -101,7 +102,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: *p.rvs}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: rvs}) p.seen[ref.Location] = struct{}{} } for _, item := range s.ValuesLists.Items { @@ -118,7 +119,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: *p.rvs}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: rvs}) p.seen[ref.Location] = struct{}{} } } @@ -126,8 +127,8 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { case *ast.UpdateStmt: reset = true - *p.rvs = append(*p.rvs, toTables(n.FromClause)...) - *p.rvs = append(*p.rvs, toTables(n.Relations)...) + rvs := append(*p.rvs, toTables(n.FromClause)...) + rvs = append(rvs, toTables(n.Relations)...) for _, item := range n.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -142,7 +143,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { if !ok { continue } - *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, rvs: *p.rvs}) + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, rvs: rvs}) } p.seen[ref.Location] = struct{}{} } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b5f8acf9dc..7f6a9f8675 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -256,12 +256,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para } var found int - seenTable := make(map[string]bool, len(search)) for _, table := range search { - if seenTable[table.Name] { - continue - } - seenTable[table.Name] = true schema := table.Schema if schema == "" { schema = c.DefaultSchema @@ -617,12 +612,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para } } - seenTables := make(map[string]bool, len(search)) for _, table := range search { - if seenTables[table.Name] { - continue - } - seenTables[table.Name] = true schema := table.Schema if schema == "" { schema = c.DefaultSchema diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 149403601c..f21e916975 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -2032,15 +2032,15 @@ func Walk(f Visitor, node ast.Node) { if n.Relations != nil { Walk(f, n.Relations) } + if n.FromClause != nil { + Walk(f, n.FromClause) + } if n.TargetList != nil { Walk(f, n.TargetList) } if n.WhereClause != nil { Walk(f, n.WhereClause) } - if n.FromClause != nil { - Walk(f, n.FromClause) - } if n.LimitCount != nil { Walk(f, n.LimitCount) } From 5d171a018fe7605811034e5ef9e4b89d174ccb1b Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 20 Aug 2023 10:47:07 -0700 Subject: [PATCH 03/12] remove comment --- internal/compiler/resolve.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 7f6a9f8675..d706aef31b 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -296,7 +296,6 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para } } if found > 1 { - fmt.Println("ambiguous 3") return nil, &sqlerr.Error{ Code: "42703", Message: fmt.Sprintf("column reference %q is ambiguous", key), From 7fe27942fefc72a92dc7bbd36733b6adf803755b Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Fri, 6 Oct 2023 17:16:52 -0700 Subject: [PATCH 04/12] updates --- internal/cmd/shim.go | 23 ++-- internal/codegen/golang/field.go | 4 + internal/codegen/golang/gen.go | 1 + internal/codegen/golang/go_type.go | 3 + internal/codegen/golang/mysql_type.go | 3 + internal/codegen/golang/postgresql_type.go | 3 + internal/codegen/golang/query.go | 26 ++++ .../codegen/golang/templates/pgx/dbCode.tmpl | 5 + .../golang/templates/stdlib/dbCode.tmpl | 6 + .../golang/templates/stdlib/queryCode.tmpl | 19 ++- internal/compiler/query.go | 3 +- internal/compiler/resolve.go | 120 +++++++++--------- .../params_location/mysql/go/query.sql.go | 37 ++++++ .../testdata/params_location/mysql/query.sql | 3 + internal/plugin/codegen.pb.go | 25 ++-- internal/sql/named/is.go | 3 +- internal/sql/named/param.go | 24 +++- internal/sql/rewrite/parameters.go | 4 + internal/sql/validate/func_call.go | 2 +- 19 files changed, 227 insertions(+), 87 deletions(-) diff --git a/internal/cmd/shim.go b/internal/cmd/shim.go index 7265f87511..be5009efb1 100644 --- a/internal/cmd/shim.go +++ b/internal/cmd/shim.go @@ -247,17 +247,18 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column { l = *c.Length } out := &plugin.Column{ - Name: c.Name, - OriginalName: c.OriginalName, - Comment: c.Comment, - NotNull: c.NotNull, - Unsigned: c.Unsigned, - IsArray: c.IsArray, - ArrayDims: int32(c.ArrayDims), - Length: int32(l), - IsNamedParam: c.IsNamedParam, - IsFuncCall: c.IsFuncCall, - IsSqlcSlice: c.IsSqlcSlice, + Name: c.Name, + OriginalName: c.OriginalName, + Comment: c.Comment, + NotNull: c.NotNull, + Unsigned: c.Unsigned, + IsArray: c.IsArray, + ArrayDims: int32(c.ArrayDims), + Length: int32(l), + IsNamedParam: c.IsNamedParam, + IsFuncCall: c.IsFuncCall, + IsSqlcSlice: c.IsSqlcSlice, + IsSqlcDynamic: c.IsSqlcDynamic, } if c.Type != nil { diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index ae7ba63573..cab783f9c0 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -28,6 +28,10 @@ func (gf Field) HasSqlcSlice() bool { return gf.Column.IsSqlcSlice } +func (gf Field) HasSqlcDynamic() bool { + return gf.Column.IsSqlcDynamic +} + func TagsToString(tags map[string]string) string { if len(tags) == 0 { return "" diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 7cd0a8dccd..68737b059d 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -38,6 +38,7 @@ type tmplCtx struct { EmitAllEnumValues bool UsesCopyFrom bool UsesBatch bool + HasSqlcDynamic bool } func (t *tmplCtx) OutputQuery(sourceName string) bool { diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index d6ba1ce69b..2b11e9d119 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -33,6 +33,9 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, co func goType(req *plugin.CodeGenRequest, col *plugin.Column) string { // Check if the column's type has been overridden + if col.IsSqlcDynamic { + return "DynamicSql" + } for _, oride := range req.Settings.Overrides { if oride.GoType.TypeName == "" { continue diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index c89d8ff3c7..0406924791 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -12,6 +12,9 @@ func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string { columnType := sdk.DataType(col.Type) notNull := col.NotNull || col.IsArray unsigned := col.Unsigned + if col.IsSqlcDynamic { + return "DynamicSql" + } switch columnType { diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index ae3e1278ac..9ab1ffc93f 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -34,6 +34,9 @@ func parseIdentifierString(name string) (*plugin.Identifier, error) { } func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) string { + if col.IsSqlcDynamic { + return "DynamicSql" + } columnType := sdk.DataType(col.Type) notNull := col.NotNull || col.IsArray driver := parseDriver(req.Settings.Go.SqlPackage) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 4fbecaffb3..befea871e5 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -164,6 +164,32 @@ func (v QueryValue) HasSqlcSlices() bool { } return false } +func (v QueryValue) HasSqlcDynamic() bool { + if v.Struct == nil { + if v.Column != nil && v.Column.IsSqlcDynamic { + return true + } + return false + } + for _, v := range v.Struct.Fields { + if v.Column.IsSqlcDynamic { + return true + } + } + return false +} +func (v QueryValue) SqlcDynamic() int { + var count int + if v.Struct == nil { + return 0 + } + for _, v := range v.Struct.Fields { + if !v.Column.IsSqlcDynamic && !v.Column.IsSqlcSlice { + count++ + } + } + return count +} func (v QueryValue) Scan() string { var out []string diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index 236554d9f2..2c0960847a 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -12,6 +12,11 @@ type DBTX interface { {{- end }} } +{{- if .HasSqlcDynamic }} +type DynamicSql interface { + ToSql(int) (string, []interface{}) +} +{{- end}} {{ if .EmitMethodsWithDBArgument}} func New() *Queries { return &Queries{} diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index 7433d522f6..59c22974b3 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -6,6 +6,12 @@ type DBTX interface { QueryRowContext(context.Context, string, ...interface{}) *sql.Row } +{{- if .HasSqlcDynamic }} +type DynamicSql interface { + Sql() (string, []interface{}) +} +{{- end}} + {{ if .EmitMethodsWithDBArgument}} func New() *Queries { return &Queries{} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 6eca49cd17..b2210081c4 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -109,9 +109,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{end}} {{define "queryCodeStdExec"}} - {{- if .Arg.HasSqlcSlices }} + {{- if or .Arg.HasSqlcSlices .Arg.HasSqlcDynamic }} query := {{.ConstantName}} var queryParams []interface{} + {{- if .Arg.HasSqlcDynamic }} + curNumb := {{ .Arg.SqlcDynamic }} + {{- end }} {{- if .Arg.Struct }} {{- $arg := .Arg }} {{- range .Arg.Struct.Fields }} @@ -121,14 +124,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} queryParams = append(queryParams, v) } query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", len({{$arg.VariableForField .}}))[1:], 1) + {{- if .HasSqlcDynamic }} + curNumb += len({{$arg.VariableForField .}}) + {{- end}} } else { query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) } + {{- else if .HasSqlcDynamic }} + replaceText, args := {{$arg.VariableForField .}}.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/?", "replaceText", 1) + queryParams = append(queryParams, v) {{- else }} queryParams = append(queryParams, {{$arg.VariableForField .}}) {{- end }} {{- end }} {{- else }} + {{- if .Arg.HasSqlcDynamic }} + var replaceText string + replaceText, queryParams = {{$arg.VariableForField .}}.ToSql(curNumb) + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.VariableForField }}*/?", "replaceText", 1) + {{- else }} {{- /* Single argument parameter to this goroutine (they are not packed in a struct), because .Arg.HasSqlcSlices further up above was true, this section is 100% a slice (impossible to get here otherwise). @@ -141,6 +157,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} } else { query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", "NULL", 1) } + {{- end }} {{- end }} {{- if emitPreparedQueries }} {{ queryRetval . }} {{ queryMethod . }}(ctx, nil, query, queryParams...) diff --git a/internal/compiler/query.go b/internal/compiler/query.go index 117cf44813..95aa903970 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -34,7 +34,8 @@ type Column struct { Type *ast.TypeName EmbedTable *ast.TableName - IsSqlcSlice bool // is this sqlc.slice() + IsSqlcSlice bool // is this sqlc.slice() + IsSqlcDynamic bool // is this sqlc.dynamic() skipTableRequiredCheck bool } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index d706aef31b..9eac6471d8 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -126,7 +126,6 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para } return nil } - for _, rv := range ref.rvs { if rv == nil || rv.Relname == nil { continue @@ -209,11 +208,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: dataType, - IsNamedParam: isNamed, - NotNull: p.NotNull(), - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: dataType, + IsNamedParam: isNamed, + NotNull: p.NotNull(), + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) continue @@ -272,17 +272,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Length: c.Length, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Length: c.Length, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -339,15 +340,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: namePrefix + p.Name(), - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: namePrefix + p.Name(), + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -412,11 +414,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: "any", - IsNamedParam: isNamed, - NotNull: p.NotNull(), - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: "any", + IsNamedParam: isNamed, + NotNull: p.NotNull(), + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) continue @@ -453,11 +456,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: dataType(paramType), - NotNull: p.NotNull(), - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: dataType(paramType), + NotNull: p.NotNull(), + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -515,17 +519,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: &ast.TableName{Schema: schema, Name: rel}, - Length: c.Length, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: &ast.TableName{Schema: schema, Name: rel}, + Length: c.Length, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } else { @@ -626,16 +631,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: c.IsNotNull, - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: c.IsNotNull, + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } diff --git a/internal/endtoend/testdata/params_location/mysql/go/query.sql.go b/internal/endtoend/testdata/params_location/mysql/go/query.sql.go index ea46901a84..f86769a734 100644 --- a/internal/endtoend/testdata/params_location/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/params_location/mysql/go/query.sql.go @@ -255,3 +255,40 @@ func (q *Queries) ListUsersWithLimit(ctx context.Context, limit int32) ([]ListUs } return items, nil } + +const searchByName = `-- name: SearchByName :many +SELECT id, first_name, last_name, age, job_status FROM users WHERE (first_name = ? OR last_name = ?) +` + +type SearchByNameParams struct { + Name sql.NullString +} + +func (q *Queries) SearchByName(ctx context.Context, arg SearchByNameParams) ([]User, error) { + rows, err := q.db.QueryContext(ctx, searchByName, arg.Name, arg.Name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Age, + &i.JobStatus, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/params_location/mysql/query.sql b/internal/endtoend/testdata/params_location/mysql/query.sql index 824583afa8..0b1bde563f 100644 --- a/internal/endtoend/testdata/params_location/mysql/query.sql +++ b/internal/endtoend/testdata/params_location/mysql/query.sql @@ -45,3 +45,6 @@ SELECT * FROM users WHERE (job_status = 'APPLIED' OR job_status = 'PENDING') AND id > ? ORDER BY id LIMIT ?; + +/* name: SearchByName :many */ +SELECT * FROM users WHERE (first_name = sqlc.narg(name) OR last_name = sqlc.narg(name)); diff --git a/internal/plugin/codegen.pb.go b/internal/plugin/codegen.pb.go index b735e90357..d2af16c6f7 100644 --- a/internal/plugin/codegen.pb.go +++ b/internal/plugin/codegen.pb.go @@ -1162,15 +1162,16 @@ type Column struct { IsNamedParam bool `protobuf:"varint,7,opt,name=is_named_param,json=isNamedParam,proto3" json:"is_named_param,omitempty"` IsFuncCall bool `protobuf:"varint,8,opt,name=is_func_call,json=isFuncCall,proto3" json:"is_func_call,omitempty"` // XXX: Figure out what PostgreSQL calls `foo.id` - Scope string `protobuf:"bytes,9,opt,name=scope,proto3" json:"scope,omitempty"` - Table *Identifier `protobuf:"bytes,10,opt,name=table,proto3" json:"table,omitempty"` - TableAlias string `protobuf:"bytes,11,opt,name=table_alias,json=tableAlias,proto3" json:"table_alias,omitempty"` - Type *Identifier `protobuf:"bytes,12,opt,name=type,proto3" json:"type,omitempty"` - IsSqlcSlice bool `protobuf:"varint,13,opt,name=is_sqlc_slice,json=isSqlcSlice,proto3" json:"is_sqlc_slice,omitempty"` - EmbedTable *Identifier `protobuf:"bytes,14,opt,name=embed_table,json=embedTable,proto3" json:"embed_table,omitempty"` - OriginalName string `protobuf:"bytes,15,opt,name=original_name,json=originalName,proto3" json:"original_name,omitempty"` - Unsigned bool `protobuf:"varint,16,opt,name=unsigned,proto3" json:"unsigned,omitempty"` - ArrayDims int32 `protobuf:"varint,17,opt,name=array_dims,json=arrayDims,proto3" json:"array_dims,omitempty"` + Scope string `protobuf:"bytes,9,opt,name=scope,proto3" json:"scope,omitempty"` + Table *Identifier `protobuf:"bytes,10,opt,name=table,proto3" json:"table,omitempty"` + TableAlias string `protobuf:"bytes,11,opt,name=table_alias,json=tableAlias,proto3" json:"table_alias,omitempty"` + Type *Identifier `protobuf:"bytes,12,opt,name=type,proto3" json:"type,omitempty"` + IsSqlcSlice bool `protobuf:"varint,13,opt,name=is_sqlc_slice,json=isSqlcSlice,proto3" json:"is_sqlc_slice,omitempty"` + IsSqlcDynamic bool `protobuf:"varint,13,opt,name=is_sqlc_dynamic,json=isSqlcDynamic,proto3" json:"is_sqlc_dynamic,omitempty"` + EmbedTable *Identifier `protobuf:"bytes,14,opt,name=embed_table,json=embedTable,proto3" json:"embed_table,omitempty"` + OriginalName string `protobuf:"bytes,15,opt,name=original_name,json=originalName,proto3" json:"original_name,omitempty"` + Unsigned bool `protobuf:"varint,16,opt,name=unsigned,proto3" json:"unsigned,omitempty"` + ArrayDims int32 `protobuf:"varint,17,opt,name=array_dims,json=arrayDims,proto3" json:"array_dims,omitempty"` } func (x *Column) Reset() { @@ -1288,6 +1289,12 @@ func (x *Column) GetIsSqlcSlice() bool { } return false } +func (x *Column) GetIsSqlcDynamic() bool { + if x != nil { + return x.IsSqlcDynamic + } + return false +} func (x *Column) GetEmbedTable() *Identifier { if x != nil { diff --git a/internal/sql/named/is.go b/internal/sql/named/is.go index d53c1d9905..5d3d1d1544 100644 --- a/internal/sql/named/is.go +++ b/internal/sql/named/is.go @@ -16,7 +16,8 @@ func IsParamFunc(node ast.Node) bool { return false } - isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "arg" || call.Func.Name == "narg" || call.Func.Name == "slice") + // TODO + isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "arg" || call.Func.Name == "narg" || call.Func.Name == "slice" || call.Func.Name == "dynamic") return isValid } diff --git a/internal/sql/named/param.go b/internal/sql/named/param.go index 42f9b855a3..7667008c8b 100644 --- a/internal/sql/named/param.go +++ b/internal/sql/named/param.go @@ -42,9 +42,10 @@ func (n nullability) String() string { // - named parameter operator @param // - named parameter function calls sqlc.arg(param) type Param struct { - name string - nullability nullability - isSqlcSlice bool + name string + nullability nullability + isSqlcSlice bool + isSqlcDynamic bool } // NewParam builds a new params with unspecified nullability @@ -72,6 +73,11 @@ func NewSqlcSlice(name string) Param { return Param{name: name, nullability: nullUnspecified, isSqlcSlice: true} } +// NewSqlcDynamic is a sqlc.dynamic() parameter. +func NewSqlcDynamic(name string) Param { + return Param{name: name, nullability: notNullable, isSqlcDynamic: true} +} + // Name is the user defined name to use for this parameter func (p Param) Name() string { return p.name @@ -113,6 +119,11 @@ func (p Param) IsSqlcSlice() bool { return p.isSqlcSlice } +// IsSlice returns whether this param is a sqlc.dynamic() param. +func (p Param) IsSqlcDynamic() bool { + return p.isSqlcDynamic +} + // mergeParam creates a new param from 2 partially specified params // If the parameters have different names, the first is preferred func mergeParam(a, b Param) Param { @@ -122,8 +133,9 @@ func mergeParam(a, b Param) Param { } return Param{ - name: name, - nullability: a.nullability | b.nullability, - isSqlcSlice: a.isSqlcSlice || b.isSqlcSlice, + name: name, + nullability: a.nullability | b.nullability, + isSqlcSlice: a.isSqlcSlice || b.isSqlcSlice, + isSqlcDynamic: a.isSqlcDynamic && b.isSqlcDynamic, } } diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index d1ea1a22cc..ccbd000e23 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -61,6 +61,8 @@ func paramFromFuncCall(call *ast.FuncCall) (named.Param, string) { param = named.NewUserNullableParam(paramName) case "slice": param = named.NewSqlcSlice(paramName) + case "dynamic": + param = named.NewSqlcDynamic(paramName) default: param = named.NewParam(paramName) } @@ -107,6 +109,8 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, // This sequence is also replicated in internal/codegen/golang.Field // since it's needed during template generation for replacement replace = fmt.Sprintf(`/*SLICE:%s*/?`, param.Name()) + } else if param.IsSqlcDynamic() { + replace = fmt.Sprintf(`/*DYNAMIC:%s*/?`, param.Name()) } else { if engine == config.EngineSQLite { replace = fmt.Sprintf("?%d", argn) diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index 383366c68f..66b47fb992 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -34,7 +34,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { // Custom validation for sqlc.arg, sqlc.narg and sqlc.slice // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed") { + if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed" || fn.Name == "dynamic") { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } From f755265cfa64ba2f0f4215153b006686d6266f74 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sat, 7 Oct 2023 21:23:23 -0700 Subject: [PATCH 05/12] Something --- internal/codegen/golang/query.go | 6 +- .../golang/templates/stdlib/queryCode.tmpl | 2 +- internal/compiler/resolve.go | 16 +- .../endtoend/testdata/dynamic/mysql/go/db.go | 31 +++ .../testdata/dynamic/mysql/go/models.go | 23 ++ .../testdata/dynamic/mysql/go/query.sql.go | 139 ++++++++++++ .../endtoend/testdata/dynamic/mysql/query.sql | 25 +++ .../endtoend/testdata/dynamic/mysql/sqlc.json | 12 ++ .../endtoend/testdata/dynamic/stdlib/go/db.go | 31 +++ .../testdata/dynamic/stdlib/go/models.go | 23 ++ .../testdata/dynamic/stdlib/go/query.sql.go | 197 ++++++++++++++++++ .../testdata/dynamic/stdlib/query.sql | 34 +++ .../testdata/dynamic/stdlib/sqlc.json | 12 ++ internal/source/code.go | 3 + internal/sql/ast/param_ref.go | 7 +- internal/sql/rewrite/parameters.go | 11 +- 16 files changed, 559 insertions(+), 13 deletions(-) create mode 100644 internal/endtoend/testdata/dynamic/mysql/go/db.go create mode 100644 internal/endtoend/testdata/dynamic/mysql/go/models.go create mode 100644 internal/endtoend/testdata/dynamic/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/dynamic/mysql/query.sql create mode 100644 internal/endtoend/testdata/dynamic/mysql/sqlc.json create mode 100644 internal/endtoend/testdata/dynamic/stdlib/go/db.go create mode 100644 internal/endtoend/testdata/dynamic/stdlib/go/models.go create mode 100644 internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go create mode 100644 internal/endtoend/testdata/dynamic/stdlib/query.sql create mode 100644 internal/endtoend/testdata/dynamic/stdlib/sqlc.json diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index befea871e5..431d3ac8b1 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -179,12 +179,12 @@ func (v QueryValue) HasSqlcDynamic() bool { return false } func (v QueryValue) SqlcDynamic() int { - var count int + var count int = 1 if v.Struct == nil { - return 0 + return 1 } for _, v := range v.Struct.Fields { - if !v.Column.IsSqlcDynamic && !v.Column.IsSqlcSlice { + if !v.Column.IsSqlcDynamic { count++ } } diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index b2210081c4..b12660ce85 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -134,7 +134,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} replaceText, args := {{$arg.VariableForField .}}.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/?", "replaceText", 1) - queryParams = append(queryParams, v) + queryParams = append(queryParams, args...) {{- else }} queryParams = append(queryParams, {{$arg.VariableForField .}}) {{- end }} diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 9eac6471d8..3bf291f8e2 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -101,11 +101,25 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para // resolve a table for an embed var a []Parameter for _, ref := range args { + if ref.ref.IsSqlcDynamic { + defaultP := named.NewInferredParam("offset", true) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) + a = append(a, Parameter{ + Number: ref.ref.Number, + Column: &Column{ + Name: p.Name(), + DataType: "DynamicSql", + NotNull: p.NotNull(), + IsNamedParam: isNamed, + IsSqlcDynamic: true, + }, + }) + continue + } aliasMap := map[string]*ast.TableName{} // TODO: Deprecate defaultTable var defaultTable *ast.TableName var tables []*ast.TableName - typeMap := map[string]map[string]map[string]*catalog.Column{} indexTable := func(table catalog.Table) error { tables = append(tables, table.Rel) diff --git a/internal/endtoend/testdata/dynamic/mysql/go/db.go b/internal/endtoend/testdata/dynamic/mysql/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dynamic/mysql/go/models.go b/internal/endtoend/testdata/dynamic/mysql/go/models.go new file mode 100644 index 0000000000..1f67f50bb3 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/go/models.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql" +) + +type Order struct { + ID int32 + Price string + UserID int32 +} + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 + JobStatus string +} diff --git a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go new file mode 100644 index 0000000000..597942bd98 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go @@ -0,0 +1,139 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const selectUsers = `-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > ? +` + +type SelectUsersRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, error) { + rows, err := q.db.QueryContext(ctx, selectUsers, age) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersRow + for rows.Next() { + var i SelectUsersRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic = `-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > ? AND /*DYNAMIC:dynamic*/? +` + +type SelectUsersDynamicParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { + query := selectUsersDynamic + var queryParams []interface{} + curNumb := 2 + queryParams = append(queryParams, arg.Age) + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) + queryParams = append(queryParams, v) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicRow + for rows.Next() { + var i SelectUsersDynamicRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > ? AND + job_status = ? AND + /*DYNAMIC:dynamic*/? +` + +type SelectUsersDynamic2Params struct { + Age int32 + Status string + Dynamic DynamicSql +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + query := selectUsersDynamic2 + var queryParams []interface{} + curNumb := 3 + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) + queryParams = append(queryParams, v) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/mysql/query.sql b/internal/endtoend/testdata/dynamic/mysql/query.sql new file mode 100644 index 0000000000..6a90908f81 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/query.sql @@ -0,0 +1,25 @@ +CREATE TABLE users ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL, + job_status varchar(10) NOT NULL +); + +CREATE TABLE orders ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + price DECIMAL(13, 4) NOT NULL, + user_id integer NOT NULL +); + +-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); +-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic'); diff --git a/internal/endtoend/testdata/dynamic/mysql/sqlc.json b/internal/endtoend/testdata/dynamic/mysql/sqlc.json new file mode 100644 index 0000000000..bfbd23e211 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "query.sql", + "queries": "query.sql", + "engine": "mysql" + } + ] +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/db.go b/internal/endtoend/testdata/dynamic/stdlib/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/models.go b/internal/endtoend/testdata/dynamic/stdlib/go/models.go new file mode 100644 index 0000000000..1f67f50bb3 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/go/models.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql" +) + +type Order struct { + ID int32 + Price string + UserID int32 +} + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 + JobStatus string +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go new file mode 100644 index 0000000000..a395125b61 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go @@ -0,0 +1,197 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" + "strings" +) + +const selectUsers = `-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > $1 +` + +type SelectUsersRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, error) { + rows, err := q.db.QueryContext(ctx, selectUsers, age) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersRow + for rows.Next() { + var i SelectUsersRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic = `-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > $1 AND sqlc.dynamic('dynamic') +` + +type SelectUsersDynamicParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { + query := selectUsersDynamic + var queryParams []interface{} + curNumb := 2 + queryParams = append(queryParams, arg.Age) + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicRow + for rows.Next() { + var i SelectUsersDynamicRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + sqlc.dynamic('dynamic') +` + +type SelectUsersDynamic2Params struct { + Age int32 + Status string + Dynamic DynamicSql +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + query := selectUsersDynamic2 + var queryParams []interface{} + curNumb := 3 + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicMulti = `-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + sqlc.dynamic('dynamic') +ORDER BY sqlc.dynamic('order') +` + +type SelectUsersDynamicMultiParams struct { + Age int32 + Status string + Dynamic DynamicSql + Order DynamicSql +} + +type SelectUsersDynamicMultiRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDynamicMultiParams) ([]SelectUsersDynamicMultiRow, error) { + query := selectUsersDynamicMulti + var queryParams []interface{} + curNumb := 3 + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) + queryParams = append(queryParams, args...) + replaceText, args := arg.Order.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/?", "replaceText", 1) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicMultiRow + for rows.Next() { + var i SelectUsersDynamicMultiRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/query.sql b/internal/endtoend/testdata/dynamic/stdlib/query.sql new file mode 100644 index 0000000000..dca8604f7a --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/query.sql @@ -0,0 +1,34 @@ +CREATE TABLE users ( + id int PRIMARY KEY, + first_name text NOT NULL, + last_name text, + age int NOT NULL, + job_status text NOT NULL +); + +CREATE TABLE orders ( + id int PRIMARY KEY, + price numeric NOT NULL, + user_id int NOT NULL +); + +-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); + +-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic') +ORDER BY sqlc.dynamic('order'); diff --git a/internal/endtoend/testdata/dynamic/stdlib/sqlc.json b/internal/endtoend/testdata/dynamic/stdlib/sqlc.json new file mode 100644 index 0000000000..696ed223db --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "query.sql", + "queries": "query.sql", + "engine": "postgresql" + } + ] +} diff --git a/internal/source/code.go b/internal/source/code.go index 9a6ed077d3..551308e280 100644 --- a/internal/source/code.go +++ b/internal/source/code.go @@ -59,6 +59,9 @@ func Mutate(raw string, a []Edit) (string, error) { s := raw for idx, edit := range a { + if strings.Contains(edit.Old, "sqlc.dynamic") { + continue + } start := edit.Location if start > len(s) || start < 0 { return "", fmt.Errorf("edit start location is out of bounds") diff --git a/internal/sql/ast/param_ref.go b/internal/sql/ast/param_ref.go index d0f486cf85..c155c1c2b2 100644 --- a/internal/sql/ast/param_ref.go +++ b/internal/sql/ast/param_ref.go @@ -1,9 +1,10 @@ package ast type ParamRef struct { - Number int - Location int - Dollar bool + Number int + Location int + Dollar bool + IsSqlcDynamic bool } func (n *ParamRef) Pos() int { diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index ccbd000e23..c61cc0e898 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -99,18 +99,19 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, param, origText := paramFromFuncCall(fun) argn := allParams.Add(param) cr.Replace(&ast.ParamRef{ - Number: argn, - Location: fun.Location, + Number: argn, + Location: fun.Location, + IsSqlcDynamic: param.IsSqlcDynamic(), }) var replace string - if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { + if param.IsSqlcDynamic() { + replace = fmt.Sprintf(`/*DYNAMIC:%s*/?`, param.Name()) + } else if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { if param.IsSqlcSlice() { // This sequence is also replicated in internal/codegen/golang.Field // since it's needed during template generation for replacement replace = fmt.Sprintf(`/*SLICE:%s*/?`, param.Name()) - } else if param.IsSqlcDynamic() { - replace = fmt.Sprintf(`/*DYNAMIC:%s*/?`, param.Name()) } else { if engine == config.EngineSQLite { replace = fmt.Sprintf("?%d", argn) From 411b4e4869c3c03ff69deb8105aed038c6c120c0 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Fri, 13 Oct 2023 10:29:56 -0700 Subject: [PATCH 06/12] Progress I don't want to loose --- internal/codegen/golang/gen.go | 4 + .../golang/templates/stdlib/queryCode.tmpl | 6 +- internal/compiler/analyze.go | 3 +- internal/compiler/parse.go | 3 +- .../endtoend/testdata/dynamic/stdlib/go/db.go | 2 +- .../testdata/dynamic/stdlib/go/models.go | 2 +- .../testdata/dynamic/stdlib/go/query.sql.go | 163 ++++++++++++++---- .../testdata/dynamic/stdlib/query.sql | 12 +- internal/source/code.go | 3 - internal/sql/rewrite/parameters.go | 84 ++++++--- internal/sql/validate/param_style.go | 2 +- 11 files changed, 211 insertions(+), 73 deletions(-) diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 71f8c25aa7..e3b6c14954 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -11,6 +11,7 @@ import ( "text/template" "github.com/sqlc-dev/sqlc/internal/codegen/sdk" + "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/metadata" "github.com/sqlc-dev/sqlc/internal/plugin" ) @@ -181,6 +182,9 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [ "emitPreparedQueries": tctx.codegenEmitPreparedQueries, "queryMethod": tctx.codegenQueryMethod, "queryRetval": tctx.codegenQueryRetval, + "dollar": func() bool { + return req.Settings.Engine == string(config.EnginePostgreSQL) + }, } tmpl := template.Must( diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 8e058084d0..c754d68d1d 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -133,7 +133,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{- else if .HasSqlcDynamic }} replaceText, args := {{$arg.VariableForField .}}.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/?", "replaceText", 1) + query = strings.ReplaceAllString(query, "/*DYNAMIC:{{.Column.Name}}*/{{- if dollar }}$1{{ else }}?{{ end }}", replaceText) queryParams = append(queryParams, args...) {{- else }} queryParams = append(queryParams, {{$arg.VariableForField .}}) @@ -142,8 +142,8 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{- else }} {{- if .Arg.HasSqlcDynamic }} var replaceText string - replaceText, queryParams = {{$arg.VariableForField .}}.ToSql(curNumb) - query = strings.ReplaceAll(query, "/*DYNAMIC:{{.VariableForField }}*/?", "replaceText", 1) + replaceText, queryParams = {{ .Arg.Column.Name}}.ToSql(curNumb) + query = strings.ReplaceAllString(query, "/*DYNAMIC:{{.Arg.Column.Name}}*/?", "replaceText") {{- else }} {{- /* Single argument parameter to this goroutine (they are not packed in a struct), because .Arg.HasSqlcSlices further up above was true, diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 739cd07993..8146477abf 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -140,7 +140,6 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(validate.In(c.catalog, raw)); err != nil { return nil, err } - rvs := rangeVars(raw.Stmt) refs, errs := findParameters(raw.Stmt) if len(errs) > 0 { if failfast { @@ -160,7 +159,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } - params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) + params, err := c.resolveCatalogRefs(qc, refs, namedParams) if err := check(err); err != nil { return nil, err } diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 53e3043c7d..2d377368de 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -77,11 +77,10 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, } expanded := anlys.Query - // If the query string was edited, make sure the syntax is valid if expanded != rawSQL { if _, err := c.parser.Parse(strings.NewReader(expanded)); err != nil { - return nil, fmt.Errorf("edited query syntax is invalid: %w", err) + return nil, fmt.Errorf("edited query syntax is invalid: %w - %s", err, expanded) } } diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/db.go b/internal/endtoend/testdata/dynamic/stdlib/go/db.go index 57406b68e8..a457fb76b2 100644 --- a/internal/endtoend/testdata/dynamic/stdlib/go/db.go +++ b/internal/endtoend/testdata/dynamic/stdlib/go/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 package querytest diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/models.go b/internal/endtoend/testdata/dynamic/stdlib/go/models.go index 1f67f50bb3..b5f5c9b7ed 100644 --- a/internal/endtoend/testdata/dynamic/stdlib/go/models.go +++ b/internal/endtoend/testdata/dynamic/stdlib/go/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 package querytest diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go index a395125b61..7c37cdb9f2 100644 --- a/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 // source: query.sql package querytest @@ -8,7 +8,6 @@ package querytest import ( "context" "database/sql" - "strings" ) const selectUsers = `-- name: SelectUsers :many @@ -43,37 +42,88 @@ func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, return items, nil } -const selectUsersDynamic = `-- name: SelectUsersDynamic :many -SELECT first_name, last_name FROM users WHERE age > $1 AND sqlc.dynamic('dynamic') +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamic2Params struct { + Age int32 + Status string + Dynamic DynamicSql +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + query := selectUsersDynamic2 + var queryParams []interface{} + curNumb := 3 + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicA = `-- name: SelectUsersDynamicA :many +SELECT first_name, last_name FROM users WHERE age > $1 AND /*DYNAMIC:dynamic*/$1 ` -type SelectUsersDynamicParams struct { +type SelectUsersDynamicAParams struct { Age int32 Dynamic DynamicSql } -type SelectUsersDynamicRow struct { +type SelectUsersDynamicARow struct { FirstName string LastName sql.NullString } -func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { - query := selectUsersDynamic +func (q *Queries) SelectUsersDynamicA(ctx context.Context, arg SelectUsersDynamicAParams) ([]SelectUsersDynamicARow, error) { + query := selectUsersDynamicA var queryParams []interface{} curNumb := 2 queryParams = append(queryParams, arg.Age) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { return nil, err } defer rows.Close() - var items []SelectUsersDynamicRow + var items []SelectUsersDynamicARow for rows.Next() { - var i SelectUsersDynamicRow + var i SelectUsersDynamicARow if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { return nil, err } @@ -88,43 +138,37 @@ func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamic return items, nil } -const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many -SELECT first_name, last_name -FROM users -WHERE age > $1 AND - job_status = $2 AND - sqlc.dynamic('dynamic') +const selectUsersDynamicB = `-- name: SelectUsersDynamicB :many +SELECT first_name, last_name FROM users WHERE /*DYNAMIC:dynamic*/$1 AND age > $1 ` -type SelectUsersDynamic2Params struct { +type SelectUsersDynamicBParams struct { Age int32 - Status string Dynamic DynamicSql } -type SelectUsersDynamic2Row struct { +type SelectUsersDynamicBRow struct { FirstName string LastName sql.NullString } -func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { - query := selectUsersDynamic2 +func (q *Queries) SelectUsersDynamicB(ctx context.Context, arg SelectUsersDynamicBParams) ([]SelectUsersDynamicBRow, error) { + query := selectUsersDynamicB var queryParams []interface{} - curNumb := 3 + curNumb := 2 queryParams = append(queryParams, arg.Age) - queryParams = append(queryParams, arg.Status) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { return nil, err } defer rows.Close() - var items []SelectUsersDynamic2Row + var items []SelectUsersDynamicBRow for rows.Next() { - var i SelectUsersDynamic2Row + var i SelectUsersDynamicBRow if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { return nil, err } @@ -144,8 +188,8 @@ SELECT first_name, last_name FROM users WHERE age > $1 AND job_status = $2 AND - sqlc.dynamic('dynamic') -ORDER BY sqlc.dynamic('order') + /*DYNAMIC:dynamic*/$1 +ORDER BY /*DYNAMIC:order*/$1 ` type SelectUsersDynamicMultiParams struct { @@ -168,11 +212,11 @@ func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDy queryParams = append(queryParams, arg.Status) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) queryParams = append(queryParams, args...) replaceText, args := arg.Order.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:order*/?", "replaceText", 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText, 1) queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { @@ -195,3 +239,60 @@ func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDy } return items, nil } + +const selectUsersDynamicMultiB = `-- name: SelectUsersDynamicMultiB :many +SELECT first_name, last_name +FROM users +WHERE /*DYNAMIC:dynamic*/$1 AND + age > $1 AND + job_status = $2 +ORDER BY /*DYNAMIC:order*/$1 +` + +type SelectUsersDynamicMultiBParams struct { + Age int32 + Status string + Dynamic DynamicSql + Order DynamicSql +} + +type SelectUsersDynamicMultiBRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamicMultiB(ctx context.Context, arg SelectUsersDynamicMultiBParams) ([]SelectUsersDynamicMultiBRow, error) { + query := selectUsersDynamicMultiB + var queryParams []interface{} + curNumb := 3 + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) + queryParams = append(queryParams, args...) + replaceText, args := arg.Order.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText, 1) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicMultiBRow + for rows.Next() { + var i SelectUsersDynamicMultiBRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/query.sql b/internal/endtoend/testdata/dynamic/stdlib/query.sql index dca8604f7a..12abf6ab8e 100644 --- a/internal/endtoend/testdata/dynamic/stdlib/query.sql +++ b/internal/endtoend/testdata/dynamic/stdlib/query.sql @@ -15,8 +15,10 @@ CREATE TABLE orders ( -- name: SelectUsers :many SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); --- name: SelectUsersDynamic :many +-- name: SelectUsersDynamicA :many SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); +-- name: SelectUsersDynamicB :many +SELECT first_name, last_name FROM users WHERE sqlc.dynamic('dynamic') AND age > sqlc.arg(age); -- name: SelectUsersDynamic2 :many SELECT first_name, last_name @@ -32,3 +34,11 @@ WHERE age > sqlc.arg(age) AND job_status = sqlc.arg(status) AND sqlc.dynamic('dynamic') ORDER BY sqlc.dynamic('order'); + +-- name: SelectUsersDynamicMultiB :many +SELECT first_name, last_name +FROM users +WHERE sqlc.dynamic('dynamic') AND + age > sqlc.arg(age) AND + job_status = sqlc.arg(status) +ORDER BY sqlc.dynamic('order'); diff --git a/internal/source/code.go b/internal/source/code.go index a243bf95b1..f34e3e3684 100644 --- a/internal/source/code.go +++ b/internal/source/code.go @@ -60,9 +60,6 @@ func Mutate(raw string, a []Edit) (string, error) { s := raw for idx, edit := range a { - if strings.Contains(edit.Old, "sqlc.dynamic") { - continue - } start := edit.Location if start > len(s) || start < 0 { return "", fmt.Errorf("edit start location is out of bounds") diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index c61cc0e898..232b0a6085 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -97,37 +97,37 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamFunc(node): fun := node.(*ast.FuncCall) param, origText := paramFromFuncCall(fun) - argn := allParams.Add(param) - cr.Replace(&ast.ParamRef{ - Number: argn, - Location: fun.Location, - IsSqlcDynamic: param.IsSqlcDynamic(), - }) - - var replace string - if param.IsSqlcDynamic() { - replace = fmt.Sprintf(`/*DYNAMIC:%s*/?`, param.Name()) - } else if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { - if param.IsSqlcSlice() { - // This sequence is also replicated in internal/codegen/golang.Field - // since it's needed during template generation for replacement - replace = fmt.Sprintf(`/*SLICE:%s*/?`, param.Name()) - } else { - if engine == config.EngineSQLite { - replace = fmt.Sprintf("?%d", argn) + if !param.IsSqlcDynamic() { + argn := allParams.Add(param) + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: fun.Location, + IsSqlcDynamic: param.IsSqlcDynamic(), + }) + + var replace string + if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { + if param.IsSqlcSlice() { + // This sequence is also replicated in internal/codegen/golang.Field + // since it's needed during template generation for replacement + replace = fmt.Sprintf(`/*SLICE:%s*/?`, param.Name()) } else { - replace = "?" + if engine == config.EngineSQLite { + replace = fmt.Sprintf("?%d", argn) + } else { + replace = "?" + } } + } else { + replace = fmt.Sprintf("$%d", argn) } - } else { - replace = fmt.Sprintf("$%d", argn) - } - edits = append(edits, source.Edit{ - Location: fun.Location - raw.StmtLocation, - Old: origText, - New: replace, - }) + edits = append(edits, source.Edit{ + Location: fun.Location - raw.StmtLocation, + Old: origText, + New: replace, + }) + } return false case isNamedParamSignCast(node): @@ -192,6 +192,34 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, return true } }, nil) - + node = astutils.Apply(node, func(cr *astutils.Cursor) bool { + node := cr.Node() + if named.IsParamFunc(node) { + fun := node.(*ast.FuncCall) + param, origText := paramFromFuncCall(fun) + if param.IsSqlcDynamic() { + argn := allParams.Add(param) + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: fun.Location, + IsSqlcDynamic: param.IsSqlcDynamic(), + }) + + var replace string + if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { + replace = fmt.Sprintf(`/*DYNAMIC:%s*/?`, param.Name()) + } else { + replace = fmt.Sprintf(`/*DYNAMIC:%s*/$1`, param.Name()) + } + edits = append(edits, source.Edit{ + Location: fun.Location - raw.StmtLocation, + Old: origText, + New: replace, + }) + } + return false + } + return true + }, nil) return node.(*ast.RawStmt), allParams, edits } diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index 1182051d20..ce3a0f1aca 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -29,7 +29,7 @@ func (v *sqlcFuncVisitor) Visit(node ast.Node) astutils.Visitor { // Custom validation for sqlc.arg, sqlc.narg and sqlc.slice // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed") { + if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed" || fn.Name == "dynamic") { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } From 6bdd617ad739c47981ad241b45fd4e7723f42f6c Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Fri, 13 Oct 2023 17:31:11 -0700 Subject: [PATCH 07/12] Need to fix internal/compiler/resolve.go and also do pbx and also the db file --- internal/compiler/analyze.go | 3 +- internal/compiler/resolve.go | 202 ++++++------------ .../endtoend/testdata/dynamic/mysql/go/db.go | 2 +- .../testdata/dynamic/mysql/go/models.go | 2 +- .../testdata/dynamic/mysql/go/query.sql.go | 3 +- .../testdata/subquery_with_where/go/db.go | 2 +- .../testdata/subquery_with_where/go/models.go | 2 +- .../subquery_with_where/go/query.sql.go | 2 +- 8 files changed, 70 insertions(+), 148 deletions(-) diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 8146477abf..739cd07993 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -140,6 +140,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(validate.In(c.catalog, raw)); err != nil { return nil, err } + rvs := rangeVars(raw.Stmt) refs, errs := findParameters(raw.Stmt) if len(errs) > 0 { if failfast { @@ -159,7 +160,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } - params, err := c.resolveCatalogRefs(qc, refs, namedParams) + params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) if err := check(err); err != nil { return nil, err } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 33507f42fb..4624c5a45d 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -21,7 +21,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar, embeds rewrite.EmbedSet) error { +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -56,7 +56,7 @@ func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar } fqn, err := ParseTableName(rv) if err != nil { - return err + return nil, err } if _, found := aliasMap[fqn.Name]; found { continue @@ -65,13 +65,13 @@ func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar if err != nil { // If the table name doesn't exist, fisrt check if it's a CTE if _, qcerr := qc.GetTable(fqn); qcerr != nil { - return err + return nil, err } continue } err = indexTable(table) if err != nil { - return err + return nil, err } if rv.Alias != nil { aliasMap[*rv.Alias.Aliasname] = fqn @@ -91,84 +91,11 @@ func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar continue } - return fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) + return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) } - return nil -} -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, params *named.ParamSet) ([]Parameter, error) { - c := comp.catalog - - // resolve a table for an embed var a []Parameter for _, ref := range args { - if ref.ref.IsSqlcDynamic { - defaultP := named.NewInferredParam("offset", true) - p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) - a = append(a, Parameter{ - Number: ref.ref.Number, - Column: &Column{ - Name: p.Name(), - DataType: "DynamicSql", - NotNull: p.NotNull(), - IsNamedParam: isNamed, - IsSqlcDynamic: true, - }, - }) - continue - } - aliasMap := map[string]*ast.TableName{} - // TODO: Deprecate defaultTable - var defaultTable *ast.TableName - var tables []*ast.TableName - typeMap := map[string]map[string]map[string]*catalog.Column{} - indexTable := func(table catalog.Table) error { - tables = append(tables, table.Rel) - if defaultTable == nil { - defaultTable = table.Rel - } - schema := table.Rel.Schema - if schema == "" { - schema = c.DefaultSchema - } - if _, exists := typeMap[schema]; !exists { - typeMap[schema] = map[string]map[string]*catalog.Column{} - } - typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{} - for _, c := range table.Columns { - cc := c - typeMap[schema][table.Rel.Name][c.Name] = cc - } - return nil - } - for _, rv := range ref.rvs { - if rv == nil || rv.Relname == nil { - continue - } - fqn, err := ParseTableName(rv) - if err != nil { - return nil, err - } - if _, found := aliasMap[fqn.Name]; found { - continue - } - table, err := c.GetTable(fqn) - if err != nil { - // If the table name doesn't exist, fisrt check if it's a CTE - if _, qcerr := qc.GetTable(fqn); qcerr != nil { - return nil, err - } - continue - } - err = indexTable(table) - if err != nil { - return nil, err - } - if rv.Alias != nil { - aliasMap[*rv.Alias.Aliasname] = fqn - } - } - switch n := ref.parent.(type) { case *limitOffset: @@ -223,12 +150,11 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: dataType, - IsNamedParam: isNamed, - NotNull: p.NotNull(), - IsSqlcSlice: p.IsSqlcSlice(), - IsSqlcDynamic: p.IsSqlcDynamic(), + Name: p.Name(), + DataType: dataType, + IsNamedParam: isNamed, + NotNull: p.NotNull(), + IsSqlcSlice: p.IsSqlcSlice(), }, }) continue @@ -287,18 +213,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Length: c.Length, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), - IsSqlcDynamic: p.IsSqlcDynamic(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Length: c.Length, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), }, }) } @@ -355,16 +280,15 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: namePrefix + p.Name(), - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), - IsSqlcDynamic: p.IsSqlcDynamic(), + Name: namePrefix + p.Name(), + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), }, }) } @@ -429,12 +353,11 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: "any", - IsNamedParam: isNamed, - NotNull: p.NotNull(), - IsSqlcSlice: p.IsSqlcSlice(), - IsSqlcDynamic: p.IsSqlcDynamic(), + Name: p.Name(), + DataType: "any", + IsNamedParam: isNamed, + NotNull: p.NotNull(), + IsSqlcSlice: p.IsSqlcSlice(), }, }) continue @@ -471,12 +394,11 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: dataType(paramType), - NotNull: p.NotNull(), - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), - IsSqlcDynamic: p.IsSqlcDynamic(), + Name: p.Name(), + DataType: dataType(paramType), + NotNull: p.NotNull(), + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), }, }) } @@ -534,18 +456,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: &ast.TableName{Schema: schema, Name: rel}, - Length: c.Length, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), - IsSqlcDynamic: p.IsSqlcDynamic(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: &ast.TableName{Schema: schema, Name: rel}, + Length: c.Length, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), }, }) } else { @@ -646,17 +567,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para a = append(a, Parameter{ Number: number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: c.IsNotNull, - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), - IsSqlcDynamic: p.IsSqlcDynamic(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: c.IsNotNull, + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), }, }) } diff --git a/internal/endtoend/testdata/dynamic/mysql/go/db.go b/internal/endtoend/testdata/dynamic/mysql/go/db.go index 57406b68e8..a457fb76b2 100644 --- a/internal/endtoend/testdata/dynamic/mysql/go/db.go +++ b/internal/endtoend/testdata/dynamic/mysql/go/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 package querytest diff --git a/internal/endtoend/testdata/dynamic/mysql/go/models.go b/internal/endtoend/testdata/dynamic/mysql/go/models.go index 1f67f50bb3..b5f5c9b7ed 100644 --- a/internal/endtoend/testdata/dynamic/mysql/go/models.go +++ b/internal/endtoend/testdata/dynamic/mysql/go/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 package querytest diff --git a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go index 597942bd98..94fba1cefe 100644 --- a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 // source: query.sql package querytest @@ -8,6 +8,7 @@ package querytest import ( "context" "database/sql" + "strings" ) const selectUsers = `-- name: SelectUsers :many diff --git a/internal/endtoend/testdata/subquery_with_where/go/db.go b/internal/endtoend/testdata/subquery_with_where/go/db.go index 57406b68e8..a457fb76b2 100644 --- a/internal/endtoend/testdata/subquery_with_where/go/db.go +++ b/internal/endtoend/testdata/subquery_with_where/go/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 package querytest diff --git a/internal/endtoend/testdata/subquery_with_where/go/models.go b/internal/endtoend/testdata/subquery_with_where/go/models.go index 3fa48ca789..56e82be769 100644 --- a/internal/endtoend/testdata/subquery_with_where/go/models.go +++ b/internal/endtoend/testdata/subquery_with_where/go/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 package querytest diff --git a/internal/endtoend/testdata/subquery_with_where/go/query.sql.go b/internal/endtoend/testdata/subquery_with_where/go/query.sql.go index d6db500c95..bb21b9b5cc 100644 --- a/internal/endtoend/testdata/subquery_with_where/go/query.sql.go +++ b/internal/endtoend/testdata/subquery_with_where/go/query.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.22.0 // source: query.sql package querytest From 8161e09e2b0557bb3629b0a9a05666b2b5f10341 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sat, 14 Oct 2023 20:40:57 -0700 Subject: [PATCH 08/12] Adding sql dynamic function --- internal/codegen/golang/gen.go | 12 +- internal/codegen/golang/query.go | 6 +- .../codegen/golang/templates/pgx/dbCode.tmpl | 2 +- .../golang/templates/pgx/queryCode.tmpl | 70 ++++++ .../golang/templates/stdlib/dbCode.tmpl | 4 +- .../golang/templates/stdlib/queryCode.tmpl | 4 +- internal/compiler/analyze.go | 6 +- internal/compiler/resolve.go | 202 ++++++++++++------ .../endtoend/testdata/dynamic/mysql/go/db.go | 3 + .../testdata/dynamic/mysql/go/query.sql.go | 9 +- .../endtoend/testdata/dynamic/pgx/v4/go/db.go | 32 +++ .../testdata/dynamic/pgx/v4/go/models.go | 25 +++ .../testdata/dynamic/pgx/v4/go/query.sql.go | 194 +++++++++++++++++ .../testdata/dynamic/pgx/v4/query.sql | 37 ++++ .../testdata/dynamic/pgx/v4/sqlc.json | 13 ++ .../endtoend/testdata/dynamic/pgx/v5/go/db.go | 32 +++ .../testdata/dynamic/pgx/v5/go/models.go | 23 ++ .../testdata/dynamic/pgx/v5/go/query.sql.go | 177 +++++++++++++++ .../testdata/dynamic/pgx/v5/query.sql | 34 +++ .../testdata/dynamic/pgx/v5/sqlc.json | 13 ++ .../endtoend/testdata/dynamic/stdlib/go/db.go | 3 + .../testdata/dynamic/stdlib/go/query.sql.go | 14 +- .../process_plugin_sqlc_gen_json/exec.json | 1 + internal/plugin/codegen.pb.go | 12 +- 24 files changed, 837 insertions(+), 91 deletions(-) create mode 100644 internal/endtoend/testdata/dynamic/pgx/v4/go/db.go create mode 100644 internal/endtoend/testdata/dynamic/pgx/v4/go/models.go create mode 100644 internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go create mode 100644 internal/endtoend/testdata/dynamic/pgx/v4/query.sql create mode 100644 internal/endtoend/testdata/dynamic/pgx/v4/sqlc.json create mode 100644 internal/endtoend/testdata/dynamic/pgx/v5/go/db.go create mode 100644 internal/endtoend/testdata/dynamic/pgx/v5/go/models.go create mode 100644 internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go create mode 100644 internal/endtoend/testdata/dynamic/pgx/v5/query.sql create mode 100644 internal/endtoend/testdata/dynamic/pgx/v5/sqlc.json diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index e3b6c14954..cfbfaf1930 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -132,6 +132,13 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [ Enums: enums, Structs: structs, } + var hasDynamic bool + for _, q := range queries { + if q.Arg.HasSqlcDynamic() { + hasDynamic = true + break + } + } tctx := tmplCtx{ EmitInterface: options.EmitInterface, @@ -150,8 +157,8 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [ Package: options.Package, Enums: enums, Structs: structs, - SqlcVersion: req.SqlcVersion, BuildTags: options.BuildTags, + SqlcVersion: req.SqlcVersion, } if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != SQLDriverGoSQLDriverMySQL { @@ -185,6 +192,9 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [ "dollar": func() bool { return req.Settings.Engine == string(config.EnginePostgreSQL) }, + "hasDynamic": func() bool { + return hasDynamic + }, } tmpl := template.Must( diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 19438496cf..f509fbe6b3 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -128,14 +128,16 @@ func (v QueryValue) Params() string { } var out []string if v.Struct == nil { - if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { + if v.Column.IsSqlcDynamic { + } else if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { out = append(out, "pq.Array("+v.Name+")") } else { out = append(out, v.Name) } } else { for _, f := range v.Struct.Fields { - if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() { + if f.HasSqlcDynamic() { + } else if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() { out = append(out, "pq.Array("+v.VariableForField(f)+")") } else { out = append(out, v.VariableForField(f)) diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index 2c0960847a..5c3b5b55ea 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -12,7 +12,7 @@ type DBTX interface { {{- end }} } -{{- if .HasSqlcDynamic }} +{{- if hasDynamic }} type DynamicSql interface { ToSql(int) (string, []interface{}) } diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index 18de5db2ba..71599f5a3b 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -1,3 +1,22 @@ +{{define "preexec"}} + {{- if .Arg.Struct }} + queryParams := []interface{}{ {{.Arg.Params}} } + {{- $arg := .Arg }} + curNumb := {{ $arg.SqlcDynamic }} + {{- range .Arg.Struct.Fields }} + {{- if .HasSqlcDynamic }} + replaceText, args := {{$arg.VariableForField .}}.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/$1", replaceText) + queryParams = append(queryParams, args...) + {{- end}} + {{- end}} + {{- else}} + replaceText, args := {{.Arg.Column.Name}}.ToSql(1) + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Arg.Column.Name}}*/$1", replaceText) + {{- end}} +{{- end}} + {{define "queryCodePgx"}} {{range .GoQueries}} {{if $.OutputQuery .SourceName}} @@ -28,10 +47,20 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + row := db.QueryRow(ctx, query, queryParams...) + {{- else}} row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + row := q.db.QueryRow(ctx, query, queryParams...) + {{- else}} row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} var {{.Ret.Name}} {{.Ret.Type}} @@ -46,10 +75,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + rows, err := db.Query(ctx, query, queryParams...) + {{- else}} rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + rows, err := q.db.Query(ctx, query, queryParams...) + {{- else}} rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} if err != nil { return nil, err @@ -79,10 +118,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret. {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + _, err := db.Exec(ctx, query, queryParams...) + {{- else}} _, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + _, err := q.db.Exec(ctx, query, queryParams...) + {{- else}} _, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} return err } @@ -93,10 +142,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { {{end -}} {{if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + result, err := db.Exec(ctx, query, queryParams...) + {{- else}} result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + result, err := q.db.Exec(ctx, query, queryParams...) + {{- else}} result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} if err != nil { return 0, err @@ -110,10 +169,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + return db.Exec(ctx, query, queryParams...) + {{- else}} return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + return q.db.Exec(ctx, query, queryParams...) + {{- else}} return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} } {{end}} @@ -122,3 +191,4 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.Co {{end}} {{end}} {{end}} + diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index 59c22974b3..d8cb756f17 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -6,9 +6,9 @@ type DBTX interface { QueryRowContext(context.Context, string, ...interface{}) *sql.Row } -{{- if .HasSqlcDynamic }} +{{- if hasDynamic }} type DynamicSql interface { - Sql() (string, []interface{}) + Sql({{ if dollar}}int{{ end }}) (string, []interface{}) } {{- end}} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index c754d68d1d..893fd595b7 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -133,7 +133,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{- else if .HasSqlcDynamic }} replaceText, args := {{$arg.VariableForField .}}.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:{{.Column.Name}}*/{{- if dollar }}$1{{ else }}?{{ end }}", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/{{- if dollar }}$1{{ else }}?{{ end }}", replaceText) queryParams = append(queryParams, args...) {{- else }} queryParams = append(queryParams, {{$arg.VariableForField .}}) @@ -143,7 +143,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{- if .Arg.HasSqlcDynamic }} var replaceText string replaceText, queryParams = {{ .Arg.Column.Name}}.ToSql(curNumb) - query = strings.ReplaceAllString(query, "/*DYNAMIC:{{.Arg.Column.Name}}*/?", "replaceText") + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Arg.Column.Name}}*/?", "replaceText") {{- else }} {{- /* Single argument parameter to this goroutine (they are not packed in a struct), because .Arg.HasSqlcSlices further up above was true, diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 739cd07993..63de0dc32e 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -160,7 +160,11 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } - params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) + params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams) + if err := check(err); err != nil { + return nil, err + } + err = c.resolveCatalogEmbeds(qc, rvs, embeds) if err := check(err); err != nil { return nil, err } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 4624c5a45d..6cf04093bf 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -21,7 +21,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar, embeds rewrite.EmbedSet) error { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -56,7 +56,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } fqn, err := ParseTableName(rv) if err != nil { - return nil, err + return err } if _, found := aliasMap[fqn.Name]; found { continue @@ -65,13 +65,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if err != nil { // If the table name doesn't exist, fisrt check if it's a CTE if _, qcerr := qc.GetTable(fqn); qcerr != nil { - return nil, err + return err } continue } err = indexTable(table) if err != nil { - return nil, err + return err } if rv.Alias != nil { aliasMap[*rv.Alias.Aliasname] = fqn @@ -91,11 +91,84 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } - return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) + return fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) } + return nil +} +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet) ([]Parameter, error) { + c := comp.catalog + + // resolve a table for an embed var a []Parameter for _, ref := range args { + if ref.ref.IsSqlcDynamic { + defaultP := named.NewInferredParam("offset", true) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) + a = append(a, Parameter{ + Number: ref.ref.Number, + Column: &Column{ + Name: p.Name(), + DataType: "DynamicSql", + NotNull: p.NotNull(), + IsNamedParam: isNamed, + IsSqlcDynamic: true, + }, + }) + continue + } + aliasMap := map[string]*ast.TableName{} + // TODO: Deprecate defaultTable + var defaultTable *ast.TableName + var tables []*ast.TableName + typeMap := map[string]map[string]map[string]*catalog.Column{} + indexTable := func(table catalog.Table) error { + tables = append(tables, table.Rel) + if defaultTable == nil { + defaultTable = table.Rel + } + schema := table.Rel.Schema + if schema == "" { + schema = c.DefaultSchema + } + if _, exists := typeMap[schema]; !exists { + typeMap[schema] = map[string]map[string]*catalog.Column{} + } + typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{} + for _, c := range table.Columns { + cc := c + typeMap[schema][table.Rel.Name][c.Name] = cc + } + return nil + } + for _, rv := range ref.rvs { + if rv == nil || rv.Relname == nil { + continue + } + fqn, err := ParseTableName(rv) + if err != nil { + return nil, err + } + if _, found := aliasMap[fqn.Name]; found { + continue + } + table, err := c.GetTable(fqn) + if err != nil { + // If the table name doesn't exist, fisrt check if it's a CTE + if _, qcerr := qc.GetTable(fqn); qcerr != nil { + return nil, err + } + continue + } + err = indexTable(table) + if err != nil { + return nil, err + } + if rv.Alias != nil { + aliasMap[*rv.Alias.Aliasname] = fqn + } + } + switch n := ref.parent.(type) { case *limitOffset: @@ -150,11 +223,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: dataType, - IsNamedParam: isNamed, - NotNull: p.NotNull(), - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: dataType, + IsNamedParam: isNamed, + NotNull: p.NotNull(), + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) continue @@ -213,17 +287,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Length: c.Length, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Length: c.Length, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -280,15 +355,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: namePrefix + p.Name(), - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: namePrefix + p.Name(), + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -353,11 +429,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: "any", - IsNamedParam: isNamed, - NotNull: p.NotNull(), - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: "any", + IsNamedParam: isNamed, + NotNull: p.NotNull(), + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) continue @@ -394,11 +471,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: dataType(paramType), - NotNull: p.NotNull(), - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: dataType(paramType), + NotNull: p.NotNull(), + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -456,17 +534,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: &ast.TableName{Schema: schema, Name: rel}, - Length: c.Length, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: &ast.TableName{Schema: schema, Name: rel}, + Length: c.Length, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } else { @@ -567,16 +646,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: c.IsNotNull, - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: c.IsNotNull, + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } diff --git a/internal/endtoend/testdata/dynamic/mysql/go/db.go b/internal/endtoend/testdata/dynamic/mysql/go/db.go index a457fb76b2..5fbccf5c37 100644 --- a/internal/endtoend/testdata/dynamic/mysql/go/db.go +++ b/internal/endtoend/testdata/dynamic/mysql/go/db.go @@ -15,6 +15,9 @@ type DBTX interface { QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) QueryRowContext(context.Context, string, ...interface{}) *sql.Row } +type DynamicSql interface { + Sql() (string, []interface{}) +} func New(db DBTX) *Queries { return &Queries{db: db} diff --git a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go index 94fba1cefe..c4570dcfaf 100644 --- a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go @@ -8,7 +8,6 @@ package querytest import ( "context" "database/sql" - "strings" ) const selectUsers = `-- name: SelectUsers :many @@ -64,8 +63,8 @@ func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamic queryParams = append(queryParams, arg.Age) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) - queryParams = append(queryParams, v) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", replaceText) + queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { return nil, err @@ -115,8 +114,8 @@ func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynami queryParams = append(queryParams, arg.Status) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", "replaceText", 1) - queryParams = append(queryParams, v) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", replaceText) + queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { return nil, err diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go new file mode 100644 index 0000000000..312ca8c296 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/models.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/models.go new file mode 100644 index 0000000000..512742184a --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/models.go @@ -0,0 +1,25 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "database/sql" + + "github.com/jackc/pgtype" +) + +type Order struct { + ID int32 + Price pgtype.Numeric + UserID int32 +} + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 + JobStatus string +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go new file mode 100644 index 0000000000..e53f64b07c --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go @@ -0,0 +1,194 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const selectUser = `-- name: SelectUser :one +SELECT first_name, last_name FROM users WHERE /*DYNAMIC:dynamic*/$1 +` + +type SelectUserRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUser(ctx context.Context, dynamic DynamicSql) (SelectUserRow, error) { + replaceText, args := dynamic.ToSql(1) + query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + row := q.db.QueryRow(ctx, query, queryParams...) + var i SelectUserRow + err := row.Scan(&i.FirstName, &i.LastName) + return i, err +} + +const selectUsers = `-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > $1 +` + +type SelectUsersRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, error) { + rows, err := q.db.Query(ctx, selectUsers, age) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersRow + for rows.Next() { + var i SelectUsersRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic = `-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > $1 AND /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamicParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { + queryParams := []interface{}{arg.Age} + curNumb := 2 + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicRow + for rows.Next() { + var i SelectUsersDynamicRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamic2Params struct { + Age int32 + Status string + Dynamic DynamicSql +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + queryParams := []interface{}{arg.Age, arg.Status} + curNumb := 3 + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicMulti = `-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +ORDER BY /*DYNAMIC:order*/$1 +` + +type SelectUsersDynamicMultiParams struct { + Age int32 + Status string + Dynamic DynamicSql + Order DynamicSql +} + +type SelectUsersDynamicMultiRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDynamicMultiParams) ([]SelectUsersDynamicMultiRow, error) { + queryParams := []interface{}{arg.Age, arg.Status} + curNumb := 3 + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + replaceText, args := arg.Order.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAllString(query, "/*DYNAMIC:order*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicMultiRow + for rows.Next() { + var i SelectUsersDynamicMultiRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/query.sql b/internal/endtoend/testdata/dynamic/pgx/v4/query.sql new file mode 100644 index 0000000000..281e23a395 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/query.sql @@ -0,0 +1,37 @@ +CREATE TABLE users ( + id int PRIMARY KEY, + first_name text NOT NULL, + last_name text, + age int NOT NULL, + job_status text NOT NULL +); + +CREATE TABLE orders ( + id int PRIMARY KEY, + price numeric NOT NULL, + user_id int NOT NULL +); + +-- name: SelectUser :one +SELECT first_name, last_name FROM users WHERE sqlc.dynamic('dynamic'); + +-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); + +-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic') +ORDER BY sqlc.dynamic('order'); diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/sqlc.json b/internal/endtoend/testdata/dynamic/pgx/v4/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go new file mode 100644 index 0000000000..8a010ccc48 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/models.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/models.go new file mode 100644 index 0000000000..d1c0243acb --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/models.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type Order struct { + ID int32 + Price pgtype.Numeric + UserID int32 +} + +type User struct { + ID int32 + FirstName string + LastName pgtype.Text + Age int32 + JobStatus string +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go new file mode 100644 index 0000000000..5e9f4d224a --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go @@ -0,0 +1,177 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const selectUsers = `-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > $1 +` + +type SelectUsersRow struct { + FirstName string + LastName pgtype.Text +} + +func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, error) { + rows, err := q.db.Query(ctx, selectUsers, age) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersRow + for rows.Next() { + var i SelectUsersRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic = `-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > $1 AND /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamicParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicRow struct { + FirstName string + LastName pgtype.Text +} + +func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { + queryParams := []interface{}{arg.Age} + curNumb := 2 + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicRow + for rows.Next() { + var i SelectUsersDynamicRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamic2Params struct { + Age int32 + Status string + Dynamic DynamicSql +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName pgtype.Text +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + queryParams := []interface{}{arg.Age, arg.Status} + curNumb := 3 + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicMulti = `-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +ORDER BY /*DYNAMIC:order*/$1 +` + +type SelectUsersDynamicMultiParams struct { + Age int32 + Status string + Dynamic DynamicSql + Order DynamicSql +} + +type SelectUsersDynamicMultiRow struct { + FirstName string + LastName pgtype.Text +} + +func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDynamicMultiParams) ([]SelectUsersDynamicMultiRow, error) { + queryParams := []interface{}{arg.Age, arg.Status} + curNumb := 3 + replaceText, args := arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + replaceText, args := arg.Order.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAllString(query, "/*DYNAMIC:order*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicMultiRow + for rows.Next() { + var i SelectUsersDynamicMultiRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/query.sql b/internal/endtoend/testdata/dynamic/pgx/v5/query.sql new file mode 100644 index 0000000000..dca8604f7a --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/query.sql @@ -0,0 +1,34 @@ +CREATE TABLE users ( + id int PRIMARY KEY, + first_name text NOT NULL, + last_name text, + age int NOT NULL, + job_status text NOT NULL +); + +CREATE TABLE orders ( + id int PRIMARY KEY, + price numeric NOT NULL, + user_id int NOT NULL +); + +-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); + +-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic') +ORDER BY sqlc.dynamic('order'); diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/sqlc.json b/internal/endtoend/testdata/dynamic/pgx/v5/sqlc.json new file mode 100644 index 0000000000..6645ccbd1b --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v5", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/db.go b/internal/endtoend/testdata/dynamic/stdlib/go/db.go index a457fb76b2..30473a9784 100644 --- a/internal/endtoend/testdata/dynamic/stdlib/go/db.go +++ b/internal/endtoend/testdata/dynamic/stdlib/go/db.go @@ -15,6 +15,9 @@ type DBTX interface { QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) QueryRowContext(context.Context, string, ...interface{}) *sql.Row } +type DynamicSql interface { + Sql(int) (string, []interface{}) +} func New(db DBTX) *Queries { return &Queries{db: db} diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go index 7c37cdb9f2..a4bf4ea284 100644 --- a/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go @@ -69,7 +69,7 @@ func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynami queryParams = append(queryParams, arg.Status) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { @@ -114,7 +114,7 @@ func (q *Queries) SelectUsersDynamicA(ctx context.Context, arg SelectUsersDynami queryParams = append(queryParams, arg.Age) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { @@ -159,7 +159,7 @@ func (q *Queries) SelectUsersDynamicB(ctx context.Context, arg SelectUsersDynami queryParams = append(queryParams, arg.Age) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { @@ -212,11 +212,11 @@ func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDy queryParams = append(queryParams, arg.Status) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) replaceText, args := arg.Order.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText, 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { @@ -269,11 +269,11 @@ func (q *Queries) SelectUsersDynamicMultiB(ctx context.Context, arg SelectUsersD queryParams = append(queryParams, arg.Status) replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText, 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) replaceText, args := arg.Order.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText, 1) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { diff --git a/internal/endtoend/testdata/process_plugin_sqlc_gen_json/exec.json b/internal/endtoend/testdata/process_plugin_sqlc_gen_json/exec.json index 2e996ca79d..572ba3e887 100644 --- a/internal/endtoend/testdata/process_plugin_sqlc_gen_json/exec.json +++ b/internal/endtoend/testdata/process_plugin_sqlc_gen_json/exec.json @@ -1,3 +1,4 @@ { + "process": "sqlc-gen-json", "contexts": ["base"] } diff --git a/internal/plugin/codegen.pb.go b/internal/plugin/codegen.pb.go index fe0a5ba358..05b770a70d 100644 --- a/internal/plugin/codegen.pb.go +++ b/internal/plugin/codegen.pb.go @@ -7,10 +7,11 @@ package plugin import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -954,13 +955,6 @@ func (x *Column) GetIsSqlcSlice() bool { } return false } -func (x *Column) GetIsSqlcDynamic() bool { - if x != nil { - return x.IsSqlcDynamic - } - return false -} - func (x *Column) GetEmbedTable() *Identifier { if x != nil { return x.EmbedTable From 0cc6a67302277850005c10445be5fcd4fc418b45 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sat, 14 Oct 2023 20:45:16 -0700 Subject: [PATCH 09/12] Fix --- internal/endtoend/testdata/dynamic/pgx/v4/go/db.go | 3 +++ .../endtoend/testdata/dynamic/pgx/v4/go/query.sql.go | 10 +++++----- internal/endtoend/testdata/dynamic/pgx/v5/go/db.go | 3 +++ .../endtoend/testdata/dynamic/pgx/v5/go/query.sql.go | 8 ++++---- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go index 312ca8c296..abb926d24d 100644 --- a/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go @@ -16,6 +16,9 @@ type DBTX interface { Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row } +type DynamicSql interface { + ToSql(int) (string, []interface{}) +} func New(db DBTX) *Queries { return &Queries{db: db} diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go index e53f64b07c..744cea1588 100644 --- a/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go @@ -21,7 +21,7 @@ type SelectUserRow struct { func (q *Queries) SelectUser(ctx context.Context, dynamic DynamicSql) (SelectUserRow, error) { replaceText, args := dynamic.ToSql(1) - query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) row := q.db.QueryRow(ctx, query, queryParams...) var i SelectUserRow err := row.Scan(&i.FirstName, &i.LastName) @@ -76,7 +76,7 @@ func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamic curNumb := 2 replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.Query(ctx, query, queryParams...) if err != nil { @@ -121,7 +121,7 @@ func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynami curNumb := 3 replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.Query(ctx, query, queryParams...) if err != nil { @@ -168,11 +168,11 @@ func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDy curNumb := 3 replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) replaceText, args := arg.Order.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:order*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.Query(ctx, query, queryParams...) if err != nil { diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go index 8a010ccc48..fcbd7c7201 100644 --- a/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go @@ -16,6 +16,9 @@ type DBTX interface { Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row } +type DynamicSql interface { + ToSql(int) (string, []interface{}) +} func New(db DBTX) *Queries { return &Queries{db: db} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go index 5e9f4d224a..6cc7a4e44e 100644 --- a/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go @@ -59,7 +59,7 @@ func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamic curNumb := 2 replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.Query(ctx, query, queryParams...) if err != nil { @@ -104,7 +104,7 @@ func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynami curNumb := 3 replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.Query(ctx, query, queryParams...) if err != nil { @@ -151,11 +151,11 @@ func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDy curNumb := 3 replaceText, args := arg.Dynamic.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:dynamic*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) replaceText, args := arg.Order.ToSql(curNumb) curNumb += len(args) - query = strings.ReplaceAllString(query, "/*DYNAMIC:order*/$1", replaceText) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) queryParams = append(queryParams, args...) rows, err := q.db.Query(ctx, query, queryParams...) if err != nil { From 720ffda839a160f0b4f6c5452d91c1508e43edd7 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 15 Oct 2023 16:23:48 -0700 Subject: [PATCH 10/12] Updates to match function with interface definition --- internal/codegen/golang/templates/stdlib/dbCode.tmpl | 2 +- internal/endtoend/testdata/dynamic/mysql/go/db.go | 2 +- internal/endtoend/testdata/dynamic/stdlib/go/db.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index d8cb756f17..f7b61e7fe2 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -8,7 +8,7 @@ type DBTX interface { {{- if hasDynamic }} type DynamicSql interface { - Sql({{ if dollar}}int{{ end }}) (string, []interface{}) + ToSql({{ if dollar}}int{{ end }}) (string, []interface{}) } {{- end}} diff --git a/internal/endtoend/testdata/dynamic/mysql/go/db.go b/internal/endtoend/testdata/dynamic/mysql/go/db.go index 5fbccf5c37..43bb13795a 100644 --- a/internal/endtoend/testdata/dynamic/mysql/go/db.go +++ b/internal/endtoend/testdata/dynamic/mysql/go/db.go @@ -16,7 +16,7 @@ type DBTX interface { QueryRowContext(context.Context, string, ...interface{}) *sql.Row } type DynamicSql interface { - Sql() (string, []interface{}) + ToSql() (string, []interface{}) } func New(db DBTX) *Queries { diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/db.go b/internal/endtoend/testdata/dynamic/stdlib/go/db.go index 30473a9784..c473940b6e 100644 --- a/internal/endtoend/testdata/dynamic/stdlib/go/db.go +++ b/internal/endtoend/testdata/dynamic/stdlib/go/db.go @@ -16,7 +16,7 @@ type DBTX interface { QueryRowContext(context.Context, string, ...interface{}) *sql.Row } type DynamicSql interface { - Sql(int) (string, []interface{}) + ToSql(int) (string, []interface{}) } func New(db DBTX) *Queries { From 3bfe6a4e6a166c7736f5d7a38b07fb13bdd1a04a Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 15 Oct 2023 23:05:41 -0700 Subject: [PATCH 11/12] Template Fixes to get it working --- internal/codegen/golang/imports.go | 6 ++++- .../golang/templates/pgx/queryCode.tmpl | 9 ++++--- .../golang/templates/stdlib/queryCode.tmpl | 8 ++++-- .../testdata/dynamic/mysql/go/query.sql.go | 9 +++++-- .../testdata/dynamic/pgx/v4/go/query.sql.go | 22 +++++++++++----- .../testdata/dynamic/pgx/v5/go/query.sql.go | 18 ++++++++++--- .../testdata/dynamic/stdlib/go/query.sql.go | 25 +++++++++++++------ 7 files changed, 72 insertions(+), 25 deletions(-) diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 381ec0ffe9..2793472160 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -288,10 +288,14 @@ func sortedImports(std map[string]struct{}, pkg map[ImportSpec]struct{}) fileImp func (i *importer) queryImports(filename string) fileImports { var gq []Query anyNonCopyFrom := false + useStrings := false for _, query := range i.Queries { if usesBatch([]Query{query}) { continue } + if query.Arg.HasSqlcDynamic() { + useStrings = true + } if query.SourceName == filename { gq = append(gq, query) if query.Cmd != metadata.CmdCopyFrom { @@ -384,7 +388,7 @@ func (i *importer) queryImports(filename string) fileImports { } sqlpkg := parseDriver(i.Options.SqlPackage) - if sqlcSliceScan() { + if useStrings || sqlcSliceScan() { std["strings"] = struct{}{} } if sliceScan() && !sqlpkg.IsPGX() { diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index 71599f5a3b..1e94007862 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -3,17 +3,20 @@ queryParams := []interface{}{ {{.Arg.Params}} } {{- $arg := .Arg }} curNumb := {{ $arg.SqlcDynamic }} + query := {{.ConstantName}} + var replaceText string + var args []interface{} {{- range .Arg.Struct.Fields }} {{- if .HasSqlcDynamic }} - replaceText, args := {{$arg.VariableForField .}}.ToSql(curNumb) + replaceText, args = {{$arg.VariableForField .}}.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/$1", replaceText) queryParams = append(queryParams, args...) {{- end}} {{- end}} {{- else}} - replaceText, args := {{.Arg.Column.Name}}.ToSql(1) - query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Arg.Column.Name}}*/$1", replaceText) + replaceText, queryParams := {{.Arg.Column.Name}}.ToSql(1) + query := strings.ReplaceAll({{.ConstantName}}, "/*DYNAMIC:{{.Arg.Column.Name}}*/$1", replaceText) {{- end}} {{- end}} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 893fd595b7..3faf7ba16c 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -116,6 +116,10 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} curNumb := {{ .Arg.SqlcDynamic }} {{- end }} {{- if .Arg.Struct }} + {{- if .Arg.HasSqlcDynamic }} + var replaceText string + var args []interface{} + {{- end }} {{- $arg := .Arg }} {{- range .Arg.Struct.Fields }} {{- if .HasSqlcSlice }} @@ -131,7 +135,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) } {{- else if .HasSqlcDynamic }} - replaceText, args := {{$arg.VariableForField .}}.ToSql(curNumb) + replaceText, args = {{$arg.VariableForField .}}.ToSql({{ if dollar }}curNumb{{ end }}) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/{{- if dollar }}$1{{ else }}?{{ end }}", replaceText) queryParams = append(queryParams, args...) @@ -142,7 +146,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{- else }} {{- if .Arg.HasSqlcDynamic }} var replaceText string - replaceText, queryParams = {{ .Arg.Column.Name}}.ToSql(curNumb) + replaceText, queryParams = {{ .Arg.Column.Name}}.ToSql({{ if dollar }}curNumb{{ end }}) query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Arg.Column.Name}}*/?", "replaceText") {{- else }} {{- /* Single argument parameter to this goroutine (they are not packed diff --git a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go index c4570dcfaf..8b27c51f8c 100644 --- a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go @@ -8,6 +8,7 @@ package querytest import ( "context" "database/sql" + "strings" ) const selectUsers = `-- name: SelectUsers :many @@ -60,8 +61,10 @@ func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamic query := selectUsersDynamic var queryParams []interface{} curNumb := 2 + var replaceText string + var args []interface{} queryParams = append(queryParams, arg.Age) - replaceText, args := arg.Dynamic.ToSql(curNumb) + replaceText, args = arg.Dynamic.ToSql() curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", replaceText) queryParams = append(queryParams, args...) @@ -110,9 +113,11 @@ func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynami query := selectUsersDynamic2 var queryParams []interface{} curNumb := 3 + var replaceText string + var args []interface{} queryParams = append(queryParams, arg.Age) queryParams = append(queryParams, arg.Status) - replaceText, args := arg.Dynamic.ToSql(curNumb) + replaceText, args = arg.Dynamic.ToSql() curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", replaceText) queryParams = append(queryParams, args...) diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go index 744cea1588..65c8566795 100644 --- a/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go @@ -8,6 +8,7 @@ package querytest import ( "context" "database/sql" + "strings" ) const selectUser = `-- name: SelectUser :one @@ -20,8 +21,8 @@ type SelectUserRow struct { } func (q *Queries) SelectUser(ctx context.Context, dynamic DynamicSql) (SelectUserRow, error) { - replaceText, args := dynamic.ToSql(1) - query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + replaceText, queryParams := dynamic.ToSql(1) + query := strings.ReplaceAll(selectUser, "/*DYNAMIC:dynamic*/$1", replaceText) row := q.db.QueryRow(ctx, query, queryParams...) var i SelectUserRow err := row.Scan(&i.FirstName, &i.LastName) @@ -74,7 +75,10 @@ type SelectUsersDynamicRow struct { func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { queryParams := []interface{}{arg.Age} curNumb := 2 - replaceText, args := arg.Dynamic.ToSql(curNumb) + query := selectUsersDynamic + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) @@ -119,7 +123,10 @@ type SelectUsersDynamic2Row struct { func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { queryParams := []interface{}{arg.Age, arg.Status} curNumb := 3 - replaceText, args := arg.Dynamic.ToSql(curNumb) + query := selectUsersDynamic2 + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) @@ -166,11 +173,14 @@ type SelectUsersDynamicMultiRow struct { func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDynamicMultiParams) ([]SelectUsersDynamicMultiRow, error) { queryParams := []interface{}{arg.Age, arg.Status} curNumb := 3 - replaceText, args := arg.Dynamic.ToSql(curNumb) + query := selectUsersDynamicMulti + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) - replaceText, args := arg.Order.ToSql(curNumb) + replaceText, args = arg.Order.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) queryParams = append(queryParams, args...) diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go index 6cc7a4e44e..89cbd6e8a0 100644 --- a/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go @@ -7,6 +7,7 @@ package querytest import ( "context" + "strings" "github.com/jackc/pgx/v5/pgtype" ) @@ -57,7 +58,10 @@ type SelectUsersDynamicRow struct { func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { queryParams := []interface{}{arg.Age} curNumb := 2 - replaceText, args := arg.Dynamic.ToSql(curNumb) + query := selectUsersDynamic + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) @@ -102,7 +106,10 @@ type SelectUsersDynamic2Row struct { func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { queryParams := []interface{}{arg.Age, arg.Status} curNumb := 3 - replaceText, args := arg.Dynamic.ToSql(curNumb) + query := selectUsersDynamic2 + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) @@ -149,11 +156,14 @@ type SelectUsersDynamicMultiRow struct { func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDynamicMultiParams) ([]SelectUsersDynamicMultiRow, error) { queryParams := []interface{}{arg.Age, arg.Status} curNumb := 3 - replaceText, args := arg.Dynamic.ToSql(curNumb) + query := selectUsersDynamicMulti + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) - replaceText, args := arg.Order.ToSql(curNumb) + replaceText, args = arg.Order.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) queryParams = append(queryParams, args...) diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go index a4bf4ea284..3e97375afb 100644 --- a/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go @@ -8,6 +8,7 @@ package querytest import ( "context" "database/sql" + "strings" ) const selectUsers = `-- name: SelectUsers :many @@ -65,9 +66,11 @@ func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynami query := selectUsersDynamic2 var queryParams []interface{} curNumb := 3 + var replaceText string + var args []interface{} queryParams = append(queryParams, arg.Age) queryParams = append(queryParams, arg.Status) - replaceText, args := arg.Dynamic.ToSql(curNumb) + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) @@ -111,8 +114,10 @@ func (q *Queries) SelectUsersDynamicA(ctx context.Context, arg SelectUsersDynami query := selectUsersDynamicA var queryParams []interface{} curNumb := 2 + var replaceText string + var args []interface{} queryParams = append(queryParams, arg.Age) - replaceText, args := arg.Dynamic.ToSql(curNumb) + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) @@ -156,8 +161,10 @@ func (q *Queries) SelectUsersDynamicB(ctx context.Context, arg SelectUsersDynami query := selectUsersDynamicB var queryParams []interface{} curNumb := 2 + var replaceText string + var args []interface{} queryParams = append(queryParams, arg.Age) - replaceText, args := arg.Dynamic.ToSql(curNumb) + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) @@ -208,13 +215,15 @@ func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDy query := selectUsersDynamicMulti var queryParams []interface{} curNumb := 3 + var replaceText string + var args []interface{} queryParams = append(queryParams, arg.Age) queryParams = append(queryParams, arg.Status) - replaceText, args := arg.Dynamic.ToSql(curNumb) + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) - replaceText, args := arg.Order.ToSql(curNumb) + replaceText, args = arg.Order.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) queryParams = append(queryParams, args...) @@ -265,13 +274,15 @@ func (q *Queries) SelectUsersDynamicMultiB(ctx context.Context, arg SelectUsersD query := selectUsersDynamicMultiB var queryParams []interface{} curNumb := 3 + var replaceText string + var args []interface{} queryParams = append(queryParams, arg.Age) queryParams = append(queryParams, arg.Status) - replaceText, args := arg.Dynamic.ToSql(curNumb) + replaceText, args = arg.Dynamic.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) queryParams = append(queryParams, args...) - replaceText, args := arg.Order.ToSql(curNumb) + replaceText, args = arg.Order.ToSql(curNumb) curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) queryParams = append(queryParams, args...) From b055f28e4990053fcde084a0a244cbad5fc6308f Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Mon, 16 Oct 2023 23:22:46 -0700 Subject: [PATCH 12/12] Changing mysql to show ... --- .../endtoend/testdata/dynamic/mysql/go/query.sql.go | 12 ++++++------ internal/endtoend/testdata/dynamic/mysql/query.sql | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go index 8b27c51f8c..5516e45400 100644 --- a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go @@ -93,15 +93,15 @@ func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamic const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many SELECT first_name, last_name FROM users -WHERE age > ? AND - job_status = ? AND - /*DYNAMIC:dynamic*/? +WHERE /*DYNAMIC:dynamic*/? AND + age > ? AND + job_status = ? ` type SelectUsersDynamic2Params struct { + Dynamic DynamicSql Age int32 Status string - Dynamic DynamicSql } type SelectUsersDynamic2Row struct { @@ -115,12 +115,12 @@ func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynami curNumb := 3 var replaceText string var args []interface{} - queryParams = append(queryParams, arg.Age) - queryParams = append(queryParams, arg.Status) replaceText, args = arg.Dynamic.ToSql() curNumb += len(args) query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", replaceText) queryParams = append(queryParams, args...) + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) rows, err := q.db.QueryContext(ctx, query, queryParams...) if err != nil { return nil, err diff --git a/internal/endtoend/testdata/dynamic/mysql/query.sql b/internal/endtoend/testdata/dynamic/mysql/query.sql index 6a90908f81..8c3d94dff3 100644 --- a/internal/endtoend/testdata/dynamic/mysql/query.sql +++ b/internal/endtoend/testdata/dynamic/mysql/query.sql @@ -20,6 +20,6 @@ SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynam -- name: SelectUsersDynamic2 :many SELECT first_name, last_name FROM users -WHERE age > sqlc.arg(age) AND - job_status = sqlc.arg(status) AND - sqlc.dynamic('dynamic'); +WHERE sqlc.dynamic('dynamic') AND + age > sqlc.arg(age) AND + job_status = sqlc.arg(status) ;