diff --git a/go.mod b/go.mod index 1dcbf5dad8..8ede37f5de 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/jinzhu/inflection v1.0.0 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 - github.com/pganalyze/pg_query_go/v2 v2.2.0 + github.com/pganalyze/pg_query_go/v4 v4.2.0 github.com/spf13/cobra v1.6.1 github.com/spf13/pflag v1.0.5 golang.org/x/sync v0.1.0 diff --git a/go.sum b/go.sum index 058bbd2959..00dbd5f3a7 100644 --- a/go.sum +++ b/go.sum @@ -118,8 +118,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/pganalyze/pg_query_go/v2 v2.2.0 h1:OW+reH+ZY7jdEuPyuLGlf1m7dLbE+fDudKXhLs0Ttpk= -github.com/pganalyze/pg_query_go/v2 v2.2.0/go.mod h1:XAxmVqz1tEGqizcQ3YSdN90vCOHBWjJi8URL1er5+cA= +github.com/pganalyze/pg_query_go/v4 v4.2.0 h1:67hSBZXYfABNYisEu/Xfu6R2gupnQwaoRhQicy0HSnQ= +github.com/pganalyze/pg_query_go/v4 v4.2.0/go.mod h1:aEkDNOXNM5j0YGzaAapwJ7LB3dLNj+bvbWcLv1hOVqA= github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63 h1:+FZIDR/D97YOPik4N4lPDaUcLDF/EQPogxtlHB2ZZRM= github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= diff --git a/internal/compiler/compat.go b/internal/compiler/compat.go index ef8c522541..77b735df87 100644 --- a/internal/compiler/compat.go +++ b/internal/compiler/compat.go @@ -33,6 +33,11 @@ type Relation struct { func parseRelation(node ast.Node) (*Relation, error) { switch n := node.(type) { + case *ast.Boolean: + return &Relation{ + Name: "bool", + }, nil + case *ast.List: parts := stringSlice(n) switch len(parts) { diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 516b03963c..656f2daab3 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -156,6 +156,12 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { col := toColumn(tc.TypeName) col.Name = name cols = append(cols, col) + } else if aconst, ok := n.Defresult.(*ast.A_Const); ok { + tn, err := ParseTypeName(aconst.Val) + if err != nil { + return nil, err + } + cols = append(cols, &Column{Name: name, DataType: dataType(tn), NotNull: true}) } else { cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false}) } diff --git a/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v4/go/query.sql.go index 99356b3997..bbec8640c3 100644 --- a/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v4/go/query.sql.go +++ b/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v4/go/query.sql.go @@ -35,7 +35,7 @@ func (q *Queries) GetAll(ctx context.Context) ([]*Foo, error) { } const getAllAByB = `-- name: GetAllAByB :many -SELECT a FROM foo WHERE b = ? +SELECT a FROM foo WHERE b = $1 ` func (q *Queries) GetAllAByB(ctx context.Context, b sql.NullInt32) ([]sql.NullInt32, error) { @@ -59,7 +59,7 @@ func (q *Queries) GetAllAByB(ctx context.Context, b sql.NullInt32) ([]sql.NullIn } const getOne = `-- name: GetOne :one -SELECT a, b FROM foo WHERE a = ? AND b = ? LIMIT 1 +SELECT a, b FROM foo WHERE a = $1 AND b = $2 LIMIT 1 ` type GetOneParams struct { diff --git a/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v4/query.sql b/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v4/query.sql index 7fc9b6b946..1c60a1dfae 100644 --- a/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v4/query.sql +++ b/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v4/query.sql @@ -7,10 +7,10 @@ ON CONFLICT DO NOTHING RETURNING *; -- name: GetOne :one -SELECT * FROM foo WHERE a = ? AND b = ? LIMIT 1; +SELECT * FROM foo WHERE a = $1 AND b = $2 LIMIT 1; -- name: GetAll :many SELECT * FROM foo; -- name: GetAllAByB :many -SELECT a FROM foo WHERE b = ?; +SELECT a FROM foo WHERE b = $1; diff --git a/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v5/go/query.sql.go index 0dc08bebc5..52947e5563 100644 --- a/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v5/go/query.sql.go +++ b/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v5/go/query.sql.go @@ -36,7 +36,7 @@ func (q *Queries) GetAll(ctx context.Context) ([]*Foo, error) { } const getAllAByB = `-- name: GetAllAByB :many -SELECT a FROM foo WHERE b = ? +SELECT a FROM foo WHERE b = $1 ` func (q *Queries) GetAllAByB(ctx context.Context, b pgtype.Int4) ([]pgtype.Int4, error) { @@ -60,7 +60,7 @@ func (q *Queries) GetAllAByB(ctx context.Context, b pgtype.Int4) ([]pgtype.Int4, } const getOne = `-- name: GetOne :one -SELECT a, b FROM foo WHERE a = ? AND b = ? LIMIT 1 +SELECT a, b FROM foo WHERE a = $1 AND b = $2 LIMIT 1 ` type GetOneParams struct { diff --git a/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v5/query.sql b/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v5/query.sql index 7fc9b6b946..1c60a1dfae 100644 --- a/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v5/query.sql +++ b/internal/endtoend/testdata/emit_result_and_params_struct_pointers/postgresql/pgx/v5/query.sql @@ -7,10 +7,10 @@ ON CONFLICT DO NOTHING RETURNING *; -- name: GetOne :one -SELECT * FROM foo WHERE a = ? AND b = ? LIMIT 1; +SELECT * FROM foo WHERE a = $1 AND b = $2 LIMIT 1; -- name: GetAll :many SELECT * FROM foo; -- name: GetAllAByB :many -SELECT a FROM foo WHERE b = ?; +SELECT a FROM foo WHERE b = $1; diff --git a/internal/endtoend/testdata/invalid_func_args/pgx/v4/query.sql b/internal/endtoend/testdata/invalid_func_args/pgx/v4/query.sql index 02ff72d1f3..46dbe9fa2d 100644 --- a/internal/endtoend/testdata/invalid_func_args/pgx/v4/query.sql +++ b/internal/endtoend/testdata/invalid_func_args/pgx/v4/query.sql @@ -1,2 +1 @@ SELECT random(1); -SELECT position(); diff --git a/internal/endtoend/testdata/invalid_func_args/pgx/v4/stderr.txt b/internal/endtoend/testdata/invalid_func_args/pgx/v4/stderr.txt index df8159d449..09d46d1f95 100644 --- a/internal/endtoend/testdata/invalid_func_args/pgx/v4/stderr.txt +++ b/internal/endtoend/testdata/invalid_func_args/pgx/v4/stderr.txt @@ -1,3 +1,2 @@ # package querytest query.sql:1:8: function random(unknown) does not exist -query.sql:2:8: function position() does not exist diff --git a/internal/endtoend/testdata/invalid_func_args/pgx/v5/query.sql b/internal/endtoend/testdata/invalid_func_args/pgx/v5/query.sql index 02ff72d1f3..46dbe9fa2d 100644 --- a/internal/endtoend/testdata/invalid_func_args/pgx/v5/query.sql +++ b/internal/endtoend/testdata/invalid_func_args/pgx/v5/query.sql @@ -1,2 +1 @@ SELECT random(1); -SELECT position(); diff --git a/internal/endtoend/testdata/invalid_func_args/pgx/v5/stderr.txt b/internal/endtoend/testdata/invalid_func_args/pgx/v5/stderr.txt index df8159d449..09d46d1f95 100644 --- a/internal/endtoend/testdata/invalid_func_args/pgx/v5/stderr.txt +++ b/internal/endtoend/testdata/invalid_func_args/pgx/v5/stderr.txt @@ -1,3 +1,2 @@ # package querytest query.sql:1:8: function random(unknown) does not exist -query.sql:2:8: function position() does not exist diff --git a/internal/endtoend/testdata/invalid_func_args/stdlib/query.sql b/internal/endtoend/testdata/invalid_func_args/stdlib/query.sql index 02ff72d1f3..46dbe9fa2d 100644 --- a/internal/endtoend/testdata/invalid_func_args/stdlib/query.sql +++ b/internal/endtoend/testdata/invalid_func_args/stdlib/query.sql @@ -1,2 +1 @@ SELECT random(1); -SELECT position(); diff --git a/internal/endtoend/testdata/invalid_func_args/stdlib/stderr.txt b/internal/endtoend/testdata/invalid_func_args/stdlib/stderr.txt index df8159d449..09d46d1f95 100644 --- a/internal/endtoend/testdata/invalid_func_args/stdlib/stderr.txt +++ b/internal/endtoend/testdata/invalid_func_args/stdlib/stderr.txt @@ -1,3 +1,2 @@ # package querytest query.sql:1:8: function random(unknown) does not exist -query.sql:2:8: function position() does not exist diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go index a37180dffd..5a5132cf0e 100644 --- a/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go +++ b/internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go @@ -9,23 +9,6 @@ import ( "context" ) -const countFour = `-- name: CountFour :one -SELECT count(1) FROM bar WHERE id > ? AND phone <> ? AND name <> ? -` - -type CountFourParams struct { - ID int32 - Phone string - Name string -} - -func (q *Queries) CountFour(ctx context.Context, arg CountFourParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countFour, arg.ID, arg.Phone, arg.Name) - var count int64 - err := row.Scan(&count) - return count, err -} - const countOne = `-- name: CountOne :one SELECT count(1) FROM bar WHERE id = $2 AND name <> $1 LIMIT $3 ` diff --git a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql index 9ec77fb270..411f99829f 100644 --- a/internal/endtoend/testdata/mix_param_types/postgresql/test.sql +++ b/internal/endtoend/testdata/mix_param_types/postgresql/test.sql @@ -12,6 +12,3 @@ SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name); -- name: CountThree :one SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1; - --- name: CountFour :one -SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?; diff --git a/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/go/query.sql.go b/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/go/query.sql.go index cc5b48e563..7ab4aaa3e1 100644 --- a/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/go/query.sql.go +++ b/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/go/query.sql.go @@ -7,11 +7,10 @@ package querytest import ( "context" - "database/sql" ) const findByID = `-- name: FindByID :many -SELECT id, name FROM users WHERE ? = id +SELECT id, name FROM users WHERE $1 = id ` func (q *Queries) FindByID(ctx context.Context, id int32) ([]User, error) { @@ -38,16 +37,11 @@ func (q *Queries) FindByID(ctx context.Context, id int32) ([]User, error) { } const findByIDAndName = `-- name: FindByIDAndName :many -SELECT id, name FROM users WHERE ? = id AND ? = name +SELECT id, name FROM users WHERE $1 = id AND $1 = name ` -type FindByIDAndNameParams struct { - ID int32 - Name sql.NullString -} - -func (q *Queries) FindByIDAndName(ctx context.Context, arg FindByIDAndNameParams) ([]User, error) { - rows, err := q.db.QueryContext(ctx, findByIDAndName, arg.ID, arg.Name) +func (q *Queries) FindByIDAndName(ctx context.Context, id int32) ([]User, error) { + rows, err := q.db.QueryContext(ctx, findByIDAndName, id) if err != nil { return nil, err } diff --git a/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/query.sql b/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/query.sql index 01655c8b62..042d4e6f58 100644 --- a/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/query.sql +++ b/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/query.sql @@ -4,7 +4,7 @@ CREATE TABLE users ( ); -- name: FindByID :many -SELECT * FROM users WHERE ? = id; +SELECT * FROM users WHERE $1 = id; -- name: FindByIDAndName :many -SELECT * FROM users WHERE ? = id AND ? = name; +SELECT * FROM users WHERE $1 = id AND $1 = name; diff --git a/internal/engine/postgresql/convert.go b/internal/engine/postgresql/convert.go index 42fd4cbfc4..617454875c 100644 --- a/internal/engine/postgresql/convert.go +++ b/internal/engine/postgresql/convert.go @@ -6,7 +6,7 @@ package postgresql import ( "fmt" - pg "github.com/pganalyze/pg_query_go/v2" + pg "github.com/pganalyze/pg_query_go/v4" "github.com/kyleconroy/sqlc/internal/sql/ast" ) @@ -23,6 +23,8 @@ func convertFuncParamMode(m pg.FunctionParameterMode) (ast.FuncParamMode, error) return ast.FuncParamVariadic, nil case pg.FunctionParameterMode_FUNC_PARAM_TABLE: return ast.FuncParamTable, nil + case pg.FunctionParameterMode_FUNC_PARAM_DEFAULT: + return ast.FuncParamDefault, nil default: return -1, fmt.Errorf("parse func param: invalid mode %v", m) } @@ -112,8 +114,25 @@ func convertA_Const(n *pg.A_Const) *ast.A_Const { if n == nil { return nil } + var val ast.Node + if n.Isnull { + val = &ast.Null{} + } else { + switch v := n.Val.(type) { + case *pg.A_Const_Boolval: + val = convertBoolean(v.Boolval) + case *pg.A_Const_Bsval: + val = convertBitString(v.Bsval) + case *pg.A_Const_Fval: + val = convertFloat(v.Fval) + case *pg.A_Const_Ival: + val = convertInteger(v.Ival) + case *pg.A_Const_Sval: + val = convertString(v.Sval) + } + } return &ast.A_Const{ - Val: convertNode(n.Val), + Val: val, Location: int(n.Location), } } @@ -345,7 +364,7 @@ func convertAlterObjectDependsStmt(n *pg.AlterObjectDependsStmt) *ast.AlterObjec ObjectType: ast.ObjectType(n.ObjectType), Relation: convertRangeVar(n.Relation), Object: convertNode(n.Object), - Extname: convertNode(n.Extname), + Extname: convertString(n.Extname), } } @@ -416,9 +435,9 @@ func convertAlterPublicationStmt(n *pg.AlterPublicationStmt) *ast.AlterPublicati return &ast.AlterPublicationStmt{ Pubname: makeString(n.Pubname), Options: convertSlice(n.Options), - Tables: convertSlice(n.Tables), + Tables: convertSlice(n.Pubobjects), ForAllTables: n.ForAllTables, - TableAction: ast.DefElemAction(n.TableAction), + TableAction: ast.DefElemAction(n.Action), } } @@ -550,7 +569,7 @@ func convertAlterTableStmt(n *pg.AlterTableStmt) *ast.AlterTableStmt { return &ast.AlterTableStmt{ Relation: convertRangeVar(n.Relation), Cmds: convertSlice(n.Cmds), - Relkind: ast.ObjectType(n.Relkind), + Relkind: ast.ObjectType(n.Objtype), MissingOk: n.MissingOk, } } @@ -611,7 +630,7 @@ func convertBitString(n *pg.BitString) *ast.BitString { return nil } return &ast.BitString{ - Str: n.Str, + Str: n.Bsval, } } @@ -627,6 +646,15 @@ func convertBoolExpr(n *pg.BoolExpr) *ast.BoolExpr { } } +func convertBoolean(n *pg.Boolean) *ast.Boolean { + if n == nil { + return nil + } + return &ast.Boolean{ + Boolval: n.Boolval, + } +} + func convertBooleanTest(n *pg.BooleanTest) *ast.BooleanTest { if n == nil { return nil @@ -1165,7 +1193,7 @@ func convertCreatePublicationStmt(n *pg.CreatePublicationStmt) *ast.CreatePublic return &ast.CreatePublicationStmt{ Pubname: makeString(n.Pubname), Options: convertSlice(n.Options), - Tables: convertSlice(n.Tables), + Tables: convertSlice(n.Pubobjects), ForAllTables: n.ForAllTables, } } @@ -1267,7 +1295,7 @@ func convertCreateTableAsStmt(n *pg.CreateTableAsStmt) *ast.CreateTableAsStmt { res := &ast.CreateTableAsStmt{ Query: convertNode(n.Query), Into: convertIntoClause(n.Into), - Relkind: ast.ObjectType(n.Relkind), + Relkind: ast.ObjectType(n.Objtype), IsSelectInto: n.IsSelectInto, IfNotExists: n.IfNotExists, } @@ -1528,13 +1556,6 @@ func convertExplainStmt(n *pg.ExplainStmt) *ast.ExplainStmt { } } -func convertExpr(n *pg.Expr) *ast.Expr { - if n == nil { - return nil - } - return &ast.Expr{} -} - func convertFetchStmt(n *pg.FetchStmt) *ast.FetchStmt { if n == nil { return nil @@ -1579,7 +1600,7 @@ func convertFloat(n *pg.Float) *ast.Float { return nil } return &ast.Float{ - Str: n.Str, + Str: n.Fval, } } @@ -1948,13 +1969,6 @@ func convertNotifyStmt(n *pg.NotifyStmt) *ast.NotifyStmt { } } -func convertNull(n *pg.Null) *ast.Null { - if n == nil { - return nil - } - return &ast.Null{} -} - func convertNullTest(n *pg.NullTest) *ast.NullTest { if n == nil { return nil @@ -2353,7 +2367,7 @@ func convertReindexStmt(n *pg.ReindexStmt) *ast.ReindexStmt { Kind: ast.ReindexObjectType(n.Kind), Relation: convertRangeVar(n.Relation), Name: makeString(n.Name), - Options: int(n.Options), + // Options: int(n.Options), TODO: Support params } } @@ -2611,7 +2625,7 @@ func convertString(n *pg.String) *ast.String { return nil } return &ast.String{ - Str: n.Str, + Str: n.Sval, } } @@ -3118,6 +3132,9 @@ func convertNode(node *pg.Node) ast.Node { case *pg.Node_BoolExpr: return convertBoolExpr(n.BoolExpr) + case *pg.Node_Boolean: + return convertBoolean(n.Boolean) + case *pg.Node_BooleanTest: return convertBooleanTest(n.BooleanTest) @@ -3328,9 +3345,6 @@ func convertNode(node *pg.Node) ast.Node { case *pg.Node_ExplainStmt: return convertExplainStmt(n.ExplainStmt) - case *pg.Node_Expr: - return convertExpr(n.Expr) - case *pg.Node_FetchStmt: return convertFetchStmt(n.FetchStmt) @@ -3427,9 +3441,6 @@ func convertNode(node *pg.Node) ast.Node { case *pg.Node_NotifyStmt: return convertNotifyStmt(n.NotifyStmt) - case *pg.Node_Null: - return convertNull(n.Null) - case *pg.Node_NullTest: return convertNullTest(n.NullTest) diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 4aba7360f0..11d6982466 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -9,7 +9,7 @@ import ( "io" "strings" - nodes "github.com/pganalyze/pg_query_go/v2" + nodes "github.com/pganalyze/pg_query_go/v4" "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/sql/ast" @@ -19,7 +19,7 @@ func stringSlice(list *nodes.List) []string { items := []string{} for _, item := range list.Items { if n, ok := item.Node.(*nodes.Node_String_); ok { - items = append(items, n.String_.Str) + items = append(items, n.String_.Sval) } } return items @@ -29,7 +29,7 @@ func stringSliceFromNodes(s []*nodes.Node) []string { var items []string for _, item := range s { if n, ok := item.Node.(*nodes.Node_String_); ok { - items = append(items, n.String_.Str) + items = append(items, n.String_.Sval) } } return items @@ -334,7 +334,7 @@ func translate(node *nodes.Node) (ast.Node, error) { return nil, fmt.Errorf("COMMENT ON SCHEMA: unexpected node type: %T", n.Object) } return &ast.CommentOnSchemaStmt{ - Schema: &ast.String{Str: o.String_.Str}, + Schema: &ast.String{Str: o.String_.Sval}, Comment: makeString(n.Comment), }, nil @@ -391,7 +391,7 @@ func translate(node *nodes.Node) (ast.Node, error) { if item.Constraint.Contype == nodes.ConstrType_CONSTR_PRIMARY { for _, key := range item.Constraint.Keys { // FIXME: Possible nil pointer dereference - primaryKey[key.Node.(*nodes.Node_String_).String_.Str] = true + primaryKey[key.Node.(*nodes.Node_String_).String_.Sval] = true } } @@ -431,7 +431,7 @@ func translate(node *nodes.Node) (ast.Node, error) { switch v := val.Node.(type) { case *nodes.Node_String_: stmt.Vals.Items = append(stmt.Vals.Items, &ast.String{ - Str: v.String_.Str, + Str: v.String_.Sval, }) } } @@ -533,7 +533,7 @@ func translate(node *nodes.Node) (ast.Node, error) { if !ok { return nil, fmt.Errorf("nodes.DropStmt: SCHEMA: unknown type in objects list: %T", obj) } - drop.Schemas = append(drop.Schemas, &ast.String{Str: val.String_.Str}) + drop.Schemas = append(drop.Schemas, &ast.String{Str: val.String_.Sval}) } return drop, nil diff --git a/internal/engine/postgresql/utils.go b/internal/engine/postgresql/utils.go index 2f6396fec6..d0ddc392a1 100644 --- a/internal/engine/postgresql/utils.go +++ b/internal/engine/postgresql/utils.go @@ -4,7 +4,7 @@ package postgresql import ( - nodes "github.com/pganalyze/pg_query_go/v2" + nodes "github.com/pganalyze/pg_query_go/v4" ) func isArray(n *nodes.TypeName) bool { diff --git a/internal/sql/ast/boolean.go b/internal/sql/ast/boolean.go new file mode 100644 index 0000000000..cf193f2c12 --- /dev/null +++ b/internal/sql/ast/boolean.go @@ -0,0 +1,9 @@ +package ast + +type Boolean struct { + Boolval bool +} + +func (n *Boolean) Pos() int { + return 0 +} diff --git a/internal/sql/ast/func_param.go b/internal/sql/ast/func_param.go index faee6ede37..b5cf8cfcf0 100644 --- a/internal/sql/ast/func_param.go +++ b/internal/sql/ast/func_param.go @@ -8,6 +8,7 @@ const ( FuncParamInOut FuncParamVariadic FuncParamTable + FuncParamDefault ) type FuncParam struct { diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index d9591c04ec..e38bbdfb3d 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -508,6 +508,9 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.Args) } + case *ast.Boolean: + // pass + case *ast.BooleanTest: if n.Xpr != nil { Walk(f, n.Xpr)