8
8
"fmt"
9
9
"regexp"
10
10
"sort"
11
+ "strconv"
11
12
"strings"
12
13
"text/template"
13
14
@@ -32,25 +33,24 @@ type Enum struct {
32
33
}
33
34
34
35
type Field struct {
36
+ ID int
35
37
Name string
36
38
Type ktType
37
39
Comment string
38
40
}
39
41
40
42
type Struct struct {
41
- Table plugin.Identifier
42
- Name string
43
- Fields []Field
44
- JDBCParamBindings []Field
45
- Comment string
43
+ Table plugin.Identifier
44
+ Name string
45
+ Fields []Field
46
+ Comment string
46
47
}
47
48
48
49
type QueryValue struct {
49
- Emit bool
50
- Name string
51
- Struct * Struct
52
- Typ ktType
53
- JDBCParamBindCount int
50
+ Emit bool
51
+ Name string
52
+ Struct * Struct
53
+ Typ ktType
54
54
}
55
55
56
56
func (v QueryValue ) EmitStruct () bool {
@@ -102,7 +102,8 @@ func jdbcSet(t ktType, idx int, name string) string {
102
102
}
103
103
104
104
type Params struct {
105
- Struct * Struct
105
+ Struct * Struct
106
+ binding []int
106
107
}
107
108
108
109
func (v Params ) isEmpty () bool {
@@ -114,9 +115,19 @@ func (v Params) Args() string {
114
115
return ""
115
116
}
116
117
var out []string
117
- for _ , f := range v .Struct .Fields {
118
+ fields := v .Struct .Fields
119
+ for _ , f := range fields {
118
120
out = append (out , f .Name + ": " + f .Type .String ())
119
121
}
122
+ if len (v .binding ) > 0 {
123
+ lookup := map [int ]int {}
124
+ for i , v := range v .binding {
125
+ lookup [v ] = i
126
+ }
127
+ sort .Slice (out , func (i , j int ) bool {
128
+ return lookup [fields [i ].ID ] < lookup [fields [j ].ID ]
129
+ })
130
+ }
120
131
if len (out ) < 3 {
121
132
return strings .Join (out , ", " )
122
133
}
@@ -128,8 +139,15 @@ func (v Params) Bindings() string {
128
139
return ""
129
140
}
130
141
var out []string
131
- for i , f := range v .Struct .JDBCParamBindings {
132
- out = append (out , jdbcSet (f .Type , i + 1 , f .Name ))
142
+ if len (v .binding ) > 0 {
143
+ for i , idx := range v .binding {
144
+ f := v .Struct .Fields [idx - 1 ]
145
+ out = append (out , jdbcSet (f .Type , i + 1 , f .Name ))
146
+ }
147
+ } else {
148
+ for i , f := range v .Struct .Fields {
149
+ out = append (out , jdbcSet (f .Type , i + 1 , f .Name ))
150
+ }
133
151
}
134
152
return indent (strings .Join (out , "\n " ), 10 , 0 )
135
153
}
@@ -387,20 +405,19 @@ func ktColumnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColu
387
405
idSeen := map [int ]Field {}
388
406
nameSeen := map [string ]int {}
389
407
for _ , c := range columns {
390
- if binding , ok := idSeen [c .id ]; ok {
391
- gs .JDBCParamBindings = append (gs .JDBCParamBindings , binding )
408
+ if _ , ok := idSeen [c .id ]; ok {
392
409
continue
393
410
}
394
411
fieldName := memberName (namer (c .Column , c .id ), req .Settings )
395
412
if v := nameSeen [c .Name ]; v > 0 {
396
413
fieldName = fmt .Sprintf ("%s_%d" , fieldName , v + 1 )
397
414
}
398
415
field := Field {
416
+ ID : c .id ,
399
417
Name : fieldName ,
400
418
Type : makeType (req , c .Column ),
401
419
}
402
420
gs .Fields = append (gs .Fields , field )
403
- gs .JDBCParamBindings = append (gs .JDBCParamBindings , field )
404
421
nameSeen [c .Name ]++
405
422
idSeen [c .id ] = field
406
423
}
@@ -438,11 +455,31 @@ var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$\d+\b`)
438
455
// HACK: jdbc doesn't support numbered parameters, so we need to transform them to question marks...
439
456
// But there's no access to the SQL parser here, so we just do a dumb regexp replace instead. This won't work if
440
457
// the literal strings contain matching values, but good enough for a prototype.
441
- func jdbcSQL (s , engine string ) string {
442
- if engine = = "postgresql" {
443
- return postgresPlaceholderRegexp . ReplaceAllString ( s , "?" )
458
+ func jdbcSQL (s , engine string ) ( string , [] string ) {
459
+ if engine ! = "postgresql" {
460
+ return s , nil
444
461
}
445
- return s
462
+ var args []string
463
+ q := postgresPlaceholderRegexp .ReplaceAllStringFunc (s , func (placeholder string ) string {
464
+ args = append (args , placeholder )
465
+ return "?"
466
+ })
467
+ return q , args
468
+ }
469
+
470
+ func parseInts (s []string ) ([]int , error ) {
471
+ if len (s ) == 0 {
472
+ return nil , nil
473
+ }
474
+ var refs []int
475
+ for _ , v := range s {
476
+ i , err := strconv .Atoi (strings .TrimPrefix (v , "$" ))
477
+ if err != nil {
478
+ return nil , err
479
+ }
480
+ refs = append (refs , i )
481
+ }
482
+ return refs , nil
446
483
}
447
484
448
485
func buildQueries (req * plugin.CodeGenRequest , structs []Struct ) ([]Query , error ) {
@@ -458,14 +495,19 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
458
495
return nil , errors .New ("Support for CopyFrom in Kotlin is not implemented" )
459
496
}
460
497
498
+ ql , args := jdbcSQL (query .Text , req .Settings .Engine )
499
+ refs , err := parseInts (args )
500
+ if err != nil {
501
+ return nil , fmt .Errorf ("Invalid parameter reference: %w" , err )
502
+ }
461
503
gq := Query {
462
504
Cmd : query .Cmd ,
463
505
ClassName : strings .Title (query .Name ),
464
506
ConstantName : sdk .LowerTitle (query .Name ),
465
507
FieldName : sdk .LowerTitle (query .Name ) + "Stmt" ,
466
508
MethodName : sdk .LowerTitle (query .Name ),
467
509
SourceName : query .Filename ,
468
- SQL : jdbcSQL ( query . Text , req . Settings . Engine ) ,
510
+ SQL : ql ,
469
511
Comments : query .Comments ,
470
512
}
471
513
@@ -478,7 +520,8 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
478
520
}
479
521
params := ktColumnsToStruct (req , gq .ClassName + "Bindings" , cols , ktParamName )
480
522
gq .Arg = Params {
481
- Struct : params ,
523
+ Struct : params ,
524
+ binding : refs ,
482
525
}
483
526
484
527
if len (query .Columns ) == 1 {
0 commit comments