Skip to content

Commit 8b87e9b

Browse files
committed
feat: add sqlc.embed to allow model re-use
1 parent c893a0d commit 8b87e9b

File tree

21 files changed

+716
-109
lines changed

21 files changed

+716
-109
lines changed

internal/cmd/shim.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
250250
Length: int32(l),
251251
IsNamedParam: c.IsNamedParam,
252252
IsFuncCall: c.IsFuncCall,
253+
Embedded: c.Embedded,
254+
EmbedIndex: c.EmbedIndex,
253255
}
254256

255257
if c.Type != nil {

internal/codegen/golang/query.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,17 @@ func (v QueryValue) Scan() string {
134134
out = append(out, "&"+v.Name)
135135
}
136136
} else {
137+
138+
for _, e := range v.Struct.EmbedFields {
139+
for _, f := range e.SubFields {
140+
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX {
141+
out = append(out, "pq.Array(&"+v.Name+"."+e.Name+"."+f.Name+")")
142+
} else {
143+
out = append(out, "&"+v.Name+"."+e.Name+"."+f.Name)
144+
}
145+
}
146+
}
147+
137148
for _, f := range v.Struct.Fields {
138149
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX {
139150
out = append(out, "pq.Array(&"+v.Name+"."+f.Name+")")
@@ -146,6 +157,7 @@ func (v QueryValue) Scan() string {
146157
return strings.Join(out, ",")
147158
}
148159
out = append(out, "")
160+
149161
return "\n" + strings.Join(out, ",\n")
150162
}
151163

internal/codegen/golang/result.go

Lines changed: 146 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ type goColumn struct {
100100
*plugin.Column
101101
}
102102

103+
type goEmbed struct {
104+
id int
105+
Model *Struct
106+
}
107+
103108
func columnName(c *plugin.Column, pos int) string {
104109
if c.Name != "" {
105110
return c.Name
@@ -172,7 +177,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
172177
Column: p.Column,
173178
})
174179
}
175-
s, err := columnsToStruct(req, gq.MethodName+"Params", cols, false)
180+
s, err := columnsToStruct(req, gq.MethodName+"Params", nil, cols, false)
176181
if err != nil {
177182
return nil, err
178183
}
@@ -197,48 +202,16 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
197202
SQLPackage: sqlpkg,
198203
}
199204
} else if len(query.Columns) > 1 {
200-
var gs *Struct
201-
var emit bool
202205

203-
for _, s := range structs {
204-
if len(s.Fields) != len(query.Columns) {
205-
continue
206-
}
207-
same := true
208-
for i, f := range s.Fields {
209-
c := query.Columns[i]
210-
sameName := f.Name == StructName(columnName(c, i), req.Settings)
211-
sameType := f.Type == goType(req, c)
212-
sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema)
213-
if !sameName || !sameType || !sameTable {
214-
same = false
215-
}
216-
}
217-
if same {
218-
gs = &s
219-
break
220-
}
206+
qvs, emit, err := buildReturnStruct(req, structs, gq, query.Columns)
207+
if err != nil {
208+
return nil, err
221209
}
222210

223-
if gs == nil {
224-
var columns []goColumn
225-
for i, c := range query.Columns {
226-
columns = append(columns, goColumn{
227-
id: i,
228-
Column: c,
229-
})
230-
}
231-
var err error
232-
gs, err = columnsToStruct(req, gq.MethodName+"Row", columns, true)
233-
if err != nil {
234-
return nil, err
235-
}
236-
emit = true
237-
}
238211
gq.Ret = QueryValue{
239212
Emit: emit,
240213
Name: "i",
241-
Struct: gs,
214+
Struct: qvs,
242215
SQLPackage: sqlpkg,
243216
EmitPointer: req.Settings.Go.EmitResultStructPointers,
244217
}
@@ -250,19 +223,60 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
250223
return qs, nil
251224
}
252225

