diff --git a/internal/endtoend/testdata/func_args/endtoend.json b/internal/endtoend/testdata/func_args/endtoend.json new file mode 100644 index 0000000000..dbc86a9a00 --- /dev/null +++ b/internal/endtoend/testdata/func_args/endtoend.json @@ -0,0 +1,4 @@ +{ + "experimental_parser_only": true +} + \ No newline at end of file diff --git a/internal/endtoend/testdata/func_args/go/query.sql.go b/internal/endtoend/testdata/func_args/go/query.sql.go index f86ffe4fc5..aabb4c6bab 100644 --- a/internal/endtoend/testdata/func_args/go/query.sql.go +++ b/internal/endtoend/testdata/func_args/go/query.sql.go @@ -55,3 +55,14 @@ func (q *Queries) Plus(ctx context.Context, arg PlusParams) (int32, error) { err := row.Scan(&plus) return plus, err } + +const tableArgs = `-- name: TableArgs :one +SELECT table_args(x => $1) +` + +func (q *Queries) TableArgs(ctx context.Context, x int32) (int32, error) { + row := q.db.QueryRowContext(ctx, tableArgs, x) + var table_args int32 + err := row.Scan(&table_args) + return table_args, err +} diff --git a/internal/endtoend/testdata/func_args/query.sql b/internal/endtoend/testdata/func_args/query.sql index 51afe879ae..bbc440081a 100644 --- a/internal/endtoend/testdata/func_args/query.sql +++ b/internal/endtoend/testdata/func_args/query.sql @@ -4,6 +4,8 @@ CREATE FUNCTION plus(a integer, b integer) RETURNS integer AS $$ END; $$ LANGUAGE plpgsql; +CREATE FUNCTION table_args(x INT) RETURNS TABLE (y INT) AS 'SELECT x' LANGUAGE sql; + -- name: Plus :one SELECT plus(b => $2, a => $1); @@ -16,4 +18,5 @@ SELECT make_interval(days => $1::int); -- name: MakeIntervalMonths :one SELECT make_interval(months => sqlc.arg('months')::int); - +-- name: TableArgs :one +SELECT table_args(x => $1); diff --git a/internal/postgresql/parse.go b/internal/postgresql/parse.go index 2211df712a..912e5023f2 100644 --- a/internal/postgresql/parse.go +++ b/internal/postgresql/parse.go @@ -42,6 +42,23 @@ func parseFuncName(node nodes.Node) (*ast.FuncName, error) { }, nil } +func parseFuncParamMode(m nodes.FunctionParameterMode) (ast.FuncParamMode, error) { + switch m { + case 'i': + return ast.FuncParamIn, nil + case 'o': + return ast.FuncParamOut, nil + case 'b': + return ast.FuncParamInOut, nil + case 'v': + return ast.FuncParamVariadic, nil + case 't': + return ast.FuncParamTable, nil + default: + return -1, fmt.Errorf("parse func param: invalid mode %v", m) + } +} + func parseTypeName(node nodes.Node) (*ast.TypeName, error) { rel, err := parseRelation(node) if err != nil { @@ -434,9 +451,14 @@ func translate(node nodes.Node) (ast.Node, error) { if err != nil { return nil, err } + mode, err := parseFuncParamMode(arg.Mode) + if err != nil { + return nil, err + } fp := &ast.FuncParam{ Name: arg.Name, Type: tn, + Mode: mode, } if arg.Defexpr != nil { fp.DefExpr = &ast.TODO{} diff --git a/internal/sql/ast/func_param.go b/internal/sql/ast/func_param.go index 4e16c61937..faee6ede37 100644 --- a/internal/sql/ast/func_param.go +++ b/internal/sql/ast/func_param.go @@ -1,9 +1,20 @@ package ast +type FuncParamMode int + +const ( + FuncParamIn FuncParamMode = iota + FuncParamOut + FuncParamInOut + FuncParamVariadic + FuncParamTable +) + type FuncParam struct { Name *string Type *TypeName DefExpr Node // Will always be &ast.TODO + Mode FuncParamMode } func (n *FuncParam) Pos() int { diff --git a/internal/sql/catalog/catalog.go b/internal/sql/catalog/catalog.go index 31cc0ddcc3..683a90cb19 100644 --- a/internal/sql/catalog/catalog.go +++ b/internal/sql/catalog/catalog.go @@ -106,11 +106,13 @@ func (s *Schema) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int if s.Funcs[i].Name != rel.Name { continue } - if len(s.Funcs[i].Args) != len(tns) { + + args := s.Funcs[i].InArgs() + if len(args) != len(tns) { continue } found := true - for j := range s.Funcs[i].Args { + for j := range args { if !sameType(s.Funcs[i].Args[j].Type, tns[j]) { found = false break @@ -215,10 +217,24 @@ type Function struct { Desc string } +func (f *Function) InArgs() []*Argument { + var args []*Argument + for _, a := range f.Args { + switch a.Mode { + case ast.FuncParamTable, ast.FuncParamOut: + continue + default: + args = append(args, a) + } + } + return args +} + type Argument struct { Name string Type *ast.TypeName HasDefault bool + Mode ast.FuncParamMode } func New(def string) *Catalog { diff --git a/internal/sql/catalog/func.go b/internal/sql/catalog/func.go index 1a92ea7116..279810e208 100644 --- a/internal/sql/catalog/func.go +++ b/internal/sql/catalog/func.go @@ -30,6 +30,7 @@ func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error { fn.Args[i] = &Argument{ Name: name, Type: arg.Type, + Mode: arg.Mode, HasDefault: arg.DefExpr != nil, } types[i] = arg.Type diff --git a/internal/sql/catalog/public.go b/internal/sql/catalog/public.go index d3bbd2b1fb..738bbca2ff 100644 --- a/internal/sql/catalog/public.go +++ b/internal/sql/catalog/public.go @@ -38,7 +38,9 @@ func (c *Catalog) GetFuncN(rel *ast.FuncName, n int) (Function, error) { if s.Funcs[i].Name != rel.Name { continue } - if len(s.Funcs[i].Args) == n { + + args := s.Funcs[i].InArgs() + if len(args) == n { return *s.Funcs[i], nil } } diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index 1d3e28e3e1..289babcdfd 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -39,7 +39,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { args = len(call.Args.Items) } for _, fun := range funs { - if len(fun.Args) == args { + if len(fun.InArgs()) == args { return v } }