From 927b5831ddaf1b650dee556240ec560e2940a2d4 Mon Sep 17 00:00:00 2001 From: Rich Churcher Date: Tue, 18 Jul 2023 14:00:24 +1200 Subject: [PATCH 1/2] Add sqlc.nembed --- internal/codegen/golang/query.go | 56 ++++++++++++++++++- internal/codegen/golang/result.go | 3 + .../golang/templates/pgx/queryCode.tmpl | 41 +++++++++++++- .../codegen/golang/templates/template.tmpl | 2 + internal/compiler/output_columns.go | 1 + internal/sql/rewrite/embeds.go | 36 ++++++++---- internal/sql/validate/func_call.go | 7 +-- 7 files changed, 126 insertions(+), 20 deletions(-) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index aeb1c106a2..42b8a15266 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -154,6 +154,54 @@ func (v QueryValue) HasSqlcSlices() bool { return false } +func (v QueryValue) AssignNullableEmbeds() string { + var out []string + if v.Struct != nil { + for _, f := range v.Struct.Fields { + if len(f.EmbedFields) > 0 && !f.Column.NotNull { + out = append(out, v.Name+"."+f.Name+" = &n"+f.Name) + } + } + } + return "\n" + strings.Join(out, "\n") +} + +func (v QueryValue) DeclareNullableEmbeds() string { + var out []string + if v.Struct != nil { + for _, f := range v.Struct.Fields { + if len(f.EmbedFields) > 0 && !f.Column.NotNull { + out = append(out, "var n"+f.Name+" "+f.Type[1:]) + } + } + } + return "\n" + strings.Join(out, "\n") +} + +func (v QueryValue) NullableIndices() [][]int { + var out [][]int + fieldIdx := 0 + if v.Struct != nil { + for _, f := range v.Struct.Fields { + if len(f.EmbedFields) > 0 { + var nullableIndices []int + for range f.EmbedFields { + if !f.Column.NotNull { + nullableIndices = append(nullableIndices, fieldIdx) + } + fieldIdx++ + } + if len(nullableIndices) > 0 { + out = append(out, nullableIndices) + } + } else { + fieldIdx++ + } + } + } + return out +} + func (v QueryValue) Scan() string { var out []string if v.Struct == nil { @@ -167,8 +215,14 @@ func (v QueryValue) Scan() string { // append any embedded fields if len(f.EmbedFields) > 0 { + prefix := "&" + v.Name + "." + // Regular embeds go straight into the return struct, nembed uses an intermediate + // value to check for NULL + if !f.Column.NotNull { + prefix = "&n" + } for _, embed := range f.EmbedFields { - out = append(out, "&"+v.Name+"."+f.Name+"."+embed) + out = append(out, prefix+f.Name+"."+embed) } continue } diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index f5ecd124a1..91e498972b 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -382,6 +382,9 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn f.Type = goType(req, c.Column) } else { f.Type = c.embed.modelType + if !c.NotNull { + f.Type = fmt.Sprintf("*%s", c.embed.modelType) + } f.EmbedFields = c.embed.fields } diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index 1736fa11f7..dca4d66894 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -28,15 +28,50 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) + rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) + rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- end}} {{- if ne .Arg.Pair .Ret.Pair }} var {{.Ret.Name}} {{.Ret.Type}} {{- end}} - err := row.Scan({{.Ret.Scan}}) + if err != nil { + return {{.Ret.ReturnName}}, err + } + {{- .Ret.DeclareNullableEmbeds}} + var cols []interface{} + cols = append(cols, {{.Ret.Scan}}) + defer rows.Close() + // This effectively duplicates the behaviour of Row.Scan, which we can't use (because it doesn't + // provide Values). + if !rows.Next() { + if rows.Err() == nil { + return {{.Ret.ReturnName}}, pgx.ErrNoRows + } + return {{.Ret.ReturnName}}, rows.Err() + } + {{if .Ret.NullableIndices -}} + vals, verr := rows.Values() + if verr != nil { + return {{.Ret.ReturnName}}, verr + } + {{- range $nembed := .Ret.NullableIndices}} + if + {{- range $nembed}} + vals[{{.}}] == nil && + {{- end -}} + true { + {{range $nembed}} + cols[{{.}}] = nil + {{- end}} + } + {{- end -}} + {{- end}} + if err := rows.Scan(cols...); err != nil { + return {{.Ret.ReturnName}}, err + } + {{- .Ret.AssignNullableEmbeds}} return {{.Ret.ReturnName}}, err } {{end}} diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index 519c693bc4..453d70cbe6 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -153,6 +153,8 @@ import ( {{range .}}{{.}} {{end}} {{end}} + // Obviously, this is temporary + "github.com/jackc/pgx/v5" ) {{template "queryCode" . }} diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 9b14fb83c2..6fb92f13ee 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -263,6 +263,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er cols = append(cols, &Column{ Name: embed.Table.Name, EmbedTable: embed.Table, + NotNull: !embed.Nullable, }) continue } diff --git a/internal/sql/rewrite/embeds.go b/internal/sql/rewrite/embeds.go index 1b132ec920..c9a79de941 100644 --- a/internal/sql/rewrite/embeds.go +++ b/internal/sql/rewrite/embeds.go @@ -7,16 +7,23 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/astutils" ) -// Embed is an instance of `sqlc.embed(param)` +// Embed is an instance of `sqlc.embed(param)` or `sqlc.nembed(param)`. +// The only difference in an embed generated with `nembed` is that `Nullable` +// will always be `true`. type Embed struct { - Table *ast.TableName - param string - Node *ast.ColumnRef + Table *ast.TableName + param string + Node *ast.ColumnRef + Nullable bool } // Orig string to replace func (e Embed) Orig() string { - return fmt.Sprintf("sqlc.embed(%s)", e.param) + fName := "embed" + if e.Nullable { + fName = "nembed" + } + return fmt.Sprintf("sqlc.%s(%s)", fName, e.param) } // EmbedSet is a set of Embed instances @@ -32,9 +39,9 @@ func (es EmbedSet) Find(node *ast.ColumnRef) (*Embed, bool) { return nil, false } -// Embeds rewrites `sqlc.embed(param)` to a `ast.ColumnRef` of form `param.*`. -// The compiler can make use of the returned `EmbedSet` while expanding the -// `param.*` column refs to produce the correct source edits. +// Embeds rewrites `sqlc.embed(param)` or `sqlc.nembed(param)` to an `ast.ColumnRef` +// of form `param.*`. The compiler can make use of the returned `EmbedSet` while +// expanding the `param.*` column refs to produce the correct source edits. func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) { var embeds []*Embed @@ -60,10 +67,15 @@ func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) { }, } + nullable := false + if fun.Func.Name == "nembed" { + nullable = true + } embeds = append(embeds, &Embed{ - Table: &ast.TableName{Name: param}, - param: param, - Node: node, + Table: &ast.TableName{Name: param}, + param: param, + Node: node, + Nullable: nullable, }) cr.Replace(node) @@ -86,6 +98,6 @@ func isEmbed(node ast.Node) bool { return false } - isValid := call.Func.Schema == "sqlc" && call.Func.Name == "embed" + isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "embed" || call.Func.Name == "nembed") return isValid } diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index bbda232c63..c3f16514d0 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -31,10 +31,10 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { return v } - // Custom validation for sqlc.arg, sqlc.narg and sqlc.slice + // Custom validation for `sqlc.` functions. // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed") { + if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed" || fn.Name == "nembed") { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } @@ -57,8 +57,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { return nil } - // If we have sqlc.arg or sqlc.narg, there is no need to resolve the function call. - // It won't resolve anyway, sinc it is not a real function. + // Don't attempt to resolve `sqlc.` functions. return nil } From 51a8679dd400048970de925e9becda771fa9e269 Mon Sep 17 00:00:00 2001 From: Rich Churcher Date: Wed, 19 Jul 2023 13:35:42 +1200 Subject: [PATCH 2/2] Marginally saner approach --- internal/codegen/golang/query.go | 9 +++++---- .../golang/templates/pgx/queryCode.tmpl | 20 ++++++------------- .../codegen/golang/templates/template.tmpl | 20 +++++++++++++++++++ 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 42b8a15266..a581cfa3fc 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -2,6 +2,7 @@ package golang import ( "fmt" + "strconv" "strings" "github.com/kyleconroy/sqlc/internal/metadata" @@ -178,16 +179,16 @@ func (v QueryValue) DeclareNullableEmbeds() string { return "\n" + strings.Join(out, "\n") } -func (v QueryValue) NullableIndices() [][]int { - var out [][]int +func (v QueryValue) NullableIndices() []string { + var out []string fieldIdx := 0 if v.Struct != nil { for _, f := range v.Struct.Fields { if len(f.EmbedFields) > 0 { - var nullableIndices []int + var nullableIndices string for range f.EmbedFields { if !f.Column.NotNull { - nullableIndices = append(nullableIndices, fieldIdx) + nullableIndices += strconv.Itoa(fieldIdx) + "," } fieldIdx++ } diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index dca4d66894..9d87224c90 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -40,8 +40,9 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De return {{.Ret.ReturnName}}, err } {{- .Ret.DeclareNullableEmbeds}} - var cols []interface{} - cols = append(cols, {{.Ret.Scan}}) + cols := []interface{}{ + {{- .Ret.Scan -}} + } defer rows.Close() // This effectively duplicates the behaviour of Row.Scan, which we can't use (because it doesn't // provide Values). @@ -56,18 +57,9 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De if verr != nil { return {{.Ret.ReturnName}}, verr } - {{- range $nembed := .Ret.NullableIndices}} - if - {{- range $nembed}} - vals[{{.}}] == nil && - {{- end -}} - true { - {{range $nembed}} - cols[{{.}}] = nil - {{- end}} - } - {{- end -}} - {{- end}} + nullableIndices := [][]int{ {{- range .Ret.NullableIndices}}[]int{ {{- . -}}}, {{- end -}} } + setEmbedsNil(vals, cols, nullableIndices) + {{end -}} if err := rows.Scan(cols...); err != nil { return {{.Ret.ReturnName}}, err } diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index 453d70cbe6..fd27c8f48e 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -157,6 +157,26 @@ import ( "github.com/jackc/pgx/v5" ) +// TODO: naming feels off +func setEmbedsNil(dbVals []interface{}, fields []interface{}, nullableIndices [][]int) { + for _, nembed := range nullableIndices { + setNil := true + for _, idx := range nembed { + // Any non-NULL value in the query result will cause a Scan attempt into the + // intermediate struct. + if dbVals[idx] != nil { + setNil = false + break + } + } + if setNil { + for _, idx := range nembed { + fields[idx] = nil + } + } + } +} + {{template "queryCode" . }} {{end}}