Skip to content

Commit 34df824

Browse files
committed
Improve support for returning tables from postgresql functions.
1 parent 08ecde7 commit 34df824

File tree

9 files changed

+142
-15
lines changed

9 files changed

+142
-15
lines changed

internal/compiler/output_columns.go

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,22 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
318318
}
319319
fun, err := qc.catalog.ResolveFuncCall(n)
320320
if err == nil {
321+
var tableCols []*catalog.Argument
322+
for _, arg := range fun.Args {
323+
if arg.Mode == ast.FuncParamTable {
324+
tableCols = append(tableCols, arg)
325+
}
326+
}
327+
var dt string
328+
if len(tableCols) == 1 {
329+
// A single column will later generate a scalar (or slice of scalar) return type.
330+
dt = dataType(tableCols[0].Type)
331+
} else {
332+
dt = dataType(fun.ReturnType)
333+
}
321334
cols = append(cols, &Column{
322335
Name: name,
323-
DataType: dataType(fun.ReturnType),
336+
DataType: dt,
324337
NotNull: !fun.ReturnTypeNullable,
325338
IsFuncCall: true,
326339
})
@@ -553,6 +566,37 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro
553566
if err != nil {
554567
continue
555568
}
569+
fnsWithName, err := c.catalog.ListFuncsByName(&ast.FuncName{
570+
Catalog: fn.ReturnType.Catalog,
571+
Schema: fn.ReturnType.Schema,
572+
Name: fn.ReturnType.Name,
573+
})
574+
// If the function was found, build a table structure to hold the columns to output.
575+
if err == nil && len(fnsWithName) == 1 {
576+
fnWithName := fnsWithName[0]
577+
rel := &ast.TableName{
578+
Catalog: fn.ReturnType.Catalog,
579+
Schema: fn.ReturnType.Schema,
580+
Name: fn.ReturnType.Name,
581+
}
582+
var cols []*Column
583+
for _, arg := range fnWithName.Args {
584+
if arg.Mode == ast.FuncParamTable {
585+
col := &catalog.Column{
586+
Name: arg.Name,
587+
Type: *arg.Type,
588+
IsArray: arg.IsArray,
589+
}
590+
convertedCol := ConvertColumn(rel, col)
591+
cols = append(cols, convertedCol)
592+
}
593+
}
594+
tables = append(tables, &Table{
595+
Rel: rel,
596+
Columns: cols,
597+
})
598+
continue
599+
}
556600
var table *Table
557601
if fn.ReturnType != nil {
558602
table, err = qc.GetTable(&ast.TableName{

internal/compiler/resolve.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,19 +366,27 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
366366
continue
367367
}
368368

369-
var paramName string
370-
var paramType *ast.TypeName
369+
var (
370+
paramName string
371+
paramType *ast.TypeName
372+
isArray bool
373+
arrayDims int
374+
)
371375

372376
if argName == "" {
373377
if i < len(fun.Args) {
374378
paramName = fun.Args[i].Name
375379
paramType = fun.Args[i].Type
380+
isArray = fun.Args[i].IsArray
381+
arrayDims = fun.Args[i].ArrayDims
376382
}
377383
} else {
378384
paramName = argName
379385
for _, arg := range fun.Args {
380386
if arg.Name == argName {
381387
paramType = arg.Type
388+
isArray = arg.IsArray
389+
arrayDims = arg.ArrayDims
382390
}
383391
}
384392
if paramType == nil {
@@ -402,6 +410,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
402410
NotNull: p.NotNull(),
403411
IsNamedParam: isNamed,
404412
IsSqlcSlice: p.IsSqlcSlice(),
413+
IsArray: isArray,
414+
ArrayDims: arrayDims,
405415
},
406416
})
407417
}

internal/endtoend/testdata/func_return_table/postgresql/pgx/go/query.sql.go

Lines changed: 25 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
-- name: Foo :one
22
SELECT * FROM register_account('a', 'b');
3+
4+
-- name: GetAccount :one
5+
SELECT * FROM get_account($1, $2);

internal/endtoend/testdata/func_return_table/postgresql/pgx/schema.sql

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,27 @@ BEGIN
2424

2525
RETURN NEXT;
2626
END;
27-
$$ LANGUAGE plpgsql;
27+
$$
28+
LANGUAGE plpgsql;
29+
30+
CREATE OR REPLACE FUNCTION get_account(
31+
_account_id INTEGER,
32+
_tags TEXT[][] -- test multidimensional array code generation
33+
)
34+
RETURNS TABLE(
35+
account_id INTEGER,
36+
username TEXT
37+
)
38+
AS $$
39+
BEGIN
40+
SELECT
41+
account_id,
42+
username
43+
FROM
44+
accounts
45+
WHERE
46+
account_id = _account_id;
47+
END;
48+
$$
49+
LANGUAGE plpgsql;
50+

internal/engine/postgresql/parse.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,11 @@ func translate(node *nodes.Node) (ast.Node, error) {
509509
return nil, err
510510
}
511511
fp := &ast.FuncParam{
512-
Name: &arg.Name,
513-
Type: rel.TypeName(),
514-
Mode: mode,
512+
Name: &arg.Name,
513+
Type: rel.TypeName(),
514+
Mode: mode,
515+
IsArray: isArray(arg.ArgType),
516+
ArrayDims: len(arg.ArgType.ArrayBounds),
515517
}
516518
if arg.Defexpr != nil {
517519
fp.DefExpr = &ast.TODO{}

internal/sql/ast/func_param.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ const (
1212
)
1313

1414
type FuncParam struct {
15-
Name *string
16-
Type *TypeName
17-
DefExpr Node // Will always be &ast.TODO
18-
Mode FuncParamMode
15+
Name *string
16+
Type *TypeName
17+
DefExpr Node // Will always be &ast.TODO
18+
Mode FuncParamMode
19+
IsArray bool
20+
ArrayDims int
1921
}
2022

2123
func (n *FuncParam) Pos() int {

internal/sql/ast/function_parameter.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ type FunctionParameter struct {
55
ArgType *TypeName
66
Mode FunctionParameterMode
77
Defexpr Node
8+
IsArray bool
89
}
910

1011
func (n *FunctionParameter) Pos() int {

internal/sql/catalog/func.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ type Argument struct {
2424
Type *ast.TypeName
2525
HasDefault bool
2626
Mode ast.FuncParamMode
27+
IsArray bool
28+
ArrayDims int
2729
}
2830

2931
func (f *Function) InArgs() []*Argument {
@@ -65,21 +67,40 @@ func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error {
6567
ReturnType: stmt.ReturnType,
6668
}
6769
types := make([]*ast.TypeName, len(stmt.Params.Items))
70+
var cols []*ast.ColumnDef
6871
for i, item := range stmt.Params.Items {
6972
arg := item.(*ast.FuncParam)
7073
var name string
7174
if arg.Name != nil {
7275
name = *arg.Name
7376
}
77+
if arg.Mode == ast.FuncParamTable {
78+
cols = append(cols, &ast.ColumnDef{
79+
Colname: name,
80+
TypeName: arg.Type,
81+
IsArray: arg.IsArray,
82+
ArrayDims: arg.ArrayDims,
83+
})
84+
}
7485
fn.Args[i] = &Argument{
7586
Name: name,
7687
Type: arg.Type,
7788
Mode: arg.Mode,
7889
HasDefault: arg.DefExpr != nil,
90+
IsArray: arg.IsArray,
91+
ArrayDims: arg.ArrayDims,
7992
}
8093
types[i] = arg.Type
8194
}
8295

96+
if len(cols) > 0 {
97+
fn.ReturnType = &ast.TypeName{
98+
Name: stmt.Func.Name,
99+
Schema: stmt.Func.Schema,
100+
Catalog: stmt.Func.Catalog,
101+
}
102+
}
103+
83104
_, idx, err := s.getFunc(stmt.Func, types)
84105
if err == nil && !stmt.Replace {
85106
return sqlerr.RelationExists(stmt.Func.Name)

0 commit comments

Comments
 (0)