diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index aeb1c106a2..a581cfa3fc 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -2,6 +2,7 @@ package golang import ( "fmt" + "strconv" "strings" "github.com/kyleconroy/sqlc/internal/metadata" @@ -154,6 +155,54 @@ func (v QueryValue) HasSqlcSlices() bool { return false } +func (v QueryValue) AssignNullableEmbeds() string { + var out []string + if v.Struct != nil { + for _, f := range v.Struct.Fields { + if len(f.EmbedFields) > 0 && !f.Column.NotNull { + out = append(out, v.Name+"."+f.Name+" = &n"+f.Name) + } + } + } + return "\n" + strings.Join(out, "\n") +} + +func (v QueryValue) DeclareNullableEmbeds() string { + var out []string + if v.Struct != nil { + for _, f := range v.Struct.Fields { + if len(f.EmbedFields) > 0 && !f.Column.NotNull { + out = append(out, "var n"+f.Name+" "+f.Type[1:]) + } + } + } + return "\n" + strings.Join(out, "\n") +} + +func (v QueryValue) NullableIndices() []string { + var out []string + fieldIdx := 0 + if v.Struct != nil { + for _, f := range v.Struct.Fields { + if len(f.EmbedFields) > 0 { + var nullableIndices string + for range f.EmbedFields { + if !f.Column.NotNull { + nullableIndices += strconv.Itoa(fieldIdx) + "," + } + fieldIdx++ + } + if len(nullableIndices) > 0 { + out = append(out, nullableIndices) + } + } else { + fieldIdx++ + } + } + } + return out +} + func (v QueryValue) Scan() string { var out []string if v.Struct == nil { @@ -167,8 +216,14 @@ func (v QueryValue) Scan() string { // append any embedded fields if len(f.EmbedFields) > 0 { + prefix := "&" + v.Name + "." + // Regular embeds go straight into the return struct, nembed uses an intermediate + // value to check for NULL + if !f.Column.NotNull { + prefix = "&n" + } for _, embed := range f.EmbedFields { - out = append(out, "&"+v.Name+"."+f.Name+"."+embed) + out = append(out, prefix+f.Name+"."+embed) } continue } diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index f5ecd124a1..91e498972b 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -382,6 +382,9 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn f.Type = goType(req, c.Column) } else { f.Type = c.embed.modelType + if !c.NotNull { + f.Type = fmt.Sprintf("*%s", c.embed.modelType) + } f.EmbedFields = c.embed.fields } diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index 1736fa11f7..9d87224c90 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -28,15 +28,42 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) + rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) + rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- end}} {{- if ne .Arg.Pair .Ret.Pair }} var {{.Ret.Name}} {{.Ret.Type}} {{- end}} - err := row.Scan({{.Ret.Scan}}) + if err != nil { + return {{.Ret.ReturnName}}, err + } + {{- .Ret.DeclareNullableEmbeds}} + cols := []interface{}{ + {{- .Ret.Scan -}} + } + defer rows.Close() + // This effectively duplicates the behaviour of Row.Scan, which we can't use (because it doesn't + // provide Values). + if !rows.Next() { + if rows.Err() == nil { + return {{.Ret.ReturnName}}, pgx.ErrNoRows + } + return {{.Ret.ReturnName}}, rows.Err() + } + {{if .Ret.NullableIndices -}} + vals, verr := rows.Values() + if verr != nil { + return {{.Ret.ReturnName}}, verr + } + nullableIndices := [][]int{ {{- range .Ret.NullableIndices}}[]int{ {{- . -}}}, {{- end -}} } + setEmbedsNil(vals, cols, nullableIndices) + {{end -}} + if err := rows.Scan(cols...); err != nil { + return {{.Ret.ReturnName}}, err + } + {{- .Ret.AssignNullableEmbeds}} return {{.Ret.ReturnName}}, err } {{end}} diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index 519c693bc4..fd27c8f48e 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -153,8 +153,30 @@ import ( {{range .}}{{.}} {{end}} {{end}} + // Obviously, this is temporary + "github.com/jackc/pgx/v5" ) +// TODO: naming feels off +func setEmbedsNil(dbVals []interface{}, fields []interface{}, nullableIndices [][]int) { + for _, nembed := range nullableIndices { + setNil := true + for _, idx := range nembed { + // Any non-NULL value in the query result will cause a Scan attempt into the + // intermediate struct. + if dbVals[idx] != nil { + setNil = false + break + } + } + if setNil { + for _, idx := range nembed { + fields[idx] = nil + } + } + } +} + {{template "queryCode" . }} {{end}} diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 9b14fb83c2..6fb92f13ee 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -263,6 +263,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er cols = append(cols, &Column{ Name: embed.Table.Name, EmbedTable: embed.Table, + NotNull: !embed.Nullable, }) continue } diff --git a/internal/sql/rewrite/embeds.go b/internal/sql/rewrite/embeds.go index 1b132ec920..c9a79de941 100644 --- a/internal/sql/rewrite/embeds.go +++ b/internal/sql/rewrite/embeds.go @@ -7,16 +7,23 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/astutils" ) -// Embed is an instance of `sqlc.embed(param)` +// Embed is an instance of `sqlc.embed(param)` or `sqlc.nembed(param)`. +// The only difference in an embed generated with `nembed` is that `Nullable` +// will always be `true`. type Embed struct { - Table *ast.TableName - param string - Node *ast.ColumnRef + Table *ast.TableName + param string + Node *ast.ColumnRef + Nullable bool } // Orig string to replace func (e Embed) Orig() string { - return fmt.Sprintf("sqlc.embed(%s)", e.param) + fName := "embed" + if e.Nullable { + fName = "nembed" + } + return fmt.Sprintf("sqlc.%s(%s)", fName, e.param) } // EmbedSet is a set of Embed instances @@ -32,9 +39,9 @@ func (es EmbedSet) Find(node *ast.ColumnRef) (*Embed, bool) { return nil, false } -// Embeds rewrites `sqlc.embed(param)` to a `ast.ColumnRef` of form `param.*`. -// The compiler can make use of the returned `EmbedSet` while expanding the -// `param.*` column refs to produce the correct source edits. +// Embeds rewrites `sqlc.embed(param)` or `sqlc.nembed(param)` to an `ast.ColumnRef` +// of form `param.*`. The compiler can make use of the returned `EmbedSet` while +// expanding the `param.*` column refs to produce the correct source edits. func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) { var embeds []*Embed @@ -60,10 +67,15 @@ func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) { }, } + nullable := false + if fun.Func.Name == "nembed" { + nullable = true + } embeds = append(embeds, &Embed{ - Table: &ast.TableName{Name: param}, - param: param, - Node: node, + Table: &ast.TableName{Name: param}, + param: param, + Node: node, + Nullable: nullable, }) cr.Replace(node) @@ -86,6 +98,6 @@ func isEmbed(node ast.Node) bool { return false } - isValid := call.Func.Schema == "sqlc" && call.Func.Name == "embed" + isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "embed" || call.Func.Name == "nembed") return isValid } diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index bbda232c63..c3f16514d0 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -31,10 +31,10 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { return v } - // Custom validation for sqlc.arg, sqlc.narg and sqlc.slice + // Custom validation for `sqlc.` functions. // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed") { + if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed" || fn.Name == "nembed") { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } @@ -57,8 +57,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { return nil } - // If we have sqlc.arg or sqlc.narg, there is no need to resolve the function call. - // It won't resolve anyway, sinc it is not a real function. + // Don't attempt to resolve `sqlc.` functions. return nil }