diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 5b96a08567..d9e0b52738 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -318,9 +318,22 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er } fun, err := qc.catalog.ResolveFuncCall(n) if err == nil { + var tableCols []*catalog.Argument + for _, arg := range fun.Args { + if arg.Mode == ast.FuncParamTable { + tableCols = append(tableCols, arg) + } + } + var dt string + if len(tableCols) == 1 { + // A single column will later generate a scalar (or slice of scalar) return type. + dt = dataType(tableCols[0].Type) + } else { + dt = dataType(fun.ReturnType) + } cols = append(cols, &Column{ Name: name, - DataType: dataType(fun.ReturnType), + DataType: dt, NotNull: !fun.ReturnTypeNullable, IsFuncCall: true, }) @@ -553,12 +566,53 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro if err != nil { continue } + var ( + fnRetCatalog string + fnRetSchema string + fnRetName string + ) + if fn.ReturnType != nil { + fnRetCatalog = fn.ReturnType.Catalog + fnRetSchema = fn.ReturnType.Schema + fnRetName = fn.ReturnType.Name + } + fnsWithName, err := c.catalog.ListFuncsByName(&ast.FuncName{ + Catalog: fnRetCatalog, + Schema: fnRetSchema, + Name: fnRetName, + }) + // If the function was found, build a table structure to hold the columns to output. + if err == nil && len(fnsWithName) == 1 { + fnWithName := fnsWithName[0] + rel := &ast.TableName{ + Catalog: fnRetCatalog, + Schema: fnRetSchema, + Name: fnRetName, + } + var cols []*Column + for _, arg := range fnWithName.Args { + if arg.Mode == ast.FuncParamTable { + col := &catalog.Column{ + Name: arg.Name, + Type: *arg.Type, + IsArray: arg.IsArray, + } + convertedCol := ConvertColumn(rel, col) + cols = append(cols, convertedCol) + } + } + tables = append(tables, &Table{ + Rel: rel, + Columns: cols, + }) + continue + } var table *Table if fn.ReturnType != nil { table, err = qc.GetTable(&ast.TableName{ - Catalog: fn.ReturnType.Catalog, - Schema: fn.ReturnType.Schema, - Name: fn.ReturnType.Name, + Catalog: fnRetCatalog, + Schema: fnRetSchema, + Name: fnRetName, }) } if table == nil || err != nil { diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index eb4315a47f..dd4061afd5 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -366,19 +366,27 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } - var paramName string - var paramType *ast.TypeName + var ( + paramName string + paramType *ast.TypeName + isArray bool + arrayDims int + ) if argName == "" { if i < len(fun.Args) { paramName = fun.Args[i].Name paramType = fun.Args[i].Type + isArray = fun.Args[i].IsArray + arrayDims = fun.Args[i].ArrayDims } } else { paramName = argName for _, arg := range fun.Args { if arg.Name == argName { paramType = arg.Type + isArray = arg.IsArray + arrayDims = arg.ArrayDims } } if paramType == nil { @@ -402,6 +410,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, NotNull: p.NotNull(), IsNamedParam: isNamed, IsSqlcSlice: p.IsSqlcSlice(), + IsArray: isArray, + ArrayDims: arrayDims, }, }) } diff --git a/internal/endtoend/testdata/func_return_table/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/func_return_table/postgresql/pgx/go/query.sql.go index 7b34360e60..aa0427ffd5 100644 --- a/internal/endtoend/testdata/func_return_table/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/func_return_table/postgresql/pgx/go/query.sql.go @@ -12,12 +12,33 @@ import ( ) const foo = `-- name: Foo :one -SELECT register_account FROM register_account('a', 'b') +SELECT account_id FROM register_account('a', 'b') ` func (q *Queries) Foo(ctx context.Context) (pgtype.Int4, error) { row := q.db.QueryRow(ctx, foo) - var register_account pgtype.Int4 - err := row.Scan(®ister_account) - return register_account, err + var account_id pgtype.Int4 + err := row.Scan(&account_id) + return account_id, err +} + +const getAccount = `-- name: GetAccount :one +SELECT account_id, username FROM get_account($1, $2) +` + +type GetAccountParams struct { + AccountID int32 + Tags [][]string +} + +type GetAccountRow struct { + AccountID pgtype.Int4 + Username pgtype.Text +} + +func (q *Queries) GetAccount(ctx context.Context, arg GetAccountParams) (GetAccountRow, error) { + row := q.db.QueryRow(ctx, getAccount, arg.AccountID, arg.Tags) + var i GetAccountRow + err := row.Scan(&i.AccountID, &i.Username) + return i, err } diff --git a/internal/endtoend/testdata/func_return_table/postgresql/pgx/query.sql b/internal/endtoend/testdata/func_return_table/postgresql/pgx/query.sql index 8a3db0f6d8..a8a2712639 100644 --- a/internal/endtoend/testdata/func_return_table/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/func_return_table/postgresql/pgx/query.sql @@ -1,2 +1,5 @@ -- name: Foo :one SELECT * FROM register_account('a', 'b'); + +-- name: GetAccount :one +SELECT * FROM get_account($1, $2); diff --git a/internal/endtoend/testdata/func_return_table/postgresql/pgx/schema.sql b/internal/endtoend/testdata/func_return_table/postgresql/pgx/schema.sql index e9ebf5e423..8f967da48d 100644 --- a/internal/endtoend/testdata/func_return_table/postgresql/pgx/schema.sql +++ b/internal/endtoend/testdata/func_return_table/postgresql/pgx/schema.sql @@ -24,4 +24,27 @@ BEGIN RETURN NEXT; END; -$$ LANGUAGE plpgsql; +$$ +LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION get_account( + _account_id INTEGER, + _tags TEXT[][] -- test multidimensional array code generation +) +RETURNS TABLE( + account_id INTEGER, + username TEXT +) +AS $$ +BEGIN + SELECT + account_id, + username + FROM + accounts + WHERE + account_id = _account_id; +END; +$$ +LANGUAGE plpgsql; + diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 296a14e858..52d9c4754e 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -509,9 +509,11 @@ func translate(node *nodes.Node) (ast.Node, error) { return nil, err } fp := &ast.FuncParam{ - Name: &arg.Name, - Type: rel.TypeName(), - Mode: mode, + Name: &arg.Name, + Type: rel.TypeName(), + Mode: mode, + IsArray: isArray(arg.ArgType), + ArrayDims: len(arg.ArgType.ArrayBounds), } if arg.Defexpr != nil { fp.DefExpr = &ast.TODO{} diff --git a/internal/sql/ast/func_param.go b/internal/sql/ast/func_param.go index b5cf8cfcf0..c8a64136a5 100644 --- a/internal/sql/ast/func_param.go +++ b/internal/sql/ast/func_param.go @@ -12,10 +12,12 @@ const ( ) type FuncParam struct { - Name *string - Type *TypeName - DefExpr Node // Will always be &ast.TODO - Mode FuncParamMode + Name *string + Type *TypeName + DefExpr Node // Will always be &ast.TODO + Mode FuncParamMode + IsArray bool + ArrayDims int } func (n *FuncParam) Pos() int { diff --git a/internal/sql/ast/function_parameter.go b/internal/sql/ast/function_parameter.go index 54262f6130..8f2961b346 100644 --- a/internal/sql/ast/function_parameter.go +++ b/internal/sql/ast/function_parameter.go @@ -5,6 +5,7 @@ type FunctionParameter struct { ArgType *TypeName Mode FunctionParameterMode Defexpr Node + IsArray bool } func (n *FunctionParameter) Pos() int { diff --git a/internal/sql/catalog/func.go b/internal/sql/catalog/func.go index e170777311..738a0e2886 100644 --- a/internal/sql/catalog/func.go +++ b/internal/sql/catalog/func.go @@ -24,6 +24,8 @@ type Argument struct { Type *ast.TypeName HasDefault bool Mode ast.FuncParamMode + IsArray bool + ArrayDims int } func (f *Function) InArgs() []*Argument { @@ -65,21 +67,40 @@ func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error { ReturnType: stmt.ReturnType, } types := make([]*ast.TypeName, len(stmt.Params.Items)) + var cols []*ast.ColumnDef for i, item := range stmt.Params.Items { arg := item.(*ast.FuncParam) var name string if arg.Name != nil { name = *arg.Name } + if arg.Mode == ast.FuncParamTable { + cols = append(cols, &ast.ColumnDef{ + Colname: name, + TypeName: arg.Type, + IsArray: arg.IsArray, + ArrayDims: arg.ArrayDims, + }) + } fn.Args[i] = &Argument{ Name: name, Type: arg.Type, Mode: arg.Mode, HasDefault: arg.DefExpr != nil, + IsArray: arg.IsArray, + ArrayDims: arg.ArrayDims, } types[i] = arg.Type } + if len(cols) > 0 { + fn.ReturnType = &ast.TypeName{ + Name: stmt.Func.Name, + Schema: stmt.Func.Schema, + Catalog: stmt.Func.Catalog, + } + } + _, idx, err := s.getFunc(stmt.Func, types) if err == nil && !stmt.Replace { return sqlerr.RelationExists(stmt.Func.Name)