226+
func buildReturnStruct(req *plugin.CodeGenRequest, models []Struct, gq Query, columns []*plugin.Column) (*Struct, bool, error) {
227+
// group columns by embeds or top level fields
228+
rootCols, embedGroups := groupColumns(columns)
229+
230+
// no embeds, just top level fields
231+
if len(embedGroups) == 0 {
232+
// return early if model already exists
233+
if model := lookupModelForColumns(req, models, rootCols); model != nil {
234+
return model, false, nil
235+
}
236+
}
237+
238+
// reaching here means we need a to construct a new result struct, which
239+
// may or may not contain embeds
240+
var goCols []goColumn
241+
for i, c := range rootCols {
242+
goCols = append(goCols, goColumn{
243+
id: i,
244+
Column: c,
245+
})
246+
}
247+
248+
var goEmbeds []goEmbed
249+
for i, eg := range embedGroups {
250+
if model := lookupModelForColumns(req, models, eg.columns); model != nil {
251+
goEmbeds = append(goEmbeds, goEmbed{
252+
id: i,
253+
Model: model,
254+
})
255+
}
256+
}
257+
258+
gs, err := columnsToStruct(req, gq.MethodName+"Row", goEmbeds, goCols, true)
259+
if err != nil {
260+
return nil, false, err
261+
}
262+
263+
return gs, true, nil
264+
}
265+
253266
// It's possible that this method will generate duplicate JSON tag values
254267
//
255268
// Columns: count, count, count_2
256269
// Fields: Count, Count_2, Count2
257270
// JSON tags: count, count_2, count_2
258271
//
259272
// This is unlikely to happen, so don't fix it yet
260-
func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn, useID bool) (*Struct, error) {
273+
func columnsToStruct(req *plugin.CodeGenRequest, name string, embeds []goEmbed, columns []goColumn, useID bool) (*Struct, error) {
261274
gs := Struct{
262275
Name: name,
263276
}
264277
seen := map[string][]int{}
265278
suffixes := map[int]int{}
279+
266280
for i, c := range columns {
267281
colName := columnName(c.Column, i)
268282
tagName := colName
@@ -324,6 +338,42 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
324338
return nil, err
325339
}
326340

341+
seenEmbeds := map[string]int{}
342+
343+
for _, em := range embeds {
344+
345+
fieldName := em.Model.Name
346+
tagName := fieldName
347+
348+
seen, ok := seenEmbeds[fieldName]
349+
if !ok {
350+
seenEmbeds[fieldName] = 0
351+
}
352+
seen++
353+
seenEmbeds[fieldName] = seen
354+
355+
suffix := seen - 1
356+
if suffix > 0 {
357+
tagName = fmt.Sprintf("%s_%d", tagName, suffix)
358+
fieldName = fmt.Sprintf("%s_%d", fieldName, suffix)
359+
}
360+
361+
tags := map[string]string{}
362+
363+
if req.Settings.Go.EmitJsonTags {
364+
tags["json:"] = JSONTagName(tagName, req.Settings)
365+
}
366+
367+
gs.EmbedFields = append(gs.EmbedFields, EmbedField{
368+
Field: Field{
369+
Name: fieldName,
370+
Type: em.Model.Name,
371+
Tags: tags,
372+
},
373+
SubFields: em.Model.Fields,
374+
})
375+
}
376+
327377
return &gs, nil
328378
}
329379

@@ -338,3 +388,61 @@ func checkIncompatibleFieldTypes(fields []Field) error {
338388
}
339389
return nil
340390
}
391+
392+
type embedGroup struct {
393+
index int32
394+
columns []*plugin.Column
395+
}
396+
397+
func groupColumns(cols []*plugin.Column) ([]*plugin.Column, []*embedGroup) {
398+
root := []*plugin.Column{}
399+
embeds := []*embedGroup{}
400+
401+
for _, c := range cols {
402+
403+
if !c.Embedded {
404+
root = append(root, c)
405+
continue
406+
}
407+
408+
var found bool
409+
for _, v := range embeds {
410+
if v.index == c.EmbedIndex {
411+
found = true
412+
v.columns = append(v.columns, c)
413+
}
414+
}
415+
if !found {
416+
embeds = append(embeds, &embedGroup{
417+
index: c.EmbedIndex,
418+
columns: []*plugin.Column{c},
419+
})
420+
}
421+
}
422+
423+
return root, embeds
424+
}
425+
426+
func lookupModelForColumns(req *plugin.CodeGenRequest, models []Struct, columns []*plugin.Column) *Struct {
427+
for _, model := range models {
428+
if len(model.Fields) != len(columns) {
429+
continue
430+
}
431+
432+
same := true
433+
for i, f := range model.Fields {
434+
c := columns[i]
435+
sameName := f.Name == StructName(columnName(c, i), req.Settings)
436+
sameType := f.Type == goType(req, c)
437+
sameTable := sdk.SameTableName(c.Table, &model.Table, req.Catalog.DefaultSchema)
438+
if !sameName || !sameType || !sameTable {
439+
same = false
440+
}
441+
}
442+
if same {
443+
return &model
444+
}
445+
}
446+
447+
return nil
448+
}

internal/codegen/golang/struct.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@ import (
99
)
1010

1111
type Struct struct {
12-
Table plugin.Identifier
13-
Name string
14-
Fields []Field
15-
Comment string
12+
Table plugin.Identifier
13+
Name string
14+
EmbedFields []EmbedField
15+
Fields []Field
16+
Comment string
17+
}
18+
19+
type EmbedField struct {
20+
Field
21+
SubFields []Field
1622
}
1723

1824
func StructName(name string, settings *plugin.Settings) string {

internal/codegen/golang/templates/pgx/queryCode.tmpl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
1616
{{end}}
1717

1818
{{if .Ret.EmitStruct}}
19-
type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}}
19+
type {{.Ret.Type}} struct { {{- range .Ret.Struct.EmbedFields}}
20+
{{.Name}} {{.Type}} {{if $.EmitJSONTags}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
21+
{{- end}}
22+
{{- range .Ret.Struct.Fields}}
2023
{{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
2124
{{- end}}
2225
}

internal/compiler/expand.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,16 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
126126
for _, p := range parts {
127127
old = append(old, c.quoteIdent(p))
128128
}
129+
oldString := strings.Join(old, ".")
130+
131+
// use the sqlc.embed string instead
132+
if embed, ok := qc.embeds.Find(ref); ok {
133+
oldString = embed.Orig()
134+
}
135+
129136
edits = append(edits, source.Edit{
130137
Location: res.Location - raw.StmtLocation,
131-
Old: strings.Join(old, "."),
138+
Old: oldString,
132139
New: strings.Join(cols, ", "),
133140
})
134141
}

internal/compiler/output_columns.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
// OutputColumns determines which columns a statement will output
1616
func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
17-
qc, err := buildQueryCatalog(c.catalog, stmt)
17+
qc, err := buildQueryCatalog(c.catalog, stmt, nil)
1818
if err != nil {
1919
return nil, err
2020
}
@@ -172,6 +172,9 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
172172

173173
case *ast.ColumnRef:
174174
if hasStarRef(n) {
175+
176+
embed, isEmbedded := qc.embeds.Find(n)
177+
175178
// TODO: This code is copied in func expand()
176179
for _, t := range tables {
177180
scope := astutils.Join(n.Fields, ".")
@@ -183,7 +186,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
183186
if res.Name != nil {
184187
cname = *res.Name
185188
}
186-
cols = append(cols, &Column{
189+
col := &Column{
187190
Name: cname,
188191
Type: c.Type,
189192
Scope: scope,
@@ -193,7 +196,14 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
193196
NotNull: c.NotNull,
194197
IsArray: c.IsArray,
195198
Length: c.Length,
196-
})
199+
Embedded: isEmbedded,
200+
}
201+
202+
if isEmbedded {
203+
col.EmbedIndex = embed.Index
204+
}
205+
206+
cols = append(cols, col)
197207
}
198208
}
199209
continue

internal/compiler/parse.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
103103
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
104104
}
105105
}
106-
qc, err := buildQueryCatalog(c.catalog, raw.Stmt)
106+
107+
raw, embeds := rewrite.Embeds(c.conf.Engine, raw)
108+
109+
qc, err := buildQueryCatalog(c.catalog, raw.Stmt, embeds)
107110
if err != nil {
108111
return nil, err
109112
}

internal/compiler/query.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ type Column struct {
3030
TableAlias string
3131
Type *ast.TypeName
3232

33+
Embedded bool
34+
EmbedIndex int32
35+
3336
skipTableRequiredCheck bool
3437
}
3538

0 commit comments

Comments
 (0)