From 870b4ad30c121cdbcc10009326d143457fc1ab9d Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Tue, 25 Oct 2022 23:30:32 -0700 Subject: [PATCH 1/2] compiler: Move Kotlin parameter logic into codegen --- internal/cmd/generate.go | 3 -- internal/codegen/kotlin/gen.go | 89 +++++++++++++++++++++++++--------- internal/compiler/parse.go | 15 ++---- internal/opts/parser.go | 3 +- 4 files changed, 71 insertions(+), 39 deletions(-) diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 0dca81de06..49334893d2 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -198,9 +198,6 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer lang = "golang" case sql.Gen.Kotlin != nil: - if sql.Engine == config.EnginePostgreSQL { - parseOpts.UsePositionalParameters = true - } lang = "kotlin" name = combo.Kotlin.Package diff --git a/internal/codegen/kotlin/gen.go b/internal/codegen/kotlin/gen.go index cea6e1d908..71242969da 100644 --- a/internal/codegen/kotlin/gen.go +++ b/internal/codegen/kotlin/gen.go @@ -8,6 +8,7 @@ import ( "fmt" "regexp" "sort" + "strconv" "strings" "text/template" @@ -32,25 +33,24 @@ type Enum struct { } type Field struct { + ID int Name string Type ktType Comment string } type Struct struct { - Table plugin.Identifier - Name string - Fields []Field - JDBCParamBindings []Field - Comment string + Table plugin.Identifier + Name string + Fields []Field + Comment string } type QueryValue struct { - Emit bool - Name string - Struct *Struct - Typ ktType - JDBCParamBindCount int + Emit bool + Name string + Struct *Struct + Typ ktType } func (v QueryValue) EmitStruct() bool { @@ -102,7 +102,8 @@ func jdbcSet(t ktType, idx int, name string) string { } type Params struct { - Struct *Struct + Struct *Struct + binding []int } func (v Params) isEmpty() bool { @@ -114,9 +115,19 @@ func (v Params) Args() string { return "" } var out []string - for _, f := range v.Struct.Fields { + fields := v.Struct.Fields + for _, f := range fields { out = append(out, f.Name+": "+f.Type.String()) } + if len(v.binding) > 0 { + lookup := map[int]int{} + for i, v := range v.binding { + lookup[v] = i + } + sort.Slice(out, func(i, j int) bool { + return lookup[fields[i].ID] < lookup[fields[j].ID] + }) + } if len(out) < 3 { return strings.Join(out, ", ") } @@ -128,8 +139,15 @@ func (v Params) Bindings() string { return "" } var out []string - for i, f := range v.Struct.JDBCParamBindings { - out = append(out, jdbcSet(f.Type, i+1, f.Name)) + if len(v.binding) > 0 { + for i, idx := range v.binding { + f := v.Struct.Fields[idx-1] + out = append(out, jdbcSet(f.Type, i+1, f.Name)) + } + } else { + for i, f := range v.Struct.Fields { + out = append(out, jdbcSet(f.Type, i+1, f.Name)) + } } return indent(strings.Join(out, "\n"), 10, 0) } @@ -387,8 +405,7 @@ func ktColumnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColu idSeen := map[int]Field{} nameSeen := map[string]int{} for _, c := range columns { - if binding, ok := idSeen[c.id]; ok { - gs.JDBCParamBindings = append(gs.JDBCParamBindings, binding) + if _, ok := idSeen[c.id]; ok { continue } fieldName := memberName(namer(c.Column, c.id), req.Settings) @@ -396,11 +413,11 @@ func ktColumnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColu fieldName = fmt.Sprintf("%s_%d", fieldName, v+1) } field := Field{ + ID: c.id, Name: fieldName, Type: makeType(req, c.Column), } gs.Fields = append(gs.Fields, field) - gs.JDBCParamBindings = append(gs.JDBCParamBindings, field) nameSeen[c.Name]++ idSeen[c.id] = field } @@ -438,11 +455,31 @@ var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$\d+\b`) // HACK: jdbc doesn't support numbered parameters, so we need to transform them to question marks... // But there's no access to the SQL parser here, so we just do a dumb regexp replace instead. This won't work if // the literal strings contain matching values, but good enough for a prototype. -func jdbcSQL(s, engine string) string { - if engine == "postgresql" { - return postgresPlaceholderRegexp.ReplaceAllString(s, "?") +func jdbcSQL(s, engine string) (string, []string) { + if engine != "postgresql" { + return s, nil } - return s + var args []string + q := postgresPlaceholderRegexp.ReplaceAllStringFunc(s, func(placeholder string) string { + args = append(args, placeholder) + return "?" + }) + return q, args +} + +func parseInts(s []string) ([]int, error) { + if len(s) == 0 { + return nil, nil + } + var refs []int + for _, v := range s { + i, err := strconv.Atoi(strings.TrimPrefix(v, "$")) + if err != nil { + return nil, err + } + refs = append(refs, i) + } + return refs, nil } func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) { @@ -458,6 +495,11 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) return nil, errors.New("Support for CopyFrom in Kotlin is not implemented") } + ql, args := jdbcSQL(query.Text, req.Settings.Engine) + refs, err := parseInts(args) + if err != nil { + return nil, fmt.Errorf("Invalid parameter reference: %w", err) + } gq := Query{ Cmd: query.Cmd, ClassName: strings.Title(query.Name), @@ -465,7 +507,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) FieldName: sdk.LowerTitle(query.Name) + "Stmt", MethodName: sdk.LowerTitle(query.Name), SourceName: query.Filename, - SQL: jdbcSQL(query.Text, req.Settings.Engine), + SQL: ql, Comments: query.Comments, } @@ -478,7 +520,8 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) } params := ktColumnsToStruct(req, gq.ClassName+"Bindings", cols, ktParamName) gq.Arg = Params{ - Struct: params, + Struct: params, + binding: refs, } if len(query.Columns) == 1 { diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 6df3c0e1d3..4eb46d04d3 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -89,18 +89,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err != nil { return nil, err } - if o.UsePositionalParameters { - edits, err = rewriteNumberedParameters(refs, raw, rawSQL) - if err != nil { - return nil, err - } + refs = uniqueParamRefs(refs, dollar) + if c.conf.Engine == config.EngineMySQL || !dollar { + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location }) } else { - refs = uniqueParamRefs(refs, dollar) - if c.conf.Engine == config.EngineMySQL || !dollar { - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location }) - } else { - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) - } + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) } qc, err := buildQueryCatalog(c.catalog, raw.Stmt) if err != nil { diff --git a/internal/opts/parser.go b/internal/opts/parser.go index 7ce464be2c..d6fb399552 100644 --- a/internal/opts/parser.go +++ b/internal/opts/parser.go @@ -1,6 +1,5 @@ package opts type Parser struct { - UsePositionalParameters bool - Debug Debug + Debug Debug } From 5b7f7102b02af0fbe59e3742b5c736cc3ee94a84 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Tue, 25 Oct 2022 23:34:28 -0700 Subject: [PATCH 2/2] compiler: Remove rewriteNumberedParameters func --- internal/compiler/parse.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 4eb46d04d3..0cc76b0728 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -19,18 +19,6 @@ import ( var ErrUnsupportedStatementType = errors.New("parseQuery: unsupported statement type") -func rewriteNumberedParameters(refs []paramRef, raw *ast.RawStmt, sql string) ([]source.Edit, error) { - edits := make([]source.Edit, len(refs)) - for i, ref := range refs { - edits[i] = source.Edit{ - Location: ref.ref.Location - raw.StmtLocation, - Old: fmt.Sprintf("$%d", ref.ref.Number), - New: "?", - } - } - return edits, nil -} - func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, error) { if o.Debug.DumpAST { debug.Dump(stmt)