Skip to content

Commit a3a37aa

Browse files
committed
first pass at sqlc.nembed
1 parent 2b88ce8 commit a3a37aa

File tree

4 files changed

+76
-9
lines changed

4 files changed

+76
-9
lines changed

internal/compiler/parse.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/sqlc-dev/sqlc/internal/source"
1414
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1515
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
16+
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
1617
"github.com/sqlc-dev/sqlc/internal/sql/rewrite"
1718
"github.com/sqlc-dev/sqlc/internal/sql/validate"
1819
)
@@ -99,6 +100,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
99100
if err != nil {
100101
return nil, err
101102
}
103+
104+
if err := buildNullTables(embeds, c.catalog); err != nil {
105+
return nil, err
106+
}
107+
102108
cols, err := c.outputColumns(qc, raw.Stmt)
103109
if err != nil {
104110
return nil, err
@@ -181,3 +187,50 @@ func uniqueParamRefs(in []paramRef, dollar bool) []paramRef {
181187
}
182188
return o
183189
}
190+
191+
// buildNullTables adds additional tables to the catalog for nullable embeds
192+
func buildNullTables(embeds rewrite.EmbedSet, c *catalog.Catalog) error {
193+
for _, emb := range embeds {
194+
if !emb.Nullable {
195+
continue
196+
}
197+
198+
schema, table, err := c.GetSchemaTable(emb.Table)
199+
if err != nil {
200+
return err
201+
}
202+
203+
emb.Table = &ast.TableName{
204+
Catalog: table.Rel.Catalog,
205+
Schema: table.Rel.Schema,
206+
Name: "null_" + table.Rel.Name,
207+
}
208+
209+
// skip if null table already exists
210+
if table, _ := c.GetTable(emb.Table); table.Rel != nil {
211+
continue
212+
}
213+
214+
nullTable := &catalog.Table{
215+
Rel: emb.Table,
216+
Columns: []*catalog.Column{},
217+
}
218+
219+
for _, c := range table.Columns {
220+
nullTable.Columns = append(nullTable.Columns, &catalog.Column{
221+
Name: c.Name,
222+
Type: c.Type,
223+
IsNotNull: false,
224+
IsUnsigned: c.IsUnsigned,
225+
IsArray: c.IsArray,
226+
ArrayDims: c.ArrayDims,
227+
Comment: c.Comment,
228+
Length: c.Length,
229+
})
230+
}
231+
232+
schema.Tables = append(schema.Tables, nullTable)
233+
}
234+
235+
return nil
236+
}

internal/sql/catalog/public.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,12 @@ func (c *Catalog) GetTable(rel *ast.TableName) (Table, error) {
131131
return *table, err
132132
}
133133
}
134+
135+
func (c *Catalog) GetSchemaTable(rel *ast.TableName) (*Schema, *Table, error) {
136+
schema, table, err := c.getTable(rel)
137+
if table == nil {
138+
return nil, nil, err
139+
} else {
140+
return schema, table, err
141+
}
142+
}

internal/sql/rewrite/embeds.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ import (
77
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
88
)
99

10-
// Embed is an instance of `sqlc.embed(param)`
10+
// Embed is an instance of `sqlc.embed(param)` or `sqlc.nembed(param)`
1111
type Embed struct {
12-
Table *ast.TableName
13-
param string
14-
Node *ast.ColumnRef
12+
Table *ast.TableName
13+
param string
14+
Node *ast.ColumnRef
15+
Nullable bool
1516
}
1617

1718
// Orig string to replace
1819
func (e Embed) Orig() string {
20+
if e.Nullable {
21+
return fmt.Sprintf("sqlc.nembed(%s)", e.param)
22+
}
1923
return fmt.Sprintf("sqlc.embed(%s)", e.param)
2024
}
2125

@@ -61,9 +65,10 @@ func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) {
6165
}
6266

6367
embeds = append(embeds, &Embed{
64-
Table: &ast.TableName{Name: param},
65-
param: param,
66-
Node: node,
68+
Table: &ast.TableName{Name: param},
69+
param: param,
70+
Node: node,
71+
Nullable: fun.Func.Name == "nembed",
6772
})
6873

6974
cr.Replace(node)
@@ -86,6 +91,6 @@ func isEmbed(node ast.Node) bool {
8691
return false
8792
}
8893

89-
isValid := call.Func.Schema == "sqlc" && call.Func.Name == "embed"
94+
isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "embed" || call.Func.Name == "nembed")
9095
return isValid
9196
}

internal/sql/validate/func_call.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor {
3434
// Custom validation for sqlc.arg, sqlc.narg and sqlc.slice
3535
// TODO: Replace this once type-checking is implemented
3636
if fn.Schema == "sqlc" {
37-
if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed") {
37+
if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed" || fn.Name == "nembed") {
3838
v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name)
3939
return nil
4040
}

0 commit comments

Comments
 (0)