Skip to content

Commit 0478ca8

Browse files
catalog: support functions with table parameters (#541)
Co-authored-by: Kyle Conroy <kyle@conroy.org>
1 parent e0666b5 commit 0478ca8

File tree

9 files changed

+75
-5
lines changed

9 files changed

+75
-5
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"experimental_parser_only": true
3+
}
4+

internal/endtoend/testdata/func_args/go/query.sql.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/func_args/query.sql

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ CREATE FUNCTION plus(a integer, b integer) RETURNS integer AS $$
44
END;
55
$$ LANGUAGE plpgsql;
66

7+
CREATE FUNCTION table_args(x INT) RETURNS TABLE (y INT) AS 'SELECT x' LANGUAGE sql;
8+
79
-- name: Plus :one
810
SELECT plus(b => $2, a => $1);
911

@@ -16,4 +18,5 @@ SELECT make_interval(days => $1::int);
1618
-- name: MakeIntervalMonths :one
1719
SELECT make_interval(months => sqlc.arg('months')::int);
1820

19-
21+
-- name: TableArgs :one
22+
SELECT table_args(x => $1);

internal/postgresql/parse.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ func parseFuncName(node nodes.Node) (*ast.FuncName, error) {
4242
}, nil
4343
}
4444

45+
func parseFuncParamMode(m nodes.FunctionParameterMode) (ast.FuncParamMode, error) {
46+
switch m {
47+
case 'i':
48+
return ast.FuncParamIn, nil
49+
case 'o':
50+
return ast.FuncParamOut, nil
51+
case 'b':
52+
return ast.FuncParamInOut, nil
53+
case 'v':
54+
return ast.FuncParamVariadic, nil
55+
case 't':
56+
return ast.FuncParamTable, nil
57+
default:
58+
return -1, fmt.Errorf("parse func param: invalid mode %v", m)
59+
}
60+
}
61+
4562
func parseTypeName(node nodes.Node) (*ast.TypeName, error) {
4663
rel, err := parseRelation(node)
4764
if err != nil {
@@ -434,9 +451,14 @@ func translate(node nodes.Node) (ast.Node, error) {
434451
if err != nil {
435452
return nil, err
436453
}
454+
mode, err := parseFuncParamMode(arg.Mode)
455+
if err != nil {
456+
return nil, err
457+
}
437458
fp := &ast.FuncParam{
438459
Name: arg.Name,
439460
Type: tn,
461+
Mode: mode,
440462
}
441463
if arg.Defexpr != nil {
442464
fp.DefExpr = &ast.TODO{}

internal/sql/ast/func_param.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
package ast
22

3+
type FuncParamMode int
4+
5+
const (
6+
FuncParamIn FuncParamMode = iota
7+
FuncParamOut
8+
FuncParamInOut
9+
FuncParamVariadic
10+
FuncParamTable
11+
)
12+
313
type FuncParam struct {
414
Name *string
515
Type *TypeName
616
DefExpr Node // Will always be &ast.TODO
17+
Mode FuncParamMode
718
}
819

920
func (n *FuncParam) Pos() int {

internal/sql/catalog/catalog.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,13 @@ func (s *Schema) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int
106106
if s.Funcs[i].Name != rel.Name {
107107
continue
108108
}
109-
if len(s.Funcs[i].Args) != len(tns) {
109+
110+
args := s.Funcs[i].InArgs()
111+
if len(args) != len(tns) {
110112
continue
111113
}
112114
found := true
113-
for j := range s.Funcs[i].Args {
115+
for j := range args {
114116
if !sameType(s.Funcs[i].Args[j].Type, tns[j]) {
115117
found = false
116118
break
@@ -215,10 +217,24 @@ type Function struct {
215217
Desc string
216218
}
217219

220+
func (f *Function) InArgs() []*Argument {
221+
var args []*Argument
222+
for _, a := range f.Args {
223+
switch a.Mode {
224+
case ast.FuncParamTable, ast.FuncParamOut:
225+
continue
226+
default:
227+
args = append(args, a)
228+
}
229+
}
230+
return args
231+
}
232+
218233
type Argument struct {
219234
Name string
220235
Type *ast.TypeName
221236
HasDefault bool
237+
Mode ast.FuncParamMode
222238
}
223239

224240
func New(def string) *Catalog {

internal/sql/catalog/func.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error {
3030
fn.Args[i] = &Argument{
3131
Name: name,
3232
Type: arg.Type,
33+
Mode: arg.Mode,
3334
HasDefault: arg.DefExpr != nil,
3435
}
3536
types[i] = arg.Type

internal/sql/catalog/public.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ func (c *Catalog) GetFuncN(rel *ast.FuncName, n int) (Function, error) {
3838
if s.Funcs[i].Name != rel.Name {
3939
continue
4040
}
41-
if len(s.Funcs[i].Args) == n {
41+
42+
args := s.Funcs[i].InArgs()
43+
if len(args) == n {
4244
return *s.Funcs[i], nil
4345
}
4446
}

internal/sql/validate/func_call.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor {
3939
args = len(call.Args.Items)
4040
}
4141
for _, fun := range funs {
42-
if len(fun.Args) == args {
42+
if len(fun.InArgs()) == args {
4343
return v
4444
}
4545
}

0 commit comments

Comments
 (0)