diff --git a/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..35e5f4a4b6 --- /dev/null +++ b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..4972835855 --- /dev/null +++ b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/models.go @@ -0,0 +1,90 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package db + +import ( + "database/sql/driver" + "fmt" +) + +type EnumType string + +const ( + EnumTypeBeforefirst EnumType = "beforefirst" + EnumTypeFirst EnumType = "first" + EnumTypeSecond EnumType = "second" + EnumTypeThird EnumType = "third" + EnumTypeFourth EnumType = "fourth" + EnumTypeFifth EnumType = "fifth" + EnumTypeLast EnumType = "last" + EnumTypeAfterlast EnumType = "afterlast" +) + +func (e *EnumType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = EnumType(s) + case string: + *e = EnumType(s) + default: + return fmt.Errorf("unsupported scan type for EnumType: %T", src) + } + return nil +} + +type NullEnumType struct { + EnumType EnumType + Valid bool // Valid is true if EnumType is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullEnumType) Scan(value interface{}) error { + if value == nil { + ns.EnumType, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.EnumType.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullEnumType) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.EnumType), nil +} + +func (e EnumType) Valid() bool { + switch e { + case EnumTypeBeforefirst, + EnumTypeFirst, + EnumTypeSecond, + EnumTypeThird, + EnumTypeFourth, + EnumTypeFifth, + EnumTypeLast, + EnumTypeAfterlast: + return true + } + return false +} + +func AllEnumTypeValues() []EnumType { + return []EnumType{ + EnumTypeBeforefirst, + EnumTypeFirst, + EnumTypeSecond, + EnumTypeThird, + EnumTypeFourth, + EnumTypeFifth, + EnumTypeLast, + EnumTypeAfterlast, + } +} + +type Foo struct { + ID int32 +} diff --git a/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/querier.go b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/querier.go new file mode 100644 index 0000000000..44b8ca8e6d --- /dev/null +++ b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/querier.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package db + +import ( + "context" +) + +type Querier interface { + GetAll(ctx context.Context) ([]int32, error) +} + +var _ Querier = (*Queries)(nil) diff --git a/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..76543c04bd --- /dev/null +++ b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 +// source: query.sql + +package db + +import ( + "context" +) + +const getAll = `-- name: GetAll :many +SELECT id FROM foo +` + +func (q *Queries) GetAll(ctx context.Context) ([]int32, error) { + rows, err := q.db.QueryContext(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int32 + for rows.Next() { + var id int32 + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/query.sql b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..202b352652 --- /dev/null +++ b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/query.sql @@ -0,0 +1,2 @@ +-- name: GetAll :many +SELECT * FROM foo; \ No newline at end of file diff --git a/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..679091891e --- /dev/null +++ b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/schema.sql @@ -0,0 +1,11 @@ +CREATE TYPE enum_type AS ENUM ('first', 'last'); +ALTER TYPE enum_type ADD VALUE 'afterlast' AFTER 'last'; +ALTER TYPE enum_type ADD VALUE 'third' AFTER 'first'; +ALTER TYPE enum_type ADD VALUE 'fourth' BEFORE 'last'; +ALTER TYPE enum_type ADD VALUE 'fifth' AFTER 'fourth'; +ALTER TYPE enum_type ADD VALUE 'second' BEFORE 'third'; +ALTER TYPE enum_type ADD VALUE 'beforefirst' BEFORE 'first'; + +CREATE TABLE foo ( + id SERIAL PRIMARY KEY +); \ No newline at end of file diff --git a/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..b0d556848e --- /dev/null +++ b/internal/endtoend/testdata/enum_ordering/postgresql/stdlib/sqlc.json @@ -0,0 +1,19 @@ +{ + "version": "2", + "sql": [ + { + "engine": "postgresql", + "schema": "schema.sql", + "queries": "query.sql", + "gen": { + "go": { + "out" : "go", + "package" : "db", + "emit_interface": true, + "emit_all_enum_values": true, + "emit_enum_valid_method": true + } + } + } + ] +} diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 4aba7360f0..aeaecbc6c6 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -211,6 +211,9 @@ func translate(node *nodes.Node) (ast.Node, error) { return &ast.AlterTypeAddValueStmt{ Type: rel.TypeName(), NewValue: makeString(n.NewVal), + NewValHasNeighbor: len(n.NewValNeighbor) > 0, + NewValNeighbor: makeString(n.NewValNeighbor), + NewValIsAfter: n.NewValIsAfter, SkipIfNewValExists: n.SkipIfNewValExists, }, nil } diff --git a/internal/sql/ast/alter_type_add_value_stmt.go b/internal/sql/ast/alter_type_add_value_stmt.go index 58085bbac3..56ae7dd9b7 100644 --- a/internal/sql/ast/alter_type_add_value_stmt.go +++ b/internal/sql/ast/alter_type_add_value_stmt.go @@ -3,6 +3,9 @@ package ast type AlterTypeAddValueStmt struct { Type *TypeName NewValue *string + NewValHasNeighbor bool + NewValNeighbor *string + NewValIsAfter bool SkipIfNewValExists bool } diff --git a/internal/sql/catalog/types.go b/internal/sql/catalog/types.go index 8b9c656411..9f1b7f54d7 100644 --- a/internal/sql/catalog/types.go +++ b/internal/sql/catalog/types.go @@ -3,7 +3,6 @@ package catalog import ( "errors" "fmt" - "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) @@ -197,20 +196,47 @@ func (c *Catalog) alterTypeAddValue(stmt *ast.AlterTypeAddValueStmt) error { return fmt.Errorf("type is not an enum: %T", stmt.Type) } - newIndex := -1 + existingIndex := -1 for i, val := range enum.Vals { if val == *stmt.NewValue { - newIndex = i + existingIndex = i } } - if newIndex >= 0 { + + if existingIndex >= 0 { if !stmt.SkipIfNewValExists { - return fmt.Errorf("type %T already has value %s", stmt.Type, *stmt.NewValue) + return fmt.Errorf("enum %s already has value %s", enum.Name, *stmt.NewValue) } else { return nil } } - enum.Vals = append(enum.Vals, *stmt.NewValue) + + if stmt.NewValHasNeighbor { + insertIndex := -1 + for i, val := range enum.Vals { + if val == *stmt.NewValNeighbor { + if stmt.NewValIsAfter { + insertIndex = i + 1 + } else { + insertIndex = i + } + } + } + + if insertIndex == -1 { + return fmt.Errorf("enum %s unable to find existing neighbor value %s for new value %s", enum.Name, *stmt.NewValNeighbor, *stmt.NewValue) + } + + if insertIndex == len(enum.Vals) { + enum.Vals = append(enum.Vals, *stmt.NewValue) + } else { + enum.Vals = append(enum.Vals[:insertIndex+1], enum.Vals[insertIndex:]...) + enum.Vals[insertIndex] = *stmt.NewValue + } + } else { + enum.Vals = append(enum.Vals, *stmt.NewValue) + } + return nil }