Skip to content

Add sqlc.nembed #2472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package golang

import (
"fmt"
"strconv"
"strings"

"github.com/kyleconroy/sqlc/internal/metadata"
Expand Down Expand Up @@ -154,6 +155,54 @@ func (v QueryValue) HasSqlcSlices() bool {
return false
}

func (v QueryValue) AssignNullableEmbeds() string {
var out []string
if v.Struct != nil {
for _, f := range v.Struct.Fields {
if len(f.EmbedFields) > 0 && !f.Column.NotNull {
out = append(out, v.Name+"."+f.Name+" = &n"+f.Name)
}
}
}
return "\n" + strings.Join(out, "\n")
}

func (v QueryValue) DeclareNullableEmbeds() string {
var out []string
if v.Struct != nil {
for _, f := range v.Struct.Fields {
if len(f.EmbedFields) > 0 && !f.Column.NotNull {
out = append(out, "var n"+f.Name+" "+f.Type[1:])
}
}
}
return "\n" + strings.Join(out, "\n")
}

func (v QueryValue) NullableIndices() []string {
var out []string
fieldIdx := 0
if v.Struct != nil {
for _, f := range v.Struct.Fields {
if len(f.EmbedFields) > 0 {
var nullableIndices string
for range f.EmbedFields {
if !f.Column.NotNull {
nullableIndices += strconv.Itoa(fieldIdx) + ","
}
fieldIdx++
}
if len(nullableIndices) > 0 {
out = append(out, nullableIndices)
}
} else {
fieldIdx++
}
}
}
return out
}

func (v QueryValue) Scan() string {
var out []string
if v.Struct == nil {
Expand All @@ -167,8 +216,14 @@ func (v QueryValue) Scan() string {

// append any embedded fields
if len(f.EmbedFields) > 0 {
prefix := "&" + v.Name + "."
// Regular embeds go straight into the return struct, nembed uses an intermediate
// value to check for NULL
if !f.Column.NotNull {
prefix = "&n"
}
for _, embed := range f.EmbedFields {
out = append(out, "&"+v.Name+"."+f.Name+"."+embed)
out = append(out, prefix+f.Name+"."+embed)
}
continue
}
Expand Down
3 changes: 3 additions & 0 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
f.Type = goType(req, c.Column)
} else {
f.Type = c.embed.modelType
if !c.NotNull {
f.Type = fmt.Sprintf("*%s", c.embed.modelType)
}
f.EmbedFields = c.embed.fields
}

Expand Down
33 changes: 30 additions & 3 deletions internal/codegen/golang/templates/pgx/queryCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,42 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}}
{{end -}}
{{- if $.EmitMethodsWithDBArgument -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {
row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {
row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- if ne .Arg.Pair .Ret.Pair }}
var {{.Ret.Name}} {{.Ret.Type}}
{{- end}}
err := row.Scan({{.Ret.Scan}})
if err != nil {
return {{.Ret.ReturnName}}, err
}
{{- .Ret.DeclareNullableEmbeds}}
cols := []interface{}{
{{- .Ret.Scan -}}
}
defer rows.Close()
// This effectively duplicates the behaviour of Row.Scan, which we can't use (because it doesn't
// provide Values).
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Er, by which I mean Row doesn't provide Values, as it's just a wrapper around Rows that QueryRow returns. Unless I'm missing something.

if !rows.Next() {
if rows.Err() == nil {
return {{.Ret.ReturnName}}, pgx.ErrNoRows
}
return {{.Ret.ReturnName}}, rows.Err()
}
{{if .Ret.NullableIndices -}}
vals, verr := rows.Values()
if verr != nil {
return {{.Ret.ReturnName}}, verr
}
nullableIndices := [][]int{ {{- range .Ret.NullableIndices}}[]int{ {{- . -}}}, {{- end -}} }
setEmbedsNil(vals, cols, nullableIndices)
{{end -}}
if err := rows.Scan(cols...); err != nil {
return {{.Ret.ReturnName}}, err
}
{{- .Ret.AssignNullableEmbeds}}
return {{.Ret.ReturnName}}, err
}
{{end}}
Expand Down
22 changes: 22 additions & 0 deletions internal/codegen/golang/templates/template.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,30 @@ import (
{{range .}}{{.}}
{{end}}
{{end}}
// Obviously, this is temporary
"github.com/jackc/pgx/v5"
)

// TODO: naming feels off
func setEmbedsNil(dbVals []interface{}, fields []interface{}, nullableIndices [][]int) {
for _, nembed := range nullableIndices {
setNil := true
for _, idx := range nembed {
// Any non-NULL value in the query result will cause a Scan attempt into the
// intermediate struct.
if dbVals[idx] != nil {
setNil = false
break
}
}
if setNil {
for _, idx := range nembed {
fields[idx] = nil
}
}
}
}

{{template "queryCode" . }}
{{end}}

Expand Down
1 change: 1 addition & 0 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
cols = append(cols, &Column{
Name: embed.Table.Name,
EmbedTable: embed.Table,
NotNull: !embed.Nullable,
})
continue
}
Expand Down
36 changes: 24 additions & 12 deletions internal/sql/rewrite/embeds.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@ import (
"github.com/kyleconroy/sqlc/internal/sql/astutils"
)

// Embed is an instance of `sqlc.embed(param)`
// Embed is an instance of `sqlc.embed(param)` or `sqlc.nembed(param)`.
// The only difference in an embed generated with `nembed` is that `Nullable`
// will always be `true`.
type Embed struct {
Table *ast.TableName
param string
Node *ast.ColumnRef
Table *ast.TableName
param string
Node *ast.ColumnRef
Nullable bool
}

// Orig string to replace
func (e Embed) Orig() string {
return fmt.Sprintf("sqlc.embed(%s)", e.param)
fName := "embed"
if e.Nullable {
fName = "nembed"
}
return fmt.Sprintf("sqlc.%s(%s)", fName, e.param)
}

// EmbedSet is a set of Embed instances
Expand All @@ -32,9 +39,9 @@ func (es EmbedSet) Find(node *ast.ColumnRef) (*Embed, bool) {
return nil, false
}

// Embeds rewrites `sqlc.embed(param)` to a `ast.ColumnRef` of form `param.*`.
// The compiler can make use of the returned `EmbedSet` while expanding the
// `param.*` column refs to produce the correct source edits.
// Embeds rewrites `sqlc.embed(param)` or `sqlc.nembed(param)` to an `ast.ColumnRef`
// of form `param.*`. The compiler can make use of the returned `EmbedSet` while
// expanding the `param.*` column refs to produce the correct source edits.
func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) {
var embeds []*Embed

Expand All @@ -60,10 +67,15 @@ func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) {
},
}

nullable := false
if fun.Func.Name == "nembed" {
nullable = true
}
embeds = append(embeds, &Embed{
Table: &ast.TableName{Name: param},
param: param,
Node: node,
Table: &ast.TableName{Name: param},
param: param,
Node: node,
Nullable: nullable,
})

cr.Replace(node)
Expand All @@ -86,6 +98,6 @@ func isEmbed(node ast.Node) bool {
return false
}

isValid := call.Func.Schema == "sqlc" && call.Func.Name == "embed"
isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "embed" || call.Func.Name == "nembed")
return isValid
}
7 changes: 3 additions & 4 deletions internal/sql/validate/func_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor {
return v
}

// Custom validation for sqlc.arg, sqlc.narg and sqlc.slice
// Custom validation for `sqlc.` functions.
// TODO: Replace this once type-checking is implemented
if fn.Schema == "sqlc" {
if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed") {
if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed" || fn.Name == "nembed") {
v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name)
return nil
}
Expand All @@ -57,8 +57,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor {
return nil
}

// If we have sqlc.arg or sqlc.narg, there is no need to resolve the function call.
// It won't resolve anyway, sinc it is not a real function.
// Don't attempt to resolve `sqlc.` functions.
return nil
}

Expand Down