Skip to content

Commit f239082

Browse files
committed
feat: support sqlc.embed
1 parent bfa71a9 commit f239082

File tree

1 file changed

+93
-9
lines changed

1 file changed

+93
-9
lines changed

internal/gen.go

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ type Field struct {
5353
Name string
5454
Type pyType
5555
Comment string
56+
// EmbedFields contains the embedded fields that require scanning.
57+
EmbedFields []Field
5658
}
5759

5860
type Struct struct {
@@ -105,14 +107,42 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node {
105107
call := &pyast.Call{
106108
Func: v.Annotation(),
107109
}
108-
for i, f := range v.Struct.Fields {
109-
call.Keywords = append(call.Keywords, &pyast.Keyword{
110-
Arg: f.Name,
111-
Value: subscriptNode(
110+
rowIndex := 0 // We need to keep track of the index in the row variable.
111+
for _, f := range v.Struct.Fields {
112+
113+
var valueNode *pyast.Node
114+
// Check if we are using sqlc.embed, if so we need to create a new object.
115+
if len(f.EmbedFields) > 0 {
116+
// We keep this separate so we can easily add all arguments.
117+
embed_call := &pyast.Call{Func: f.Type.Annotation()}
118+
119+
// Now add all field Initializers for the embedded model that index into the original row.
120+
for i, embedField := range f.EmbedFields {
121+
embed_call.Keywords = append(embed_call.Keywords, &pyast.Keyword{
122+
Arg: embedField.Name,
123+
Value: subscriptNode(
124+
rowVar,
125+
constantInt(rowIndex+i),
126+
),
127+
})
128+
}
129+
130+
valueNode = &pyast.Node{
131+
Node: &pyast.Node_Call{
132+
Call: embed_call,
133+
},
134+
}
135+
136+
rowIndex += len(f.EmbedFields)
137+
} else {
138+
valueNode = subscriptNode(
112139
rowVar,
113-
constantInt(i),
114-
),
115-
})
140+
constantInt(rowIndex),
141+
)
142+
rowIndex++
143+
}
144+
145+
call.Keywords = append(call.Keywords, &pyast.Keyword{Arg: f.Name, Value: valueNode})
116146
}
117147
return &pyast.Node{
118148
Node: &pyast.Node_Call{
@@ -336,6 +366,47 @@ func paramName(p *plugin.Parameter) string {
336366
type pyColumn struct {
337367
id int32
338368
*plugin.Column
369+
embed *pyEmbed
370+
}
371+
372+
type pyEmbed struct {
373+
modelType string
374+
modelName string
375+
fields []Field
376+
}
377+
378+
// Taken from https://github.com/sqlc-dev/sqlc/blob/8c59fbb9938a0bad3d9971fc2c10ea1f83cc1d0b/internal/codegen/golang/result.go#L123-L126
379+
// look through all the structs and attempt to find a matching one to embed
380+
// We need the name of the struct and its field names.
381+
func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *pyEmbed {
382+
if embed == nil {
383+
return nil
384+
}
385+
386+
for _, s := range structs {
387+
embedSchema := defaultSchema
388+
if embed.Schema != "" {
389+
embedSchema = embed.Schema
390+
}
391+
392+
// compare the other attributes
393+
if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema {
394+
continue
395+
}
396+
397+
fields := make([]Field, len(s.Fields))
398+
for i, f := range s.Fields {
399+
fields[i] = f
400+
}
401+
402+
return &pyEmbed{
403+
modelType: s.Name,
404+
modelName: s.Name,
405+
fields: fields,
406+
}
407+
}
408+
409+
return nil
339410
}
340411

341412
func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn) *Struct {
@@ -359,10 +430,22 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn
359430
if suffix > 0 {
360431
fieldName = fmt.Sprintf("%s_%d", fieldName, suffix)
361432
}
362-
gs.Fields = append(gs.Fields, Field{
433+
434+
f := Field{
363435
Name: fieldName,
364436
Type: makePyType(req, c.Column),
365-
})
437+
}
438+
439+
if c.embed != nil {
440+
f.Type = pyType{
441+
InnerType: "models." + modelName(c.embed.modelType, req.Settings),
442+
IsArray: false,
443+
IsNull: false,
444+
}
445+
f.EmbedFields = c.embed.fields
446+
}
447+
448+
gs.Fields = append(gs.Fields, f)
366449
seen[colName]++
367450
}
368451
return &gs
@@ -476,6 +559,7 @@ func buildQueries(conf Config, req *plugin.CodeGenRequest, structs []Struct) ([]
476559
columns = append(columns, pyColumn{
477560
id: int32(i),
478561
Column: c,
562+
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema),
479563
})
480564
}
481565
gs = columnsToStruct(req, query.Name+"Row", columns)

0 commit comments

Comments
 (0)