Skip to content

Improve support for SQL functions that return tables #1973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 58 additions & 4 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,22 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
}
fun, err := qc.catalog.ResolveFuncCall(n)
if err == nil {
var tableCols []*catalog.Argument
for _, arg := range fun.Args {
if arg.Mode == ast.FuncParamTable {
tableCols = append(tableCols, arg)
}
}
var dt string
if len(tableCols) == 1 {
// A single column will later generate a scalar (or slice of scalar) return type.
dt = dataType(tableCols[0].Type)
} else {
dt = dataType(fun.ReturnType)
}
cols = append(cols, &Column{
Name: name,
DataType: dataType(fun.ReturnType),
DataType: dt,
NotNull: !fun.ReturnTypeNullable,
IsFuncCall: true,
})
Expand Down Expand Up @@ -553,12 +566,53 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro
if err != nil {
continue
}
var (
fnRetCatalog string
fnRetSchema string
fnRetName string
)
if fn.ReturnType != nil {
fnRetCatalog = fn.ReturnType.Catalog
fnRetSchema = fn.ReturnType.Schema
fnRetName = fn.ReturnType.Name
}
fnsWithName, err := c.catalog.ListFuncsByName(&ast.FuncName{
Catalog: fnRetCatalog,
Schema: fnRetSchema,
Name: fnRetName,
})
// If the function was found, build a table structure to hold the columns to output.
if err == nil && len(fnsWithName) == 1 {
fnWithName := fnsWithName[0]
rel := &ast.TableName{
Catalog: fnRetCatalog,
Schema: fnRetSchema,
Name: fnRetName,
}
var cols []*Column
for _, arg := range fnWithName.Args {
if arg.Mode == ast.FuncParamTable {
col := &catalog.Column{
Name: arg.Name,
Type: *arg.Type,
IsArray: arg.IsArray,
}
convertedCol := ConvertColumn(rel, col)
cols = append(cols, convertedCol)
}
}
tables = append(tables, &Table{
Rel: rel,
Columns: cols,
})
continue
}
var table *Table
if fn.ReturnType != nil {
table, err = qc.GetTable(&ast.TableName{
Catalog: fn.ReturnType.Catalog,
Schema: fn.ReturnType.Schema,
Name: fn.ReturnType.Name,
Catalog: fnRetCatalog,
Schema: fnRetSchema,
Name: fnRetName,
})
}
if table == nil || err != nil {
Expand Down
14 changes: 12 additions & 2 deletions internal/compiler/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,19 +366,27 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
continue
}

var paramName string
var paramType *ast.TypeName
var (
paramName string
paramType *ast.TypeName
isArray bool
arrayDims int
)

if argName == "" {
if i < len(fun.Args) {
paramName = fun.Args[i].Name
paramType = fun.Args[i].Type
isArray = fun.Args[i].IsArray
arrayDims = fun.Args[i].ArrayDims
}
} else {
paramName = argName
for _, arg := range fun.Args {
if arg.Name == argName {
paramType = arg.Type
isArray = arg.IsArray
arrayDims = arg.ArrayDims
}
}
if paramType == nil {
Expand All @@ -402,6 +410,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
NotNull: p.NotNull(),
IsNamedParam: isNamed,
IsSqlcSlice: p.IsSqlcSlice(),
IsArray: isArray,
ArrayDims: arrayDims,
},
})
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
-- name: Foo :one
SELECT * FROM register_account('a', 'b');

-- name: GetAccount :one
SELECT * FROM get_account($1, $2);
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,27 @@ BEGIN

RETURN NEXT;
END;
$$ LANGUAGE plpgsql;
$$
LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION get_account(
_account_id INTEGER,
_tags TEXT[][] -- test multidimensional array code generation
)
RETURNS TABLE(
account_id INTEGER,
username TEXT
)
AS $$
BEGIN
SELECT
account_id,
username
FROM
accounts
WHERE
account_id = _account_id;
END;
$$
LANGUAGE plpgsql;

8 changes: 5 additions & 3 deletions internal/engine/postgresql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,11 @@ func translate(node *nodes.Node) (ast.Node, error) {
return nil, err
}
fp := &ast.FuncParam{
Name: &arg.Name,
Type: rel.TypeName(),
Mode: mode,
Name: &arg.Name,
Type: rel.TypeName(),
Mode: mode,
IsArray: isArray(arg.ArgType),
ArrayDims: len(arg.ArgType.ArrayBounds),
}
if arg.Defexpr != nil {
fp.DefExpr = &ast.TODO{}
Expand Down
10 changes: 6 additions & 4 deletions internal/sql/ast/func_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ const (
)

type FuncParam struct {
Name *string
Type *TypeName
DefExpr Node // Will always be &ast.TODO
Mode FuncParamMode
Name *string
Type *TypeName
DefExpr Node // Will always be &ast.TODO
Mode FuncParamMode
IsArray bool
ArrayDims int
}

func (n *FuncParam) Pos() int {
Expand Down
1 change: 1 addition & 0 deletions internal/sql/ast/function_parameter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type FunctionParameter struct {
ArgType *TypeName
Mode FunctionParameterMode
Defexpr Node
IsArray bool
}

func (n *FunctionParameter) Pos() int {
Expand Down
21 changes: 21 additions & 0 deletions internal/sql/catalog/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ type Argument struct {
Type *ast.TypeName
HasDefault bool
Mode ast.FuncParamMode
IsArray bool
ArrayDims int
}

func (f *Function) InArgs() []*Argument {
Expand Down Expand Up @@ -65,21 +67,40 @@ func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error {
ReturnType: stmt.ReturnType,
}
types := make([]*ast.TypeName, len(stmt.Params.Items))
var cols []*ast.ColumnDef
for i, item := range stmt.Params.Items {
arg := item.(*ast.FuncParam)
var name string
if arg.Name != nil {
name = *arg.Name
}
if arg.Mode == ast.FuncParamTable {
cols = append(cols, &ast.ColumnDef{
Colname: name,
TypeName: arg.Type,
IsArray: arg.IsArray,
ArrayDims: arg.ArrayDims,
})
}
fn.Args[i] = &Argument{
Name: name,
Type: arg.Type,
Mode: arg.Mode,
HasDefault: arg.DefExpr != nil,
IsArray: arg.IsArray,
ArrayDims: arg.ArrayDims,
}
types[i] = arg.Type
}

if len(cols) > 0 {
fn.ReturnType = &ast.TypeName{
Name: stmt.Func.Name,
Schema: stmt.Func.Schema,
Catalog: stmt.Func.Catalog,
}
}

_, idx, err := s.getFunc(stmt.Func, types)
if err == nil && !stmt.Replace {
return sqlerr.RelationExists(stmt.Func.Name)
Expand Down