diff --git a/internal/cmd/shim.go b/internal/cmd/shim.go index 4453d47a3c..3f70a34a89 100644 --- a/internal/cmd/shim.go +++ b/internal/cmd/shim.go @@ -228,17 +228,18 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column { l = *c.Length } out := &plugin.Column{ - Name: c.Name, - OriginalName: c.OriginalName, - Comment: c.Comment, - NotNull: c.NotNull, - Unsigned: c.Unsigned, - IsArray: c.IsArray, - ArrayDims: int32(c.ArrayDims), - Length: int32(l), - IsNamedParam: c.IsNamedParam, - IsFuncCall: c.IsFuncCall, - IsSqlcSlice: c.IsSqlcSlice, + Name: c.Name, + OriginalName: c.OriginalName, + Comment: c.Comment, + NotNull: c.NotNull, + Unsigned: c.Unsigned, + IsArray: c.IsArray, + ArrayDims: int32(c.ArrayDims), + Length: int32(l), + IsNamedParam: c.IsNamedParam, + IsFuncCall: c.IsFuncCall, + IsSqlcSlice: c.IsSqlcSlice, + IsSqlcDynamic: c.IsSqlcDynamic, } if c.Type != nil { diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index 2a8d9ccdfc..87dc40d537 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -28,6 +28,10 @@ func (gf Field) HasSqlcSlice() bool { return gf.Column.IsSqlcSlice } +func (gf Field) HasSqlcDynamic() bool { + return gf.Column.IsSqlcDynamic +} + func TagsToString(tags map[string]string) string { if len(tags) == 0 { return "" diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index e5762374ba..cfbfaf1930 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -11,6 +11,7 @@ import ( "text/template" "github.com/sqlc-dev/sqlc/internal/codegen/sdk" + "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/metadata" "github.com/sqlc-dev/sqlc/internal/plugin" ) @@ -38,6 +39,7 @@ type tmplCtx struct { EmitAllEnumValues bool UsesCopyFrom bool UsesBatch bool + HasSqlcDynamic bool BuildTags string } @@ -130,6 +132,13 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [ Enums: enums, Structs: structs, } + var hasDynamic bool + for _, q := range queries { + if q.Arg.HasSqlcDynamic() { + hasDynamic = true + break + } + } tctx := tmplCtx{ EmitInterface: options.EmitInterface, @@ -148,8 +157,8 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [ Package: options.Package, Enums: enums, Structs: structs, - SqlcVersion: req.SqlcVersion, BuildTags: options.BuildTags, + SqlcVersion: req.SqlcVersion, } if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != SQLDriverGoSQLDriverMySQL { @@ -180,6 +189,12 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [ "emitPreparedQueries": tctx.codegenEmitPreparedQueries, "queryMethod": tctx.codegenQueryMethod, "queryRetval": tctx.codegenQueryRetval, + "dollar": func() bool { + return req.Settings.Engine == string(config.EnginePostgreSQL) + }, + "hasDynamic": func() bool { + return hasDynamic + }, } tmpl := template.Must( diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index 2b5f75bcd4..6edcb4c72f 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -33,6 +33,9 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, co func goType(req *plugin.CodeGenRequest, options *opts, col *plugin.Column) string { // Check if the column's type has been overridden + if col.IsSqlcDynamic { + return "DynamicSql" + } for _, oride := range req.Settings.Overrides { if oride.GoType.TypeName == "" { continue diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 381ec0ffe9..2793472160 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -288,10 +288,14 @@ func sortedImports(std map[string]struct{}, pkg map[ImportSpec]struct{}) fileImp func (i *importer) queryImports(filename string) fileImports { var gq []Query anyNonCopyFrom := false + useStrings := false for _, query := range i.Queries { if usesBatch([]Query{query}) { continue } + if query.Arg.HasSqlcDynamic() { + useStrings = true + } if query.SourceName == filename { gq = append(gq, query) if query.Cmd != metadata.CmdCopyFrom { @@ -384,7 +388,7 @@ func (i *importer) queryImports(filename string) fileImports { } sqlpkg := parseDriver(i.Options.SqlPackage) - if sqlcSliceScan() { + if useStrings || sqlcSliceScan() { std["strings"] = struct{}{} } if sliceScan() && !sqlpkg.IsPGX() { diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index c89d8ff3c7..0406924791 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -12,6 +12,9 @@ func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string { columnType := sdk.DataType(col.Type) notNull := col.NotNull || col.IsArray unsigned := col.Unsigned + if col.IsSqlcDynamic { + return "DynamicSql" + } switch columnType { diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index 815befad30..6965b5ca6f 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -34,6 +34,9 @@ func parseIdentifierString(name string) (*plugin.Identifier, error) { } func postgresType(req *plugin.CodeGenRequest, options *opts, col *plugin.Column) string { + if col.IsSqlcDynamic { + return "DynamicSql" + } columnType := sdk.DataType(col.Type) notNull := col.NotNull || col.IsArray driver := parseDriver(options.SqlPackage) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index b82178686c..8003a053d8 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -128,14 +128,16 @@ func (v QueryValue) Params() string { } var out []string if v.Struct == nil { - if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { + if v.Column.IsSqlcDynamic { + } else if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { out = append(out, "pq.Array("+escape(v.Name)+")") } else { out = append(out, escape(v.Name)) } } else { for _, f := range v.Struct.Fields { - if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() { + if f.HasSqlcDynamic() { + } else if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() { out = append(out, "pq.Array("+escape(v.VariableForField(f))+")") } else { out = append(out, escape(v.VariableForField(f))) @@ -188,6 +190,32 @@ func (v QueryValue) HasSqlcSlices() bool { } return false } +func (v QueryValue) HasSqlcDynamic() bool { + if v.Struct == nil { + if v.Column != nil && v.Column.IsSqlcDynamic { + return true + } + return false + } + for _, v := range v.Struct.Fields { + if v.Column.IsSqlcDynamic { + return true + } + } + return false +} +func (v QueryValue) SqlcDynamic() int { + var count int = 1 + if v.Struct == nil { + return 1 + } + for _, v := range v.Struct.Fields { + if !v.Column.IsSqlcDynamic { + count++ + } + } + return count +} func (v QueryValue) Scan() string { var out []string diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index 236554d9f2..5c3b5b55ea 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -12,6 +12,11 @@ type DBTX interface { {{- end }} } +{{- if hasDynamic }} +type DynamicSql interface { + ToSql(int) (string, []interface{}) +} +{{- end}} {{ if .EmitMethodsWithDBArgument}} func New() *Queries { return &Queries{} diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index 18de5db2ba..1e94007862 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -1,3 +1,25 @@ +{{define "preexec"}} + {{- if .Arg.Struct }} + queryParams := []interface{}{ {{.Arg.Params}} } + {{- $arg := .Arg }} + curNumb := {{ $arg.SqlcDynamic }} + query := {{.ConstantName}} + var replaceText string + var args []interface{} + {{- range .Arg.Struct.Fields }} + {{- if .HasSqlcDynamic }} + replaceText, args = {{$arg.VariableForField .}}.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/$1", replaceText) + queryParams = append(queryParams, args...) + {{- end}} + {{- end}} + {{- else}} + replaceText, queryParams := {{.Arg.Column.Name}}.ToSql(1) + query := strings.ReplaceAll({{.ConstantName}}, "/*DYNAMIC:{{.Arg.Column.Name}}*/$1", replaceText) + {{- end}} +{{- end}} + {{define "queryCodePgx"}} {{range .GoQueries}} {{if $.OutputQuery .SourceName}} @@ -28,10 +50,20 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + row := db.QueryRow(ctx, query, queryParams...) + {{- else}} row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + row := q.db.QueryRow(ctx, query, queryParams...) + {{- else}} row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} var {{.Ret.Name}} {{.Ret.Type}} @@ -46,10 +78,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + rows, err := db.Query(ctx, query, queryParams...) + {{- else}} rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + rows, err := q.db.Query(ctx, query, queryParams...) + {{- else}} rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} if err != nil { return nil, err @@ -79,10 +121,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret. {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + _, err := db.Exec(ctx, query, queryParams...) + {{- else}} _, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + _, err := q.db.Exec(ctx, query, queryParams...) + {{- else}} _, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} return err } @@ -93,10 +145,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { {{end -}} {{if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + result, err := db.Exec(ctx, query, queryParams...) + {{- else}} result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + result, err := q.db.Exec(ctx, query, queryParams...) + {{- else}} result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} if err != nil { return 0, err @@ -110,10 +172,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + return db.Exec(ctx, query, queryParams...) + {{- else}} return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) { + {{- if .Arg.HasSqlcDynamic }} + {{- template "preexec" .}} + return q.db.Exec(ctx, query, queryParams...) + {{- else}} return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end}} {{- end}} } {{end}} @@ -122,3 +194,4 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.Co {{end}} {{end}} {{end}} + diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index 7433d522f6..f7b61e7fe2 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -6,6 +6,12 @@ type DBTX interface { QueryRowContext(context.Context, string, ...interface{}) *sql.Row } +{{- if hasDynamic }} +type DynamicSql interface { + ToSql({{ if dollar}}int{{ end }}) (string, []interface{}) +} +{{- end}} + {{ if .EmitMethodsWithDBArgument}} func New() *Queries { return &Queries{} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index cf56000ec6..3faf7ba16c 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -109,10 +109,17 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{end}} {{define "queryCodeStdExec"}} - {{- if .Arg.HasSqlcSlices }} + {{- if or .Arg.HasSqlcSlices .Arg.HasSqlcDynamic }} query := {{.ConstantName}} var queryParams []interface{} + {{- if .Arg.HasSqlcDynamic }} + curNumb := {{ .Arg.SqlcDynamic }} + {{- end }} {{- if .Arg.Struct }} + {{- if .Arg.HasSqlcDynamic }} + var replaceText string + var args []interface{} + {{- end }} {{- $arg := .Arg }} {{- range .Arg.Struct.Fields }} {{- if .HasSqlcSlice }} @@ -121,14 +128,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} queryParams = append(queryParams, v) } query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", len({{$arg.VariableForField .}}))[1:], 1) + {{- if .HasSqlcDynamic }} + curNumb += len({{$arg.VariableForField .}}) + {{- end}} } else { query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) } + {{- else if .HasSqlcDynamic }} + replaceText, args = {{$arg.VariableForField .}}.ToSql({{ if dollar }}curNumb{{ end }}) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/{{- if dollar }}$1{{ else }}?{{ end }}", replaceText) + queryParams = append(queryParams, args...) {{- else }} queryParams = append(queryParams, {{$arg.VariableForField .}}) {{- end }} {{- end }} {{- else }} + {{- if .Arg.HasSqlcDynamic }} + var replaceText string + replaceText, queryParams = {{ .Arg.Column.Name}}.ToSql({{ if dollar }}curNumb{{ end }}) + query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Arg.Column.Name}}*/?", "replaceText") + {{- else }} {{- /* Single argument parameter to this goroutine (they are not packed in a struct), because .Arg.HasSqlcSlices further up above was true, this section is 100% a slice (impossible to get here otherwise). @@ -141,6 +161,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} } else { query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", "NULL", 1) } + {{- end }} {{- end }} {{- if emitPreparedQueries }} {{ queryRetval . }} {{ queryMethod . }}(ctx, nil, query, queryParams...) diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 739cd07993..63de0dc32e 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -160,7 +160,11 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } - params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) + params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams) + if err := check(err); err != nil { + return nil, err + } + err = c.resolveCatalogEmbeds(qc, rvs, embeds) if err := check(err); err != nil { return nil, err } diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index ca38199b9d..dcce502442 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -10,7 +10,7 @@ import ( func findParameters(root ast.Node) ([]paramRef, []error) { refs := make([]paramRef, 0) errors := make([]error, 0) - v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors} + v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors, rvs: &[]*ast.RangeVar{}} astutils.Walk(v, root) if len(*v.errs) > 0 { return refs, *v.errs @@ -21,6 +21,7 @@ func findParameters(root ast.Node) ([]paramRef, []error) { type paramRef struct { parent ast.Node + rvs []*ast.RangeVar rv *ast.RangeVar ref *ast.ParamRef name string // Named parameter support @@ -30,6 +31,7 @@ type paramSearch struct { parent ast.Node rangeVar *ast.RangeVar refs *[]paramRef + rvs *[]*ast.RangeVar seen map[int]struct{} errs *[]error @@ -53,6 +55,7 @@ func (l *limitOffset) Pos() int { } func (p paramSearch) Visit(node ast.Node) astutils.Visitor { + var reset bool switch n := node.(type) { case *ast.A_Expr: @@ -65,6 +68,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.parent = n.FuncCall case *ast.DeleteStmt: + reset = true if n.LimitCount != nil { p.limitCount = n.LimitCount } @@ -73,7 +77,13 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.parent = node case *ast.InsertStmt: + reset = true + rvs := *p.rvs + if n.Relation != nil { + rvs = append(rvs, n.Relation) + } if s, ok := n.SelectStmt.(*ast.SelectStmt); ok { + rvs = append(rvs, toTables(s.FromClause)...) for i, item := range s.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -87,7 +97,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: rvs}) p.seen[ref.Location] = struct{}{} } for _, item := range s.ValuesLists.Items { @@ -104,13 +114,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: rvs}) p.seen[ref.Location] = struct{}{} } } } case *ast.UpdateStmt: + reset = true + rvs := append(*p.rvs, toTables(n.FromClause)...) + rvs = append(rvs, toTables(n.Relations)...) for _, item := range n.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -125,7 +138,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { if !ok { continue } - *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv}) + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, rvs: rvs}) } p.seen[ref.Location] = struct{}{} } @@ -134,12 +147,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } case *ast.RangeVar: + if n != nil { + *p.rvs = append(*p.rvs, n) + } p.rangeVar = n case *ast.ResTarget: p.parent = node case *ast.SelectStmt: + reset = true if n.LimitCount != nil { p.limitCount = n.LimitCount } @@ -186,7 +203,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } if set { - *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar}) + *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar, rvs: *p.rvs}) p.seen[n.Location] = struct{}{} } return nil @@ -210,5 +227,20 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.Visit(n.Expr) } } + if reset { + rvs := *p.rvs + return paramSearch{seen: p.seen, refs: p.refs, errs: p.errs, rvs: &rvs, parent: p.parent, rangeVar: p.rangeVar, limitCount: p.limitCount, limitOffset: p.limitOffset} + } return p } + +func toTables(tbl *ast.List) []*ast.RangeVar { + tables := make([]*ast.RangeVar, len(tbl.Items)) + for _, t := range tbl.Items { + item, ok := t.(*ast.RangeVar) + if ok && item != nil { + tables = append(tables, item) + } + } + return tables +} diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 53e3043c7d..2d377368de 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -77,11 +77,10 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, } expanded := anlys.Query - // If the query string was edited, make sure the syntax is valid if expanded != rawSQL { if _, err := c.parser.Parse(strings.NewReader(expanded)); err != nil { - return nil, fmt.Errorf("edited query syntax is invalid: %w", err) + return nil, fmt.Errorf("edited query syntax is invalid: %w - %s", err, expanded) } } diff --git a/internal/compiler/query.go b/internal/compiler/query.go index 117cf44813..95aa903970 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -34,7 +34,8 @@ type Column struct { Type *ast.TypeName EmbedTable *ast.TableName - IsSqlcSlice bool // is this sqlc.slice() + IsSqlcSlice bool // is this sqlc.slice() + IsSqlcDynamic bool // is this sqlc.dynamic() skipTableRequiredCheck bool } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 4624c5a45d..6cf04093bf 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -21,7 +21,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar, embeds rewrite.EmbedSet) error { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -56,7 +56,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } fqn, err := ParseTableName(rv) if err != nil { - return nil, err + return err } if _, found := aliasMap[fqn.Name]; found { continue @@ -65,13 +65,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if err != nil { // If the table name doesn't exist, fisrt check if it's a CTE if _, qcerr := qc.GetTable(fqn); qcerr != nil { - return nil, err + return err } continue } err = indexTable(table) if err != nil { - return nil, err + return err } if rv.Alias != nil { aliasMap[*rv.Alias.Aliasname] = fqn @@ -91,11 +91,84 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } - return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) + return fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) } + return nil +} +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet) ([]Parameter, error) { + c := comp.catalog + + // resolve a table for an embed var a []Parameter for _, ref := range args { + if ref.ref.IsSqlcDynamic { + defaultP := named.NewInferredParam("offset", true) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) + a = append(a, Parameter{ + Number: ref.ref.Number, + Column: &Column{ + Name: p.Name(), + DataType: "DynamicSql", + NotNull: p.NotNull(), + IsNamedParam: isNamed, + IsSqlcDynamic: true, + }, + }) + continue + } + aliasMap := map[string]*ast.TableName{} + // TODO: Deprecate defaultTable + var defaultTable *ast.TableName + var tables []*ast.TableName + typeMap := map[string]map[string]map[string]*catalog.Column{} + indexTable := func(table catalog.Table) error { + tables = append(tables, table.Rel) + if defaultTable == nil { + defaultTable = table.Rel + } + schema := table.Rel.Schema + if schema == "" { + schema = c.DefaultSchema + } + if _, exists := typeMap[schema]; !exists { + typeMap[schema] = map[string]map[string]*catalog.Column{} + } + typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{} + for _, c := range table.Columns { + cc := c + typeMap[schema][table.Rel.Name][c.Name] = cc + } + return nil + } + for _, rv := range ref.rvs { + if rv == nil || rv.Relname == nil { + continue + } + fqn, err := ParseTableName(rv) + if err != nil { + return nil, err + } + if _, found := aliasMap[fqn.Name]; found { + continue + } + table, err := c.GetTable(fqn) + if err != nil { + // If the table name doesn't exist, fisrt check if it's a CTE + if _, qcerr := qc.GetTable(fqn); qcerr != nil { + return nil, err + } + continue + } + err = indexTable(table) + if err != nil { + return nil, err + } + if rv.Alias != nil { + aliasMap[*rv.Alias.Aliasname] = fqn + } + } + switch n := ref.parent.(type) { case *limitOffset: @@ -150,11 +223,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: dataType, - IsNamedParam: isNamed, - NotNull: p.NotNull(), - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: dataType, + IsNamedParam: isNamed, + NotNull: p.NotNull(), + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) continue @@ -213,17 +287,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Length: c.Length, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Length: c.Length, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -280,15 +355,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: namePrefix + p.Name(), - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: namePrefix + p.Name(), + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -353,11 +429,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: "any", - IsNamedParam: isNamed, - NotNull: p.NotNull(), - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: "any", + IsNamedParam: isNamed, + NotNull: p.NotNull(), + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) continue @@ -394,11 +471,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - DataType: dataType(paramType), - NotNull: p.NotNull(), - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + DataType: dataType(paramType), + NotNull: p.NotNull(), + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } @@ -456,17 +534,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: &ast.TableName{Schema: schema, Name: rel}, - Length: c.Length, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: &ast.TableName{Schema: schema, Name: rel}, + Length: c.Length, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } else { @@ -567,16 +646,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, a = append(a, Parameter{ Number: number, Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: c.IsNotNull, - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: table, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: c.IsNotNull, + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + IsSqlcDynamic: p.IsSqlcDynamic(), }, }) } diff --git a/internal/endtoend/testdata/dynamic/mysql/go/db.go b/internal/endtoend/testdata/dynamic/mysql/go/db.go new file mode 100644 index 0000000000..43bb13795a --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/go/db.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} +type DynamicSql interface { + ToSql() (string, []interface{}) +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dynamic/mysql/go/models.go b/internal/endtoend/testdata/dynamic/mysql/go/models.go new file mode 100644 index 0000000000..b5f5c9b7ed --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/go/models.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "database/sql" +) + +type Order struct { + ID int32 + Price string + UserID int32 +} + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 + JobStatus string +} diff --git a/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go new file mode 100644 index 0000000000..5516e45400 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/go/query.sql.go @@ -0,0 +1,144 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" + "strings" +) + +const selectUsers = `-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > ? +` + +type SelectUsersRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, error) { + rows, err := q.db.QueryContext(ctx, selectUsers, age) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersRow + for rows.Next() { + var i SelectUsersRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic = `-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > ? AND /*DYNAMIC:dynamic*/? +` + +type SelectUsersDynamicParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { + query := selectUsersDynamic + var queryParams []interface{} + curNumb := 2 + var replaceText string + var args []interface{} + queryParams = append(queryParams, arg.Age) + replaceText, args = arg.Dynamic.ToSql() + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicRow + for rows.Next() { + var i SelectUsersDynamicRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE /*DYNAMIC:dynamic*/? AND + age > ? AND + job_status = ? +` + +type SelectUsersDynamic2Params struct { + Dynamic DynamicSql + Age int32 + Status string +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + query := selectUsersDynamic2 + var queryParams []interface{} + curNumb := 3 + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql() + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/?", replaceText) + queryParams = append(queryParams, args...) + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/mysql/query.sql b/internal/endtoend/testdata/dynamic/mysql/query.sql new file mode 100644 index 0000000000..8c3d94dff3 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/query.sql @@ -0,0 +1,25 @@ +CREATE TABLE users ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL, + job_status varchar(10) NOT NULL +); + +CREATE TABLE orders ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + price DECIMAL(13, 4) NOT NULL, + user_id integer NOT NULL +); + +-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); +-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE sqlc.dynamic('dynamic') AND + age > sqlc.arg(age) AND + job_status = sqlc.arg(status) ; diff --git a/internal/endtoend/testdata/dynamic/mysql/sqlc.json b/internal/endtoend/testdata/dynamic/mysql/sqlc.json new file mode 100644 index 0000000000..bfbd23e211 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "query.sql", + "queries": "query.sql", + "engine": "mysql" + } + ] +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go new file mode 100644 index 0000000000..abb926d24d --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/db.go @@ -0,0 +1,35 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} +type DynamicSql interface { + ToSql(int) (string, []interface{}) +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/models.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/models.go new file mode 100644 index 0000000000..512742184a --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/models.go @@ -0,0 +1,25 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "database/sql" + + "github.com/jackc/pgtype" +) + +type Order struct { + ID int32 + Price pgtype.Numeric + UserID int32 +} + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 + JobStatus string +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go new file mode 100644 index 0000000000..65c8566795 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/go/query.sql.go @@ -0,0 +1,204 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" + "strings" +) + +const selectUser = `-- name: SelectUser :one +SELECT first_name, last_name FROM users WHERE /*DYNAMIC:dynamic*/$1 +` + +type SelectUserRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUser(ctx context.Context, dynamic DynamicSql) (SelectUserRow, error) { + replaceText, queryParams := dynamic.ToSql(1) + query := strings.ReplaceAll(selectUser, "/*DYNAMIC:dynamic*/$1", replaceText) + row := q.db.QueryRow(ctx, query, queryParams...) + var i SelectUserRow + err := row.Scan(&i.FirstName, &i.LastName) + return i, err +} + +const selectUsers = `-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > $1 +` + +type SelectUsersRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, error) { + rows, err := q.db.Query(ctx, selectUsers, age) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersRow + for rows.Next() { + var i SelectUsersRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic = `-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > $1 AND /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamicParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { + queryParams := []interface{}{arg.Age} + curNumb := 2 + query := selectUsersDynamic + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicRow + for rows.Next() { + var i SelectUsersDynamicRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamic2Params struct { + Age int32 + Status string + Dynamic DynamicSql +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + queryParams := []interface{}{arg.Age, arg.Status} + curNumb := 3 + query := selectUsersDynamic2 + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicMulti = `-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +ORDER BY /*DYNAMIC:order*/$1 +` + +type SelectUsersDynamicMultiParams struct { + Age int32 + Status string + Dynamic DynamicSql + Order DynamicSql +} + +type SelectUsersDynamicMultiRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDynamicMultiParams) ([]SelectUsersDynamicMultiRow, error) { + queryParams := []interface{}{arg.Age, arg.Status} + curNumb := 3 + query := selectUsersDynamicMulti + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + replaceText, args = arg.Order.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicMultiRow + for rows.Next() { + var i SelectUsersDynamicMultiRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/query.sql b/internal/endtoend/testdata/dynamic/pgx/v4/query.sql new file mode 100644 index 0000000000..281e23a395 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/query.sql @@ -0,0 +1,37 @@ +CREATE TABLE users ( + id int PRIMARY KEY, + first_name text NOT NULL, + last_name text, + age int NOT NULL, + job_status text NOT NULL +); + +CREATE TABLE orders ( + id int PRIMARY KEY, + price numeric NOT NULL, + user_id int NOT NULL +); + +-- name: SelectUser :one +SELECT first_name, last_name FROM users WHERE sqlc.dynamic('dynamic'); + +-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); + +-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic') +ORDER BY sqlc.dynamic('order'); diff --git a/internal/endtoend/testdata/dynamic/pgx/v4/sqlc.json b/internal/endtoend/testdata/dynamic/pgx/v4/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v4/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go new file mode 100644 index 0000000000..fcbd7c7201 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/db.go @@ -0,0 +1,35 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} +type DynamicSql interface { + ToSql(int) (string, []interface{}) +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/models.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/models.go new file mode 100644 index 0000000000..d1c0243acb --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/models.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type Order struct { + ID int32 + Price pgtype.Numeric + UserID int32 +} + +type User struct { + ID int32 + FirstName string + LastName pgtype.Text + Age int32 + JobStatus string +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go new file mode 100644 index 0000000000..89cbd6e8a0 --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/go/query.sql.go @@ -0,0 +1,187 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" + "strings" + + "github.com/jackc/pgx/v5/pgtype" +) + +const selectUsers = `-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > $1 +` + +type SelectUsersRow struct { + FirstName string + LastName pgtype.Text +} + +func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, error) { + rows, err := q.db.Query(ctx, selectUsers, age) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersRow + for rows.Next() { + var i SelectUsersRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic = `-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > $1 AND /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamicParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicRow struct { + FirstName string + LastName pgtype.Text +} + +func (q *Queries) SelectUsersDynamic(ctx context.Context, arg SelectUsersDynamicParams) ([]SelectUsersDynamicRow, error) { + queryParams := []interface{}{arg.Age} + curNumb := 2 + query := selectUsersDynamic + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicRow + for rows.Next() { + var i SelectUsersDynamicRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamic2Params struct { + Age int32 + Status string + Dynamic DynamicSql +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName pgtype.Text +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + queryParams := []interface{}{arg.Age, arg.Status} + curNumb := 3 + query := selectUsersDynamic2 + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicMulti = `-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +ORDER BY /*DYNAMIC:order*/$1 +` + +type SelectUsersDynamicMultiParams struct { + Age int32 + Status string + Dynamic DynamicSql + Order DynamicSql +} + +type SelectUsersDynamicMultiRow struct { + FirstName string + LastName pgtype.Text +} + +func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDynamicMultiParams) ([]SelectUsersDynamicMultiRow, error) { + queryParams := []interface{}{arg.Age, arg.Status} + curNumb := 3 + query := selectUsersDynamicMulti + var replaceText string + var args []interface{} + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + replaceText, args = arg.Order.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.Query(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicMultiRow + for rows.Next() { + var i SelectUsersDynamicMultiRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/query.sql b/internal/endtoend/testdata/dynamic/pgx/v5/query.sql new file mode 100644 index 0000000000..dca8604f7a --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/query.sql @@ -0,0 +1,34 @@ +CREATE TABLE users ( + id int PRIMARY KEY, + first_name text NOT NULL, + last_name text, + age int NOT NULL, + job_status text NOT NULL +); + +CREATE TABLE orders ( + id int PRIMARY KEY, + price numeric NOT NULL, + user_id int NOT NULL +); + +-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); + +-- name: SelectUsersDynamic :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic') +ORDER BY sqlc.dynamic('order'); diff --git a/internal/endtoend/testdata/dynamic/pgx/v5/sqlc.json b/internal/endtoend/testdata/dynamic/pgx/v5/sqlc.json new file mode 100644 index 0000000000..6645ccbd1b --- /dev/null +++ b/internal/endtoend/testdata/dynamic/pgx/v5/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v5", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/db.go b/internal/endtoend/testdata/dynamic/stdlib/go/db.go new file mode 100644 index 0000000000..c473940b6e --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/go/db.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} +type DynamicSql interface { + ToSql(int) (string, []interface{}) +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/models.go b/internal/endtoend/testdata/dynamic/stdlib/go/models.go new file mode 100644 index 0000000000..b5f5c9b7ed --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/go/models.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "database/sql" +) + +type Order struct { + ID int32 + Price string + UserID int32 +} + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 + JobStatus string +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go new file mode 100644 index 0000000000..3e97375afb --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/go/query.sql.go @@ -0,0 +1,309 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" + "strings" +) + +const selectUsers = `-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > $1 +` + +type SelectUsersRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsers(ctx context.Context, age int32) ([]SelectUsersRow, error) { + rows, err := q.db.QueryContext(ctx, selectUsers, age) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersRow + for rows.Next() { + var i SelectUsersRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamic2 = `-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamic2Params struct { + Age int32 + Status string + Dynamic DynamicSql +} + +type SelectUsersDynamic2Row struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamic2(ctx context.Context, arg SelectUsersDynamic2Params) ([]SelectUsersDynamic2Row, error) { + query := selectUsersDynamic2 + var queryParams []interface{} + curNumb := 3 + var replaceText string + var args []interface{} + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamic2Row + for rows.Next() { + var i SelectUsersDynamic2Row + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicA = `-- name: SelectUsersDynamicA :many +SELECT first_name, last_name FROM users WHERE age > $1 AND /*DYNAMIC:dynamic*/$1 +` + +type SelectUsersDynamicAParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicARow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamicA(ctx context.Context, arg SelectUsersDynamicAParams) ([]SelectUsersDynamicARow, error) { + query := selectUsersDynamicA + var queryParams []interface{} + curNumb := 2 + var replaceText string + var args []interface{} + queryParams = append(queryParams, arg.Age) + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicARow + for rows.Next() { + var i SelectUsersDynamicARow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicB = `-- name: SelectUsersDynamicB :many +SELECT first_name, last_name FROM users WHERE /*DYNAMIC:dynamic*/$1 AND age > $1 +` + +type SelectUsersDynamicBParams struct { + Age int32 + Dynamic DynamicSql +} + +type SelectUsersDynamicBRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamicB(ctx context.Context, arg SelectUsersDynamicBParams) ([]SelectUsersDynamicBRow, error) { + query := selectUsersDynamicB + var queryParams []interface{} + curNumb := 2 + var replaceText string + var args []interface{} + queryParams = append(queryParams, arg.Age) + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicBRow + for rows.Next() { + var i SelectUsersDynamicBRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicMulti = `-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > $1 AND + job_status = $2 AND + /*DYNAMIC:dynamic*/$1 +ORDER BY /*DYNAMIC:order*/$1 +` + +type SelectUsersDynamicMultiParams struct { + Age int32 + Status string + Dynamic DynamicSql + Order DynamicSql +} + +type SelectUsersDynamicMultiRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamicMulti(ctx context.Context, arg SelectUsersDynamicMultiParams) ([]SelectUsersDynamicMultiRow, error) { + query := selectUsersDynamicMulti + var queryParams []interface{} + curNumb := 3 + var replaceText string + var args []interface{} + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + replaceText, args = arg.Order.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicMultiRow + for rows.Next() { + var i SelectUsersDynamicMultiRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUsersDynamicMultiB = `-- name: SelectUsersDynamicMultiB :many +SELECT first_name, last_name +FROM users +WHERE /*DYNAMIC:dynamic*/$1 AND + age > $1 AND + job_status = $2 +ORDER BY /*DYNAMIC:order*/$1 +` + +type SelectUsersDynamicMultiBParams struct { + Age int32 + Status string + Dynamic DynamicSql + Order DynamicSql +} + +type SelectUsersDynamicMultiBRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) SelectUsersDynamicMultiB(ctx context.Context, arg SelectUsersDynamicMultiBParams) ([]SelectUsersDynamicMultiBRow, error) { + query := selectUsersDynamicMultiB + var queryParams []interface{} + curNumb := 3 + var replaceText string + var args []interface{} + queryParams = append(queryParams, arg.Age) + queryParams = append(queryParams, arg.Status) + replaceText, args = arg.Dynamic.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:dynamic*/$1", replaceText) + queryParams = append(queryParams, args...) + replaceText, args = arg.Order.ToSql(curNumb) + curNumb += len(args) + query = strings.ReplaceAll(query, "/*DYNAMIC:order*/$1", replaceText) + queryParams = append(queryParams, args...) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectUsersDynamicMultiBRow + for rows.Next() { + var i SelectUsersDynamicMultiBRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dynamic/stdlib/query.sql b/internal/endtoend/testdata/dynamic/stdlib/query.sql new file mode 100644 index 0000000000..12abf6ab8e --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/query.sql @@ -0,0 +1,44 @@ +CREATE TABLE users ( + id int PRIMARY KEY, + first_name text NOT NULL, + last_name text, + age int NOT NULL, + job_status text NOT NULL +); + +CREATE TABLE orders ( + id int PRIMARY KEY, + price numeric NOT NULL, + user_id int NOT NULL +); + +-- name: SelectUsers :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age); + +-- name: SelectUsersDynamicA :many +SELECT first_name, last_name FROM users WHERE age > sqlc.arg(age) AND sqlc.dynamic('dynamic'); +-- name: SelectUsersDynamicB :many +SELECT first_name, last_name FROM users WHERE sqlc.dynamic('dynamic') AND age > sqlc.arg(age); + +-- name: SelectUsersDynamic2 :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic'); + +-- name: SelectUsersDynamicMulti :many +SELECT first_name, last_name +FROM users +WHERE age > sqlc.arg(age) AND + job_status = sqlc.arg(status) AND + sqlc.dynamic('dynamic') +ORDER BY sqlc.dynamic('order'); + +-- name: SelectUsersDynamicMultiB :many +SELECT first_name, last_name +FROM users +WHERE sqlc.dynamic('dynamic') AND + age > sqlc.arg(age) AND + job_status = sqlc.arg(status) +ORDER BY sqlc.dynamic('order'); diff --git a/internal/endtoend/testdata/dynamic/stdlib/sqlc.json b/internal/endtoend/testdata/dynamic/stdlib/sqlc.json new file mode 100644 index 0000000000..696ed223db --- /dev/null +++ b/internal/endtoend/testdata/dynamic/stdlib/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "query.sql", + "queries": "query.sql", + "engine": "postgresql" + } + ] +} diff --git a/internal/endtoend/testdata/params_location/mysql/go/query.sql.go b/internal/endtoend/testdata/params_location/mysql/go/query.sql.go index ba60001e0b..e1a24f3303 100644 --- a/internal/endtoend/testdata/params_location/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/params_location/mysql/go/query.sql.go @@ -255,3 +255,40 @@ func (q *Queries) ListUsersWithLimit(ctx context.Context, limit int32) ([]ListUs } return items, nil } + +const searchByName = `-- name: SearchByName :many +SELECT id, first_name, last_name, age, job_status FROM users WHERE (first_name = ? OR last_name = ?) +` + +type SearchByNameParams struct { + Name sql.NullString +} + +func (q *Queries) SearchByName(ctx context.Context, arg SearchByNameParams) ([]User, error) { + rows, err := q.db.QueryContext(ctx, searchByName, arg.Name, arg.Name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Age, + &i.JobStatus, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/params_location/mysql/query.sql b/internal/endtoend/testdata/params_location/mysql/query.sql index cfa804c252..0482979098 100644 --- a/internal/endtoend/testdata/params_location/mysql/query.sql +++ b/internal/endtoend/testdata/params_location/mysql/query.sql @@ -31,3 +31,6 @@ SELECT * FROM users WHERE (job_status = 'APPLIED' OR job_status = 'PENDING') AND id > ? ORDER BY id LIMIT ?; + +/* name: SearchByName :many */ +SELECT * FROM users WHERE (first_name = sqlc.narg(name) OR last_name = sqlc.narg(name)); diff --git a/internal/endtoend/testdata/process_plugin_sqlc_gen_json/exec.json b/internal/endtoend/testdata/process_plugin_sqlc_gen_json/exec.json index 2e996ca79d..572ba3e887 100644 --- a/internal/endtoend/testdata/process_plugin_sqlc_gen_json/exec.json +++ b/internal/endtoend/testdata/process_plugin_sqlc_gen_json/exec.json @@ -1,3 +1,4 @@ { + "process": "sqlc-gen-json", "contexts": ["base"] } diff --git a/internal/endtoend/testdata/subquery_with_where/go/db.go b/internal/endtoend/testdata/subquery_with_where/go/db.go new file mode 100644 index 0000000000..a457fb76b2 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/subquery_with_where/go/models.go b/internal/endtoend/testdata/subquery_with_where/go/models.go new file mode 100644 index 0000000000..56e82be769 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/models.go @@ -0,0 +1,19 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "database/sql" +) + +type Bar struct { + A int32 + Alias sql.NullString +} + +type Foo struct { + A int32 + Name sql.NullString +} diff --git a/internal/endtoend/testdata/subquery_with_where/go/query.sql.go b/internal/endtoend/testdata/subquery_with_where/go/query.sql.go new file mode 100644 index 0000000000..bb21b9b5cc --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/query.sql.go @@ -0,0 +1,53 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const subquery = `-- name: Subquery :many +SELECT + a, + name, + (SELECT alias FROM bar WHERE bar.a=foo.a AND alias = $1 ORDER BY bar.a DESC limit 1) as alias +FROM FOO WHERE a = $2 +` + +type SubqueryParams struct { + Alias sql.NullString + A int32 +} + +type SubqueryRow struct { + A int32 + Name sql.NullString + Alias sql.NullString +} + +func (q *Queries) Subquery(ctx context.Context, arg SubqueryParams) ([]SubqueryRow, error) { + rows, err := q.db.QueryContext(ctx, subquery, arg.Alias, arg.A) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SubqueryRow + for rows.Next() { + var i SubqueryRow + if err := rows.Scan(&i.A, &i.Name, &i.Alias); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/subquery_with_where/query.sql b/internal/endtoend/testdata/subquery_with_where/query.sql new file mode 100644 index 0000000000..12e6dfaf3f --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/query.sql @@ -0,0 +1,9 @@ +CREATE TABLE foo (a int not null, name text); +CREATE TABLE bar (a int not null, alias text); + +-- name: Subquery :many +SELECT + a, + name, + (SELECT alias FROM bar WHERE bar.a=foo.a AND alias = $1 ORDER BY bar.a DESC limit 1) as alias +FROM FOO WHERE a = $2; diff --git a/internal/endtoend/testdata/subquery_with_where/sqlc.json b/internal/endtoend/testdata/subquery_with_where/sqlc.json new file mode 100644 index 0000000000..c72b6132d5 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/plugin/codegen.pb.go b/internal/plugin/codegen.pb.go index 7b3347c6e1..6dcdf000e5 100644 --- a/internal/plugin/codegen.pb.go +++ b/internal/plugin/codegen.pb.go @@ -7,10 +7,11 @@ package plugin import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -851,15 +852,16 @@ type Column struct { IsNamedParam bool `protobuf:"varint,7,opt,name=is_named_param,json=isNamedParam,proto3" json:"is_named_param,omitempty"` IsFuncCall bool `protobuf:"varint,8,opt,name=is_func_call,json=isFuncCall,proto3" json:"is_func_call,omitempty"` // XXX: Figure out what PostgreSQL calls `foo.id` - Scope string `protobuf:"bytes,9,opt,name=scope,proto3" json:"scope,omitempty"` - Table *Identifier `protobuf:"bytes,10,opt,name=table,proto3" json:"table,omitempty"` - TableAlias string `protobuf:"bytes,11,opt,name=table_alias,json=tableAlias,proto3" json:"table_alias,omitempty"` - Type *Identifier `protobuf:"bytes,12,opt,name=type,proto3" json:"type,omitempty"` - IsSqlcSlice bool `protobuf:"varint,13,opt,name=is_sqlc_slice,json=isSqlcSlice,proto3" json:"is_sqlc_slice,omitempty"` - EmbedTable *Identifier `protobuf:"bytes,14,opt,name=embed_table,json=embedTable,proto3" json:"embed_table,omitempty"` - OriginalName string `protobuf:"bytes,15,opt,name=original_name,json=originalName,proto3" json:"original_name,omitempty"` - Unsigned bool `protobuf:"varint,16,opt,name=unsigned,proto3" json:"unsigned,omitempty"` - ArrayDims int32 `protobuf:"varint,17,opt,name=array_dims,json=arrayDims,proto3" json:"array_dims,omitempty"` + Scope string `protobuf:"bytes,9,opt,name=scope,proto3" json:"scope,omitempty"` + Table *Identifier `protobuf:"bytes,10,opt,name=table,proto3" json:"table,omitempty"` + TableAlias string `protobuf:"bytes,11,opt,name=table_alias,json=tableAlias,proto3" json:"table_alias,omitempty"` + Type *Identifier `protobuf:"bytes,12,opt,name=type,proto3" json:"type,omitempty"` + IsSqlcSlice bool `protobuf:"varint,13,opt,name=is_sqlc_slice,json=isSqlcSlice,proto3" json:"is_sqlc_slice,omitempty"` + IsSqlcDynamic bool `protobuf:"varint,13,opt,name=is_sqlc_dynamic,json=isSqlcDynamic,proto3" json:"is_sqlc_dynamic,omitempty"` + EmbedTable *Identifier `protobuf:"bytes,14,opt,name=embed_table,json=embedTable,proto3" json:"embed_table,omitempty"` + OriginalName string `protobuf:"bytes,15,opt,name=original_name,json=originalName,proto3" json:"original_name,omitempty"` + Unsigned bool `protobuf:"varint,16,opt,name=unsigned,proto3" json:"unsigned,omitempty"` + ArrayDims int32 `protobuf:"varint,17,opt,name=array_dims,json=arrayDims,proto3" json:"array_dims,omitempty"` } func (x *Column) Reset() { @@ -977,7 +979,6 @@ func (x *Column) GetIsSqlcSlice() bool { } return false } - func (x *Column) GetEmbedTable() *Identifier { if x != nil { return x.EmbedTable diff --git a/internal/sql/ast/param_ref.go b/internal/sql/ast/param_ref.go index 8bd724993d..11e3a79f32 100644 --- a/internal/sql/ast/param_ref.go +++ b/internal/sql/ast/param_ref.go @@ -3,9 +3,10 @@ package ast import "fmt" type ParamRef struct { - Number int - Location int - Dollar bool + Number int + Location int + Dollar bool + IsSqlcDynamic bool } func (n *ParamRef) Pos() int { diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 0943379f03..00d6d2fdfc 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1818,6 +1818,9 @@ func Walk(f Visitor, node ast.Node) { } case *ast.SelectStmt: + if n.FromClause != nil { + Walk(f, n.FromClause) + } if n.DistinctClause != nil { Walk(f, n.DistinctClause) } @@ -1827,9 +1830,6 @@ func Walk(f Visitor, node ast.Node) { if n.TargetList != nil { Walk(f, n.TargetList) } - if n.FromClause != nil { - Walk(f, n.FromClause) - } if n.WhereClause != nil { Walk(f, n.WhereClause) } @@ -2032,15 +2032,15 @@ func Walk(f Visitor, node ast.Node) { if n.Relations != nil { Walk(f, n.Relations) } + if n.FromClause != nil { + Walk(f, n.FromClause) + } if n.TargetList != nil { Walk(f, n.TargetList) } if n.WhereClause != nil { Walk(f, n.WhereClause) } - if n.FromClause != nil { - Walk(f, n.FromClause) - } if n.LimitCount != nil { Walk(f, n.LimitCount) } diff --git a/internal/sql/named/is.go b/internal/sql/named/is.go index d53c1d9905..5d3d1d1544 100644 --- a/internal/sql/named/is.go +++ b/internal/sql/named/is.go @@ -16,7 +16,8 @@ func IsParamFunc(node ast.Node) bool { return false } - isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "arg" || call.Func.Name == "narg" || call.Func.Name == "slice") + // TODO + isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "arg" || call.Func.Name == "narg" || call.Func.Name == "slice" || call.Func.Name == "dynamic") return isValid } diff --git a/internal/sql/named/param.go b/internal/sql/named/param.go index 42f9b855a3..7667008c8b 100644 --- a/internal/sql/named/param.go +++ b/internal/sql/named/param.go @@ -42,9 +42,10 @@ func (n nullability) String() string { // - named parameter operator @param // - named parameter function calls sqlc.arg(param) type Param struct { - name string - nullability nullability - isSqlcSlice bool + name string + nullability nullability + isSqlcSlice bool + isSqlcDynamic bool } // NewParam builds a new params with unspecified nullability @@ -72,6 +73,11 @@ func NewSqlcSlice(name string) Param { return Param{name: name, nullability: nullUnspecified, isSqlcSlice: true} } +// NewSqlcDynamic is a sqlc.dynamic() parameter. +func NewSqlcDynamic(name string) Param { + return Param{name: name, nullability: notNullable, isSqlcDynamic: true} +} + // Name is the user defined name to use for this parameter func (p Param) Name() string { return p.name @@ -113,6 +119,11 @@ func (p Param) IsSqlcSlice() bool { return p.isSqlcSlice } +// IsSlice returns whether this param is a sqlc.dynamic() param. +func (p Param) IsSqlcDynamic() bool { + return p.isSqlcDynamic +} + // mergeParam creates a new param from 2 partially specified params // If the parameters have different names, the first is preferred func mergeParam(a, b Param) Param { @@ -122,8 +133,9 @@ func mergeParam(a, b Param) Param { } return Param{ - name: name, - nullability: a.nullability | b.nullability, - isSqlcSlice: a.isSqlcSlice || b.isSqlcSlice, + name: name, + nullability: a.nullability | b.nullability, + isSqlcSlice: a.isSqlcSlice || b.isSqlcSlice, + isSqlcDynamic: a.isSqlcDynamic && b.isSqlcDynamic, } } diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index d1ea1a22cc..232b0a6085 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -61,6 +61,8 @@ func paramFromFuncCall(call *ast.FuncCall) (named.Param, string) { param = named.NewUserNullableParam(paramName) case "slice": param = named.NewSqlcSlice(paramName) + case "dynamic": + param = named.NewSqlcDynamic(paramName) default: param = named.NewParam(paramName) } @@ -95,34 +97,37 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamFunc(node): fun := node.(*ast.FuncCall) param, origText := paramFromFuncCall(fun) - argn := allParams.Add(param) - cr.Replace(&ast.ParamRef{ - Number: argn, - Location: fun.Location, - }) - - var replace string - if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { - if param.IsSqlcSlice() { - // This sequence is also replicated in internal/codegen/golang.Field - // since it's needed during template generation for replacement - replace = fmt.Sprintf(`/*SLICE:%s*/?`, param.Name()) - } else { - if engine == config.EngineSQLite { - replace = fmt.Sprintf("?%d", argn) + if !param.IsSqlcDynamic() { + argn := allParams.Add(param) + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: fun.Location, + IsSqlcDynamic: param.IsSqlcDynamic(), + }) + + var replace string + if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { + if param.IsSqlcSlice() { + // This sequence is also replicated in internal/codegen/golang.Field + // since it's needed during template generation for replacement + replace = fmt.Sprintf(`/*SLICE:%s*/?`, param.Name()) } else { - replace = "?" + if engine == config.EngineSQLite { + replace = fmt.Sprintf("?%d", argn) + } else { + replace = "?" + } } + } else { + replace = fmt.Sprintf("$%d", argn) } - } else { - replace = fmt.Sprintf("$%d", argn) - } - edits = append(edits, source.Edit{ - Location: fun.Location - raw.StmtLocation, - Old: origText, - New: replace, - }) + edits = append(edits, source.Edit{ + Location: fun.Location - raw.StmtLocation, + Old: origText, + New: replace, + }) + } return false case isNamedParamSignCast(node): @@ -187,6 +192,34 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, return true } }, nil) - + node = astutils.Apply(node, func(cr *astutils.Cursor) bool { + node := cr.Node() + if named.IsParamFunc(node) { + fun := node.(*ast.FuncCall) + param, origText := paramFromFuncCall(fun) + if param.IsSqlcDynamic() { + argn := allParams.Add(param) + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: fun.Location, + IsSqlcDynamic: param.IsSqlcDynamic(), + }) + + var replace string + if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { + replace = fmt.Sprintf(`/*DYNAMIC:%s*/?`, param.Name()) + } else { + replace = fmt.Sprintf(`/*DYNAMIC:%s*/$1`, param.Name()) + } + edits = append(edits, source.Edit{ + Location: fun.Location - raw.StmtLocation, + Old: origText, + New: replace, + }) + } + return false + } + return true + }, nil) return node.(*ast.RawStmt), allParams, edits } diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index 1182051d20..ce3a0f1aca 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -29,7 +29,7 @@ func (v *sqlcFuncVisitor) Visit(node ast.Node) astutils.Visitor { // Custom validation for sqlc.arg, sqlc.narg and sqlc.slice // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed") { + if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed" || fn.Name == "dynamic") { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil }