Skip to content

Commit 66f537c

Browse files
committed
Improve support for returning tables from postgresql functions.
1 parent a334908 commit 66f537c

File tree

9 files changed

+155
-18
lines changed

9 files changed

+155
-18
lines changed

internal/compiler/output_columns.go

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,22 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
318318
}
319319
fun, err := qc.catalog.ResolveFuncCall(n)
320320
if err == nil {
321+
var tableCols []*catalog.Argument
322+
for _, arg := range fun.Args {
323+
if arg.Mode == ast.FuncParamTable {
324+
tableCols = append(tableCols, arg)
325+
}
326+
}
327+
var dt string
328+
if len(tableCols) == 1 {
329+
// A single column will later generate a scalar (or slice of scalar) return type.
330+
dt = dataType(tableCols[0].Type)
331+
} else {
332+
dt = dataType(fun.ReturnType)
333+
}
321334
cols = append(cols, &Column{
322335
Name: name,
323-
DataType: dataType(fun.ReturnType),
336+
DataType: dt,
324337
NotNull: !fun.ReturnTypeNullable,
325338
IsFuncCall: true,
326339
})
@@ -553,12 +566,53 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro
553566
if err != nil {
554567
continue
555568
}
569+
var (
570+
fnRetCatalog string
571+
fnRetSchema string
572+
fnRetName string
573+
)
574+
if fn.ReturnType != nil {
575+
fnRetCatalog = fn.ReturnType.Catalog
576+
fnRetSchema = fn.ReturnType.Schema
577+
fnRetName = fn.ReturnType.Name
578+
}
579+
fnsWithName, err := c.catalog.ListFuncsByName(&ast.FuncName{
580+
Catalog: fnRetCatalog,
581+
Schema: fnRetSchema,
582+
Name: fnRetName,
583+
})
584+
// If the function was found, build a table structure to hold the columns to output.
585+
if err == nil && len(fnsWithName) == 1 {
586+
fnWithName := fnsWithName[0]
587+
rel := &ast.TableName{
588+
Catalog: fnRetCatalog,
589+
Schema: fnRetSchema,
590+
Name: fnRetName,
591+
}
592+
var cols []*Column
593+
for _, arg := range fnWithName.Args {
594+
if arg.Mode == ast.FuncParamTable {
595+
col := &catalog.Column{
596+
Name: arg.Name,
597+
Type: *arg.Type,
598+
IsArray: arg.IsArray,
599+
}
600+
convertedCol := ConvertColumn(rel, col)
601+
cols = append(cols, convertedCol)
602+
}
603+
}
604+
tables = append(tables, &Table{
605+
Rel: rel,
606+
Columns: cols,
607+
})
608+
continue
609+
}
556610
var table *Table
557611
if fn.ReturnType != nil {
558612
table, err = qc.GetTable(&ast.TableName{
559-
Catalog: fn.ReturnType.Catalog,
560-
Schema: fn.ReturnType.Schema,
561-
Name: fn.ReturnType.Name,
613+
Catalog: fnRetCatalog,
614+
Schema: fnRetSchema,
615+
Name: fnRetName,
562616
})
563617
}
564618
if table == nil || err != nil {

internal/compiler/resolve.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,19 +366,27 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
366366
continue
367367
}
368368

369-
var paramName string
370-
var paramType *ast.TypeName
369+
var (
370+
paramName string
371+
paramType *ast.TypeName
372+
isArray bool
373+
arrayDims int
374+
)
371375

372376
if argName == "" {
373377
if i < len(fun.Args) {
374378
paramName = fun.Args[i].Name
375379
paramType = fun.Args[i].Type
380+
isArray = fun.Args[i].IsArray
381+
arrayDims = fun.Args[i].ArrayDims
376382
}
377383
} else {
378384
paramName = argName
379385
for _, arg := range fun.Args {
380386
if arg.Name == argName {
381387
paramType = arg.Type
388+
isArray = arg.IsArray
389+
arrayDims = arg.ArrayDims
382390
}
383391
}
384392
if paramType == nil {
@@ -402,6 +410,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
402410
NotNull: p.NotNull(),
403411
IsNamedParam: isNamed,
404412
IsSqlcSlice: p.IsSqlcSlice(),
413+
IsArray: isArray,
414+
ArrayDims: arrayDims,
405415
},
406416
})
407417
}

