Skip to content

Commit d1b8782

Browse files
committed
feat: add sqlc.embed to allow model re-use
1 parent f5c1a5e commit d1b8782

File tree

21 files changed

+720
-113
lines changed

21 files changed

+720
-113
lines changed

internal/cmd/shim.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
241241
Length: int32(l),
242242
IsNamedParam: c.IsNamedParam,
243243
IsFuncCall: c.IsFuncCall,
244+
Embedded: c.Embedded,
245+
EmbedIndex: c.EmbedIndex,
244246
}
245247

246248
if c.Type != nil {

internal/codegen/golang/query.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,17 @@ func (v QueryValue) Scan() string {
134134
out = append(out, "&"+v.Name)
135135
}
136136
} else {
137+
138+
for _, e := range v.Struct.EmbedFields {
139+
for _, f := range e.SubFields {
140+
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX {
141+
out = append(out, "pq.Array(&"+v.Name+"."+e.Name+"."+f.Name+")")
142+
} else {
143+
out = append(out, "&"+v.Name+"."+e.Name+"."+f.Name)
144+
}
145+
}
146+
}
147+
137148
for _, f := range v.Struct.Fields {
138149
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
139150
out = append(out, "pq.Array(&"+v.Name+"."+f.Name+")")
@@ -146,6 +157,7 @@ func (v QueryValue) Scan() string {
146157
return strings.Join(out, ",")
147158
}
148159
out = append(out, "")
160+
149161
return "\n" + strings.Join(out, ",\n")
150162
}
151163

internal/codegen/golang/result.go

Lines changed: 146 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ type goColumn struct {
105105
*plugin.Column
106106
}
107107

108+
type goEmbed struct {
109+
id int
110+
Model *Struct
111+
}
112+
108113
func columnName(c *plugin.Column, pos int) string {
109114
if c.Name != "" {
110115
return c.Name
@@ -177,7 +182,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
177182
Column: p.Column,
178183
})
179184
}
180-
s, err := columnsToStruct(req, gq.MethodName+"Params", cols, false)
185+
s, err := columnsToStruct(req, gq.MethodName+"Params", nil, cols, false)
181186
if err != nil {
182187
return nil, err
183188
}
@@ -202,48 +207,16 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
202207
SQLDriver: sqlpkg,
203208
}
204209
} else if putOutColumns(query) {
205-
var gs *Struct
206-
var emit bool
207210

208-
for _, s := range structs {
209-
if len(s.Fields) != len(query.Columns) {
210-
continue
211-
}
212-
same := true
213-
for i, f := range s.Fields {
214-
c := query.Columns[i]
215-
sameName := f.Name == StructName(columnName(c, i), req.Settings)
216-
sameType := f.Type == goType(req, c)
217-
sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema)
218-
if !sameName || !sameType || !sameTable {
219-
same = false
220-
}
221-
}
222-
if same {
223-
gs = &s
224-
break
225-
}
211+
qvs, emit, err := buildReturnStruct(req, structs, gq, query.Columns)
212+
if err != nil {
213+
return nil, err
226214
}
227215

228-
if gs == nil {
229-
var columns []goColumn
230-
for i, c := range query.Columns {
231-
columns = append(columns, goColumn{
232-
id: i,
233-
Column: c,
234-
})
235-
}
236-
var err error
237-
gs, err = columnsToStruct(req, gq.MethodName+"Row", columns, true)
238-
if err != nil {
239-
return nil, err
240-
}
241-
emit = true
242-
}
243216
gq.Ret = QueryValue{
244217
Emit: emit,
245218
Name: "i",
246-
Struct: gs,
219+
Struct: qvs,
247220
SQLDriver: sqlpkg,
248221
EmitPointer: req.Settings.Go.EmitResultStructPointers,
249222
}
@@ -267,6 +240,46 @@ func putOutColumns(query *plugin.Query) bool {
267240
return false
268241
}
269242

243+
func buildReturnStruct(req *plugin.CodeGenRequest, models []Struct, gq Query, columns []*plugin.Column) (*Struct, bool, error) {
244+
// group columns by embeds or top level fields
245+
rootCols, embedGroups := groupColumns(columns)
246+
247+
// no embeds, just top level fields
248+
if len(embedGroups) == 0 {
249+
// return early if model already exists
250+
if model := lookupModelForColumns(req, models, rootCols); model != nil {
251+
return model, false, nil
252+
}
253+
}
254+
255+
// reaching here means we need a to construct a new result struct, which
256+
// may or may not contain embeds
257+
var goCols []goColumn
258+
for i, c := range rootCols {
259+
goCols = append(goCols, goColumn{
260+
id: i,
261+
Column: c,
262+
})
263+
}
264+
265+
var goEmbeds []goEmbed
266+
for i, eg := range embedGroups {
267+
if model := lookupModelForColumns(req, models, eg.columns); model != nil {
268+
goEmbeds = append(goEmbeds, goEmbed{
269+
id: i,
270+
Model: model,
271+
})
272+
}
273+
}
274+
275+
gs, err := columnsToStruct(req, gq.MethodName+"Row", goEmbeds, goCols, true)
276+
if err != nil {
277+
return nil, false, err
278+
}
279+
280+
return gs, true, nil
281+
}
282+
270283
// It's possible that this method will generate duplicate JSON tag values
271284
//
272285
// Columns: count, count, count_2
@@ -275,12 +288,13 @@ func putOutColumns(query *plugin.Query) bool {
275288
// JSON tags: count, count_2, count_2
276289
//
277290
// This is unlikely to happen, so don't fix it yet
278-
func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn, useID bool) (*Struct, error) {
291+
func columnsToStruct(req *plugin.CodeGenRequest, name string, embeds []goEmbed, columns []goColumn, useID bool) (*Struct, error) {
279292
gs := Struct{
280293
Name: name,
281294
}
282295
seen := map[string][]int{}
283296
suffixes := map[int]int{}
297+
284298
for i, c := range columns {
285299
colName := columnName(c.Column, i)
286300
tagName := colName
@@ -342,6 +356,42 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
342356
return nil, err
343357
}
344358

359+
seenEmbeds := map[string]int{}
360+
361+
for _, em := range embeds {
362+
363+
fieldName := em.Model.Name
364+
tagName := fieldName
365+
366+
seen, ok := seenEmbeds[fieldName]
367+
if !ok {
368+
seenEmbeds[fieldName] = 0
369+
}
370+
seen++
371+
seenEmbeds[fieldName] = seen
372+
373+
suffix := seen - 1
374+
if suffix > 0 {
375+
tagName = fmt.Sprintf("%s_%d", tagName, suffix)
376+
fieldName = fmt.Sprintf("%s_%d", fieldName, suffix)
377+
}
378+
379+
tags := map[string]string{}
380+
381+
if req.Settings.Go.EmitJsonTags {
382+
tags["json:"] = JSONTagName(tagName, req.Settings)
383+
}
384+
385+
gs.EmbedFields = append(gs.EmbedFields, EmbedField{
386+
Field: Field{
387+
Name: fieldName,
388+
Type: em.Model.Name,
389+
Tags: tags,
390+
},
391+
SubFields: em.Model.Fields,
392+
})
393+
}
394+
345395
return &gs, nil
346396
}
347397

@@ -356,3 +406,61 @@ func checkIncompatibleFieldTypes(fields []Field) error {
356406
}
357407
return nil
358408
}
409+
410+
type embedGroup struct {
411+
index int32
412+
columns []*plugin.Column
413+
}
414+
415+
func groupColumns(cols []*plugin.Column) ([]*plugin.Column, []*embedGroup) {
416+
root := []*plugin.Column{}
417+
embeds := []*embedGroup{}
418+
419+
for _, c := range cols {
420+
421+
if !c.Embedded {
422+
root = append(root, c)
423+
continue
424+
}
425+
426+
var found bool
427+
for _, v := range embeds {
428+
if v.index == c.EmbedIndex {
429+
found = true
430+
v.columns = append(v.columns, c)
431+
}
432+
}
433+
if !found {
434+
embeds = append(embeds, &embedGroup{
435+
index: c.EmbedIndex,
436+
columns: []*plugin.Column{c},
437+
})
438+
}
439+
}
440+
441+
return root, embeds
442+
}
443+
444+
func lookupModelForColumns(req *plugin.CodeGenRequest, models []Struct, columns []*plugin.Column) *Struct {
445+
for _, model := range models {
446+
if len(model.Fields) != len(columns) {
447+
continue
448+
}
449+
450+
same := true
451+
for i, f := range model.Fields {
452+
c := columns[i]
453+
sameName := f.Name == StructName(columnName(c, i), req.Settings)
454+
sameType := f.Type == goType(req, c)
455+
sameTable := sdk.SameTableName(c.Table, &model.Table, req.Catalog.DefaultSchema)
456+
if !sameName || !sameType || !sameTable {
457+
same = false
458+
}
459+
}
460+
if same {
461+
return &model
462+
}
463+
}
464+
465+
return nil
466+
}

internal/codegen/golang/struct.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@ import (
99
)
1010

1111
type Struct struct {
12-
Table plugin.Identifier
13-
Name string
14-
Fields []Field
15-
Comment string
12+
Table plugin.Identifier
13+
Name string
14+
EmbedFields []EmbedField
15+
Fields []Field
16+
Comment string
17+
}
18+
19+
type EmbedField struct {
20+
Field
21+
SubFields []Field
1622
}
1723

1824
func StructName(name string, settings *plugin.Settings) string {

internal/codegen/golang/templates/pgx/queryCode.tmpl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
1616
{{end}}
1717

1818
{{if .Ret.EmitStruct}}
19-
type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}}
19+
type {{.Ret.Type}} struct { {{- range .Ret.Struct.EmbedFields}}
20+
{{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
21+
{{- end}}
22+
{{- range .Ret.Struct.Fields}}
2023
{{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
2124
{{- end}}
2225
}

internal/compiler/expand.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,16 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
132132
for _, p := range parts {
133133
old = append(old, c.quoteIdent(p))
134134
}
135+
oldString := strings.Join(old, ".")
136+
137+
// use the sqlc.embed string instead
138+
if embed, ok := qc.embeds.Find(ref); ok {
139+
oldString = embed.Orig()
140+
}
141+
135142
edits = append(edits, source.Edit{
136143
Location: res.Location - raw.StmtLocation,
137-
Old: strings.Join(old, "."),
144+
Old: oldString,
138145
New: strings.Join(cols, ", "),
139146
})
140147
}

internal/compiler/output_columns.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
// OutputColumns determines which columns a statement will output
1616
func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
17-
qc, err := buildQueryCatalog(c.catalog, stmt)
17+
qc, err := buildQueryCatalog(c.catalog, stmt, nil)
1818
if err != nil {
1919
return nil, err
2020
}
@@ -178,6 +178,9 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
178178

179179
case *ast.ColumnRef:
180180
if hasStarRef(n) {
181+
182+
embed, isEmbedded := qc.embeds.Find(n)
183+
181184
// TODO: This code is copied in func expand()
182185
for _, t := range tables {
183186
scope := astutils.Join(n.Fields, ".")
@@ -189,7 +192,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
189192
if res.Name != nil {
190193
cname = *res.Name
191194
}
192-
cols = append(cols, &Column{
195+
col := &Column{
193196
Name: cname,
194197
Type: c.Type,
195198
Scope: scope,
@@ -199,7 +202,14 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
199202
NotNull: c.NotNull,
200203
IsArray: c.IsArray,
201204
Length: c.Length,
202-
})
205+
Embedded: isEmbedded,
206+
}
207+
208+
if isEmbedded {
209+
col.EmbedIndex = embed.Index
210+
}
211+
212+
cols = append(cols, col)
203213
}
204214
}
205215
continue

internal/compiler/parse.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
8383
} else {
8484
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
8585
}
86-
qc, err := buildQueryCatalog(c.catalog, raw.Stmt)
86+
87+
raw, embeds := rewrite.Embeds(c.conf.Engine, raw)
88+
89+
qc, err := buildQueryCatalog(c.catalog, raw.Stmt, embeds)
8790
if err != nil {
8891
return nil, err
8992
}

internal/compiler/query.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ type Column struct {
3030
TableAlias string
3131
Type *ast.TypeName
3232

33+
Embedded bool
34+
EmbedIndex int32
35+
3336
skipTableRequiredCheck bool
3437
}
3538

0 commit comments

Comments
 (0)