Skip to content

Commit 5408aa9

Browse files
authored
compiler: Move Kotlin parameter logic into codegen (#1910)
* compiler: Move Kotlin parameter logic into codegen * compiler: Remove rewriteNumberedParameters func
1 parent 89c6aa1 commit 5408aa9

File tree

4 files changed

+71
-51
lines changed

4 files changed

+71
-51
lines changed

internal/cmd/generate.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,6 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
198198
lang = "golang"
199199

200200
case sql.Gen.Kotlin != nil:
201-
if sql.Engine == config.EnginePostgreSQL {
202-
parseOpts.UsePositionalParameters = true
203-
}
204201
lang = "kotlin"
205202
name = combo.Kotlin.Package
206203

internal/codegen/kotlin/gen.go

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"regexp"
1010
"sort"
11+
"strconv"
1112
"strings"
1213
"text/template"
1314

@@ -32,25 +33,24 @@ type Enum struct {
3233
}
3334

3435
type Field struct {
36+
ID int
3537
Name string
3638
Type ktType
3739
Comment string
3840
}
3941

4042
type Struct struct {
41-
Table plugin.Identifier
42-
Name string
43-
Fields []Field
44-
JDBCParamBindings []Field
45-
Comment string
43+
Table plugin.Identifier
44+
Name string
45+
Fields []Field
46+
Comment string
4647
}
4748

4849
type QueryValue struct {
49-
Emit bool
50-
Name string
51-
Struct *Struct
52-
Typ ktType
53-
JDBCParamBindCount int
50+
Emit bool
51+
Name string
52+
Struct *Struct
53+
Typ ktType
5454
}
5555

5656
func (v QueryValue) EmitStruct() bool {
@@ -102,7 +102,8 @@ func jdbcSet(t ktType, idx int, name string) string {
102102
}
103103

104104
type Params struct {
105-
Struct *Struct
105+
Struct *Struct
106+
binding []int
106107
}
107108

108109
func (v Params) isEmpty() bool {
@@ -114,9 +115,19 @@ func (v Params) Args() string {
114115
return ""
115116
}
116117
var out []string
117-
for _, f := range v.Struct.Fields {
118+
fields := v.Struct.Fields
119+
for _, f := range fields {
118120
out = append(out, f.Name+": "+f.Type.String())
119121
}
122+
if len(v.binding) > 0 {
123+
lookup := map[int]int{}
124+
for i, v := range v.binding {
125+
lookup[v] = i
126+
}
127+
sort.Slice(out, func(i, j int) bool {
128+
return lookup[fields[i].ID] < lookup[fields[j].ID]
129+
})
130+
}
120131
if len(out) < 3 {
121132
return strings.Join(out, ", ")
122133
}
@@ -128,8 +139,15 @@ func (v Params) Bindings() string {
128139
return ""
129140
}
130141
var out []string
131-
for i, f := range v.Struct.JDBCParamBindings {
132-
out = append(out, jdbcSet(f.Type, i+1, f.Name))
142+
if len(v.binding) > 0 {
143+
for i, idx := range v.binding {
144+
f := v.Struct.Fields[idx-1]
145+
out = append(out, jdbcSet(f.Type, i+1, f.Name))
146+
}
147+
} else {
148+
for i, f := range v.Struct.Fields {
149+
out = append(out, jdbcSet(f.Type, i+1, f.Name))
150+
}
133151
}
134152
return indent(strings.Join(out, "\n"), 10, 0)
135153
}
@@ -387,20 +405,19 @@ func ktColumnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColu
387405
idSeen := map[int]Field{}
388406
nameSeen := map[string]int{}
389407
for _, c := range columns {
390-
if binding, ok := idSeen[c.id]; ok {
391-
gs.JDBCParamBindings = append(gs.JDBCParamBindings, binding)
408+
if _, ok := idSeen[c.id]; ok {
392409
continue
393410
}
394411
fieldName := memberName(namer(c.Column, c.id), req.Settings)
395412
if v := nameSeen[c.Name]; v > 0 {
396413
fieldName = fmt.Sprintf("%s_%d", fieldName, v+1)
397414
}
398415
field := Field{
416+
ID: c.id,
399417
Name: fieldName,
400418
Type: makeType(req, c.Column),
401419
}
402420
gs.Fields = append(gs.Fields, field)
403-
gs.JDBCParamBindings = append(gs.JDBCParamBindings, field)
404421
nameSeen[c.Name]++
405422
idSeen[c.id] = field
406423
}
@@ -438,11 +455,31 @@ var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$\d+\b`)
438455
// HACK: jdbc doesn't support numbered parameters, so we need to transform them to question marks...
439456
// But there's no access to the SQL parser here, so we just do a dumb regexp replace instead. This won't work if
440457
// the literal strings contain matching values, but good enough for a prototype.
441-
func jdbcSQL(s, engine string) string {
442-
if engine == "postgresql" {
443-
return postgresPlaceholderRegexp.ReplaceAllString(s, "?")
458+
func jdbcSQL(s, engine string) (string, []string) {
459+
if engine != "postgresql" {
460+
return s, nil
444461
}
445-
return s
462+
var args []string
463+
q := postgresPlaceholderRegexp.ReplaceAllStringFunc(s, func(placeholder string) string {
464+
args = append(args, placeholder)
465+
return "?"
466+
})
467+
return q, args
468+
}
469+
470+
func parseInts(s []string) ([]int, error) {
471+
if len(s) == 0 {
472+
return nil, nil
473+
}
474+
var refs []int
475+
for _, v := range s {
476+
i, err := strconv.Atoi(strings.TrimPrefix(v, "$"))
477+
if err != nil {
478+
return nil, err
479+
}
480+
refs = append(refs, i)
481+
}
482+
return refs, nil
446483
}
447484

448485
func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) {
@@ -458,14 +495,19 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
458495
return nil, errors.New("Support for CopyFrom in Kotlin is not implemented")
459496
}
460497

498+
ql, args := jdbcSQL(query.Text, req.Settings.Engine)
499+
refs, err := parseInts(args)
500+
if err != nil {
501+
return nil, fmt.Errorf("Invalid parameter reference: %w", err)
502+
}
461503
gq := Query{
462504
Cmd: query.Cmd,
463505
ClassName: strings.Title(query.Name),
464506
ConstantName: sdk.LowerTitle(query.Name),
465507
FieldName: sdk.LowerTitle(query.Name) + "Stmt",
466508
MethodName: sdk.LowerTitle(query.Name),
467509
SourceName: query.Filename,
468-
SQL: jdbcSQL(query.Text, req.Settings.Engine),
510+
SQL: ql,
469511
Comments: query.Comments,
470512
}
471513

@@ -478,7 +520,8 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
478520
}
479521
params := ktColumnsToStruct(req, gq.ClassName+"Bindings", cols, ktParamName)
480522
gq.Arg = Params{
481-
Struct: params,
523+
Struct: params,
524+
binding: refs,
482525
}
483526

484527
if len(query.Columns) == 1 {

internal/compiler/parse.go

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,6 @@ import (
1919

2020
var ErrUnsupportedStatementType = errors.New("parseQuery: unsupported statement type")
2121

22-
func rewriteNumberedParameters(refs []paramRef, raw *ast.RawStmt, sql string) ([]source.Edit, error) {
23-
edits := make([]source.Edit, len(refs))
24-
for i, ref := range refs {
25-
edits[i] = source.Edit{
26-
Location: ref.ref.Location - raw.StmtLocation,
27-
Old: fmt.Sprintf("$%d", ref.ref.Number),
28-
New: "?",
29-
}
30-
}
31-
return edits, nil
32-
}
33-
3422
func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, error) {
3523
if o.Debug.DumpAST {
3624
debug.Dump(stmt)
@@ -89,18 +77,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
8977
if err != nil {
9078
return nil, err
9179
}
92-
if o.UsePositionalParameters {
93-
edits, err = rewriteNumberedParameters(refs, raw, rawSQL)
94-
if err != nil {
95-
return nil, err
96-
}
80+
refs = uniqueParamRefs(refs, dollar)
81+
if c.conf.Engine == config.EngineMySQL || !dollar {
82+
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location })
9783
} else {
98-
refs = uniqueParamRefs(refs, dollar)
99-
if c.conf.Engine == config.EngineMySQL || !dollar {
100-
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location })
101-
} else {
102-
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
103-
}
84+
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
10485
}
10586
qc, err := buildQueryCatalog(c.catalog, raw.Stmt)
10687
if err != nil {

internal/opts/parser.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package opts
22

33
type Parser struct {
4-
UsePositionalParameters bool
5-
Debug Debug
4+
Debug Debug
65
}

0 commit comments

Comments
 (0)