internal/endtoend/testdata/func_return_table/postgresql/pgx/go/query.sql.go

Lines changed: 25 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
-- name: Foo :one
22
SELECT * FROM register_account('a', 'b');
3+
4+
-- name: GetAccount :one
5+
SELECT * FROM get_account($1, $2);

internal/endtoend/testdata/func_return_table/postgresql/pgx/schema.sql

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,27 @@ BEGIN
2424

2525
RETURN NEXT;
2626
END;
27-
$$ LANGUAGE plpgsql;
27+
$$
28+
LANGUAGE plpgsql;
29+
30+
CREATE OR REPLACE FUNCTION get_account(
31+
_account_id INTEGER,
32+
_tags TEXT[][] -- test multidimensional array code generation
33+
)
34+
RETURNS TABLE(
35+
account_id INTEGER,
36+
username TEXT
37+
)
38+
AS $$
39+
BEGIN
40+
SELECT
41+
account_id,
42+
username
43+
FROM
44+
accounts
45+
WHERE
46+
account_id = _account_id;
47+
END;
48+
$$
49+
LANGUAGE plpgsql;
50+

internal/engine/postgresql/parse.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,11 @@ func translate(node *nodes.Node) (ast.Node, error) {
509509
return nil, err
510510
}
511511
fp := &ast.FuncParam{
512-
Name: &arg.Name,
513-
Type: rel.TypeName(),
514-
Mode: mode,
512+
Name: &arg.Name,
513+
Type: rel.TypeName(),
514+
Mode: mode,
515+
IsArray: isArray(arg.ArgType),
516+
ArrayDims: len(arg.ArgType.ArrayBounds),
515517
}
516518
if arg.Defexpr != nil {
517519
fp.DefExpr = &ast.TODO{}

internal/sql/ast/func_param.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ const (
1212
)
1313

1414
type FuncParam struct {
15-
Name *string
16-
Type *TypeName
17-
DefExpr Node // Will always be &ast.TODO
18-
Mode FuncParamMode
15+
Name *string
16+
Type *TypeName
17+
DefExpr Node // Will always be &ast.TODO
18+
Mode FuncParamMode
19+
IsArray bool
20+
ArrayDims int
1921
}
2022

2123
func (n *FuncParam) Pos() int {

internal/sql/ast/function_parameter.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ type FunctionParameter struct {
55
ArgType *TypeName
66
Mode FunctionParameterMode
77
Defexpr Node
8+
IsArray bool
89
}
910

1011
func (n *FunctionParameter) Pos() int {

internal/sql/catalog/func.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ type Argument struct {
2424
Type *ast.TypeName
2525
HasDefault bool
2626
Mode ast.FuncParamMode
27+
IsArray bool
28+
ArrayDims int
2729
}
2830

2931
func (f *Function) InArgs() []*Argument {
@@ -65,21 +67,40 @@ func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error {
6567
ReturnType: stmt.ReturnType,
6668
}
6769
types := make([]*ast.TypeName, len(stmt.Params.Items))
70+
var cols []*ast.ColumnDef
6871
for i, item := range stmt.Params.Items {
6972
arg := item.(*ast.FuncParam)
7073
var name string
7174
if arg.Name != nil {
7275
name = *arg.Name
7376
}
77+
if arg.Mode == ast.FuncParamTable {
78+
cols = append(cols, &ast.ColumnDef{
79+
Colname: name,
80+
TypeName: arg.Type,
81+
IsArray: arg.IsArray,
82+
ArrayDims: arg.ArrayDims,
83+
})
84+
}
7485
fn.Args[i] = &Argument{
7586
Name: name,
7687
Type: arg.Type,
7788
Mode: arg.Mode,
7889
HasDefault: arg.DefExpr != nil,
90+
IsArray: arg.IsArray,
91+
ArrayDims: arg.ArrayDims,
7992
}
8093
types[i] = arg.Type
8194
}
8295

96+
if len(cols) > 0 {
97+
fn.ReturnType = &ast.TypeName{
98+
Name: stmt.Func.Name,
99+
Schema: stmt.Func.Schema,
100+
Catalog: stmt.Func.Catalog,
101+
}
102+
}
103+
83104
_, idx, err := s.getFunc(stmt.Func, types)
84105
if err == nil && !stmt.Replace {
85106
return sqlerr.RelationExists(stmt.Func.Name)

0 commit comments

Comments
 (0)