@@ -53,6 +53,8 @@ type Field struct {
53
53
Name string
54
54
Type pyType
55
55
Comment string
56
+ // EmbedFields contains the embedded fields that require scanning.
57
+ EmbedFields []Field
56
58
}
57
59
58
60
type Struct struct {
@@ -105,14 +107,42 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node {
105
107
call := & pyast.Call {
106
108
Func : v .Annotation (),
107
109
}
108
- for i , f := range v .Struct .Fields {
109
- call .Keywords = append (call .Keywords , & pyast.Keyword {
110
- Arg : f .Name ,
111
- Value : subscriptNode (
110
+ rowIndex := 0 // We need to keep track of the index in the row variable.
111
+ for _ , f := range v .Struct .Fields {
112
+
113
+ var valueNode * pyast.Node
114
+ // Check if we are using sqlc.embed, if so we need to create a new object.
115
+ if len (f .EmbedFields ) > 0 {
116
+ // We keep this separate so we can easily add all arguments.
117
+ embed_call := & pyast.Call {Func : f .Type .Annotation ()}
118
+
119
+ // Now add all field Initializers for the embedded model that index into the original row.
120
+ for i , embedField := range f .EmbedFields {
121
+ embed_call .Keywords = append (embed_call .Keywords , & pyast.Keyword {
122
+ Arg : embedField .Name ,
123
+ Value : subscriptNode (
124
+ rowVar ,
125
+ constantInt (rowIndex + i ),
126
+ ),
127
+ })
128
+ }
129
+
130
+ valueNode = & pyast.Node {
131
+ Node : & pyast.Node_Call {
132
+ Call : embed_call ,
133
+ },
134
+ }
135
+
136
+ rowIndex += len (f .EmbedFields )
137
+ } else {
138
+ valueNode = subscriptNode (
112
139
rowVar ,
113
- constantInt (i ),
114
- ),
115
- })
140
+ constantInt (rowIndex ),
141
+ )
142
+ rowIndex ++
143
+ }
144
+
145
+ call .Keywords = append (call .Keywords , & pyast.Keyword {Arg : f .Name , Value : valueNode })
116
146
}
117
147
return & pyast.Node {
118
148
Node : & pyast.Node_Call {
@@ -336,6 +366,47 @@ func paramName(p *plugin.Parameter) string {
336
366
type pyColumn struct {
337
367
id int32
338
368
* plugin.Column
369
+ embed * pyEmbed
370
+ }
371
+
372
+ type pyEmbed struct {
373
+ modelType string
374
+ modelName string
375
+ fields []Field
376
+ }
377
+
378
+ // Taken from https://github.com/sqlc-dev/sqlc/blob/8c59fbb9938a0bad3d9971fc2c10ea1f83cc1d0b/internal/codegen/golang/result.go#L123-L126
379
+ // look through all the structs and attempt to find a matching one to embed
380
+ // We need the name of the struct and its field names.
381
+ func newGoEmbed (embed * plugin.Identifier , structs []Struct , defaultSchema string ) * pyEmbed {
382
+ if embed == nil {
383
+ return nil
384
+ }
385
+
386
+ for _ , s := range structs {
387
+ embedSchema := defaultSchema
388
+ if embed .Schema != "" {
389
+ embedSchema = embed .Schema
390
+ }
391
+
392
+ // compare the other attributes
393
+ if embed .Catalog != s .Table .Catalog || embed .Name != s .Table .Name || embedSchema != s .Table .Schema {
394
+ continue
395
+ }
396
+
397
+ fields := make ([]Field , len (s .Fields ))
398
+ for i , f := range s .Fields {
399
+ fields [i ] = f
400
+ }
401
+
402
+ return & pyEmbed {
403
+ modelType : s .Name ,
404
+ modelName : s .Name ,
405
+ fields : fields ,
406
+ }
407
+ }
408
+
409
+ return nil
339
410
}
340
411
341
412
func columnsToStruct (req * plugin.CodeGenRequest , name string , columns []pyColumn ) * Struct {
@@ -359,10 +430,22 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn
359
430
if suffix > 0 {
360
431
fieldName = fmt .Sprintf ("%s_%d" , fieldName , suffix )
361
432
}
362
- gs .Fields = append (gs .Fields , Field {
433
+
434
+ f := Field {
363
435
Name : fieldName ,
364
436
Type : makePyType (req , c .Column ),
365
- })
437
+ }
438
+
439
+ if c .embed != nil {
440
+ f .Type = pyType {
441
+ InnerType : "models." + modelName (c .embed .modelType , req .Settings ),
442
+ IsArray : false ,
443
+ IsNull : false ,
444
+ }
445
+ f .EmbedFields = c .embed .fields
446
+ }
447
+
448
+ gs .Fields = append (gs .Fields , f )
366
449
seen [colName ]++
367
450
}
368
451
return & gs
@@ -476,6 +559,7 @@ func buildQueries(conf Config, req *plugin.CodeGenRequest, structs []Struct) ([]
476
559
columns = append (columns , pyColumn {
477
560
id : int32 (i ),
478
561
Column : c ,
562
+ embed : newGoEmbed (c .EmbedTable , structs , req .Catalog .DefaultSchema ),
479
563
})
480
564
}
481
565
gs = columnsToStruct (req , query .Name + "Row" , columns )
0 commit comments