Skip to content

Commit c7c6a36

Browse files
feat: add sqlc.embed to allow model re-use (#1615)
* add sqlc.embed * only allow alias or table name in sqlc.embed() * add tests * regenerate other tests * Fix codegen.json tests --------- Co-authored-by: Kyle Conroy <kyle@conroy.org>
1 parent a8477b8 commit c7c6a36

File tree

34 files changed

+69117
-4933
lines changed

34 files changed

+69117
-4933
lines changed

internal/cmd/shim.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,14 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
264264
}
265265
}
266266

267+
if c.EmbedTable != nil {
268+
out.EmbedTable = &plugin.Identifier{
269+
Catalog: c.EmbedTable.Catalog,
270+
Schema: c.EmbedTable.Schema,
271+
Name: c.EmbedTable.Name,
272+
}
273+
}
274+
267275
return out
268276
}
269277

internal/codegen/golang/field.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ type Field struct {
1515
Tags map[string]string
1616
Comment string
1717
Column *plugin.Column
18+
// EmbedFields contains the embedded fields that reuqire scanning.
19+
EmbedFields []string
1820
}
1921

2022
func (gf Field) Tag() string {

internal/codegen/golang/query.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ func (v QueryValue) Scan() string {
154154
}
155155
} else {
156156
for _, f := range v.Struct.Fields {
157+
158+
// append any embedded fields
159+
if len(f.EmbedFields) > 0 {
160+
for _, embed := range f.EmbedFields {
161+
out = append(out, "&"+v.Name+"."+f.Name+"."+embed)
162+
}
163+
continue
164+
}
165+
157166
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
158167
out = append(out, "pq.Array(&"+v.Name+"."+f.Name+")")
159168
} else {

internal/codegen/golang/result.go

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,46 @@ func buildStructs(req *plugin.CodeGenRequest) []Struct {
103103
type goColumn struct {
104104
id int
105105
*plugin.Column
106+
embed *goEmbed
107+
}
108+
109+
type goEmbed struct {
110+
modelType string
111+
modelName string
112+
fields []string
113+
}
114+
115+
// look through all the structs and attempt to find a matching one to embed
116+
// We need the name of the struct and its field names.
117+
func newGoEmbed(embed *plugin.Identifier, structs []Struct) *goEmbed {
118+
if embed == nil {
119+
return nil
120+
}
121+
122+
for _, s := range structs {
123+
embedSchema := "public"
124+
if embed.Schema != "" {
125+
embedSchema = embed.Schema
126+
}
127+
128+
// compare the other attributes
129+
if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema {
130+
continue
131+
}
132+
133+
fields := make([]string, len(s.Fields))
134+
for i, f := range s.Fields {
135+
fields[i] = f.Name
136+
}
137+
138+
return &goEmbed{
139+
modelType: s.Name,
140+
modelName: s.Name,
141+
fields: fields,
142+
}
143+
}
144+
145+
return nil
106146
}
107147

108148
func columnName(c *plugin.Column, pos int) string {
@@ -192,7 +232,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
192232
}
193233
}
194234

195-
if len(query.Columns) == 1 {
235+
if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil {
196236
c := query.Columns[0]
197237
name := columnName(c, 0)
198238
if c.IsFuncCall {
@@ -234,6 +274,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
234274
columns = append(columns, goColumn{
235275
id: i,
236276
Column: c,
277+
embed: newGoEmbed(c.EmbedTable, structs),
237278
})
238279
}
239280
var err error
@@ -287,6 +328,13 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
287328
for i, c := range columns {
288329
colName := columnName(c.Column, i)
289330
tagName := colName
331+
332+
// overide col/tag with expected model name
333+
if c.embed != nil {
334+
colName = c.embed.modelName
335+
tagName = SetCaseStyle(colName, "snake")
336+
}
337+
290338
fieldName := StructName(colName, req.Settings)
291339
baseFieldName := fieldName
292340
// Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be
@@ -309,13 +357,20 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
309357
if req.Settings.Go.EmitJsonTags {
310358
tags["json"] = JSONTagName(tagName, req.Settings)
311359
}
312-
gs.Fields = append(gs.Fields, Field{
360+
f := Field{
313361
Name: fieldName,
314362
DBName: colName,
315-
Type: goType(req, c.Column),
316363
Tags: tags,
317364
Column: c.Column,
318-
})
365+
}
366+
if c.embed == nil {
367+
f.Type = goType(req, c.Column)
368+
} else {
369+
f.Type = c.embed.modelType
370+
f.EmbedFields = c.embed.fields
371+
}
372+
373+
gs.Fields = append(gs.Fields, f)
319374
if _, found := seen[baseFieldName]; !found {
320375
seen[baseFieldName] = []int{i}
321376
} else {

internal/compiler/expand.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,16 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
132132
for _, p := range parts {
133133
old = append(old, c.quoteIdent(p))
134134
}
135+
oldString := strings.Join(old, ".")
136+
137+
// use the sqlc.embed string instead
138+
if embed, ok := qc.embeds.Find(ref); ok {
139+
oldString = embed.Orig()
140+
}
141+
135142
edits = append(edits, source.Edit{
136143
Location: res.Location - raw.StmtLocation,
137-
Old: strings.Join(old, "."),
144+
Old: oldString,
138145
New: strings.Join(cols, ", "),
139146
})
140147
}

internal/compiler/output_columns.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
// OutputColumns determines which columns a statement will output
1616
func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
17-
qc, err := buildQueryCatalog(c.catalog, stmt)
17+
qc, err := buildQueryCatalog(c.catalog, stmt, nil)
1818
if err != nil {
1919
return nil, err
2020
}
@@ -201,6 +201,16 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
201201

202202
case *ast.ColumnRef:
203203
if hasStarRef(n) {
204+
205+
// add a column with a reference to an embedded table
206+
if embed, ok := qc.embeds.Find(n); ok {
207+
cols = append(cols, &Column{
208+
Name: embed.Table.Name,
209+
EmbedTable: embed.Table,
210+
})
211+
continue
212+
}
213+
204214
// TODO: This code is copied in func expand()
205215
for _, t := range tables {
206216
scope := astutils.Join(n.Fields, ".")
@@ -520,6 +530,7 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
520530
NotNull: c.NotNull,
521531
IsArray: c.IsArray,
522532
Length: c.Length,
533+
EmbedTable: c.EmbedTable,
523534
})
524535
}
525536
}

internal/compiler/parse.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,14 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
8686
} else {
8787
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
8888
}
89-
qc, err := buildQueryCatalog(c.catalog, raw.Stmt)
89+
90+
raw, embeds := rewrite.Embeds(raw)
91+
qc, err := buildQueryCatalog(c.catalog, raw.Stmt, embeds)
9092
if err != nil {
9193
return nil, err
9294
}
9395

94-
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams)
96+
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
9597
if err != nil {
9698
return nil, err
9799
}

internal/compiler/query.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type Column struct {
2929
Table *ast.TableName
3030
TableAlias string
3131
Type *ast.TypeName
32+
EmbedTable *ast.TableName
3233

3334
IsSqlcSlice bool // is this sqlc.slice()
3435

internal/compiler/query_catalog.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ import (
55

66
"github.com/kyleconroy/sqlc/internal/sql/ast"
77
"github.com/kyleconroy/sqlc/internal/sql/catalog"
8+
"github.com/kyleconroy/sqlc/internal/sql/rewrite"
89
)
910

1011
type QueryCatalog struct {
1112
catalog *catalog.Catalog
1213
ctes map[string]*Table
14+
embeds rewrite.EmbedSet
1315
}
1416

15-
func buildQueryCatalog(c *catalog.Catalog, node ast.Node) (*QueryCatalog, error) {
17+
func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) {
1618
var with *ast.WithClause
1719
switch n := node.(type) {
1820
case *ast.DeleteStmt:
@@ -26,7 +28,7 @@ func buildQueryCatalog(c *catalog.Catalog, node ast.Node) (*QueryCatalog, error)
2628
default:
2729
with = nil
2830
}
29-
qc := &QueryCatalog{catalog: c, ctes: map[string]*Table{}}
31+
qc := &QueryCatalog{catalog: c, ctes: map[string]*Table{}, embeds: embeds}
3032
if with != nil {
3133
for _, item := range with.Ctes.Items {
3234
if cte, ok := item.(*ast.CommonTableExpr); ok {

internal/compiler/resolve.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/kyleconroy/sqlc/internal/sql/astutils"
99
"github.com/kyleconroy/sqlc/internal/sql/catalog"
1010
"github.com/kyleconroy/sqlc/internal/sql/named"
11+
"github.com/kyleconroy/sqlc/internal/sql/rewrite"
1112
"github.com/kyleconroy/sqlc/internal/sql/sqlerr"
1213
)
1314

@@ -19,7 +20,7 @@ func dataType(n *ast.TypeName) string {
1920
}
2021
}
2122

22-
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet) ([]Parameter, error) {
23+
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
2324
c := comp.catalog
2425

2526
aliasMap := map[string]*ast.TableName{}
@@ -76,6 +77,22 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
7677
}
7778
}
7879

80+
// resolve a table for an embed
81+
for _, embed := range embeds {
82+
table, err := c.GetTable(embed.Table)
83+
if err == nil {
84+
embed.Table = table.Rel
85+
continue
86+
}
87+
88+
if alias, ok := aliasMap[embed.Table.Name]; ok {
89+
embed.Table = alias
90+
continue
91+
}
92+
93+
return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err)
94+
}
95+
7996
var a []Parameter
8097
for _, ref := range args {
8198
switch n := ref.parent.(type) {

0 commit comments

Comments
 (0)