From 6cc483dc74f33ebad21bebe2e44a34d79601c045 Mon Sep 17 00:00:00 2001 From: Robert Holt Date: Thu, 10 Jun 2021 16:21:05 -0400 Subject: [PATCH] Add support for SQL Views --- internal/compiler/compile.go | 7 +-- internal/compiler/engine.go | 2 +- internal/compiler/output_columns.go | 26 ++++++++ .../testdata/create_view/mysql/go/db.go | 29 +++++++++ .../testdata/create_view/mysql/go/models.go | 21 +++++++ .../create_view/mysql/go/query.sql.go | 62 ++++++++++++++++++ .../testdata/create_view/mysql/query.sql | 5 ++ .../testdata/create_view/mysql/schema.sql | 10 +++ .../testdata/create_view/mysql/sqlc.json | 12 ++++ .../testdata/create_view/postgresql/go/db.go | 29 +++++++++ .../create_view/postgresql/go/models.go | 21 +++++++ .../create_view/postgresql/go/query.sql.go | 63 +++++++++++++++++++ .../testdata/create_view/postgresql/query.sql | 5 ++ .../create_view/postgresql/schema.sql | 10 +++ .../testdata/create_view/postgresql/sqlc.json | 12 ++++ internal/engine/dolphin/convert.go | 13 ++-- internal/engine/postgresql/parse.go | 2 +- internal/sql/catalog/catalog.go | 14 ++++- internal/sql/catalog/view.go | 52 +++++++++++++++ 19 files changed, 382 insertions(+), 13 deletions(-) create mode 100644 internal/endtoend/testdata/create_view/mysql/go/db.go create mode 100644 internal/endtoend/testdata/create_view/mysql/go/models.go create mode 100644 internal/endtoend/testdata/create_view/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/create_view/mysql/query.sql create mode 100644 internal/endtoend/testdata/create_view/mysql/schema.sql create mode 100644 internal/endtoend/testdata/create_view/mysql/sqlc.json create mode 100644 internal/endtoend/testdata/create_view/postgresql/go/db.go create mode 100644 internal/endtoend/testdata/create_view/postgresql/go/models.go create mode 100644 internal/endtoend/testdata/create_view/postgresql/go/query.sql.go create mode 100644 internal/endtoend/testdata/create_view/postgresql/query.sql create mode 100644 internal/endtoend/testdata/create_view/postgresql/schema.sql create mode 100644 internal/endtoend/testdata/create_view/postgresql/sqlc.json create mode 100644 internal/sql/catalog/view.go diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 171bc9a6c1..2f0d0b1828 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -14,7 +14,6 @@ import ( "github.com/kyleconroy/sqlc/internal/multierr" "github.com/kyleconroy/sqlc/internal/opts" "github.com/kyleconroy/sqlc/internal/sql/ast" - "github.com/kyleconroy/sqlc/internal/sql/catalog" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" "github.com/kyleconroy/sqlc/internal/sql/sqlpath" ) @@ -54,7 +53,7 @@ func enumValueName(value string) string { } // end copypasta -func parseCatalog(p Parser, c *catalog.Catalog, schemas []string) error { +func (c *Compiler) parseCatalog(schemas []string) error { files, err := sqlpath.Glob(schemas) if err != nil { return err @@ -67,13 +66,13 @@ func parseCatalog(p Parser, c *catalog.Catalog, schemas []string) error { continue } contents := migrations.RemoveRollbackStatements(string(blob)) - stmts, err := p.Parse(strings.NewReader(contents)) + stmts, err := c.parser.Parse(strings.NewReader(contents)) if err != nil { merr.Add(filename, contents, 0, err) continue } for i := range stmts { - if err := c.Update(stmts[i]); err != nil { + if err := c.catalog.Update(stmts[i], c); err != nil { merr.Add(filename, contents, stmts[i].Pos(), err) continue } diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index 2c90bc315b..d0f73acbcc 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -42,7 +42,7 @@ func (c *Compiler) Catalog() *catalog.Catalog { } func (c *Compiler) ParseCatalog(schema []string) error { - return parseCatalog(c.parser, c.catalog, schema) + return c.parseCatalog(schema) } func (c *Compiler) ParseQueries(queries []string, o opts.Parser) error { diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 56e236cecc..8330f3f269 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -3,6 +3,7 @@ package compiler import ( "errors" "fmt" + "github.com/kyleconroy/sqlc/internal/sql/catalog" "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" @@ -10,6 +11,31 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) +// OutputColumns determines which columns a statement will output +func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) { + qc, err := buildQueryCatalog(c.catalog, stmt) + if err != nil { + return nil, err + } + cols, err := outputColumns(qc, stmt) + if err != nil { + return nil, err + } + + catCols := make([]*catalog.Column, 0, len(cols)) + for _, col := range cols { + catCols = append(catCols, &catalog.Column{ + Name: col.Name, + Type: ast.TypeName{Name: col.DataType}, + IsNotNull: col.NotNull, + IsArray: col.IsArray, + Comment: col.Comment, + Length: col.Length, + }) + } + return catCols, nil +} + func hasStarRef(cf *ast.ColumnRef) bool { for _, item := range cf.Fields.Items { if _, ok := item.(*ast.A_Star); ok { diff --git a/internal/endtoend/testdata/create_view/mysql/go/db.go b/internal/endtoend/testdata/create_view/mysql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/create_view/mysql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/create_view/mysql/go/models.go b/internal/endtoend/testdata/create_view/mysql/go/models.go new file mode 100644 index 0000000000..409c0fb965 --- /dev/null +++ b/internal/endtoend/testdata/create_view/mysql/go/models.go @@ -0,0 +1,21 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type FirstView struct { + Val string +} + +type Foo struct { + Val string + Val2 sql.NullInt32 +} + +type SecondView struct { + Val string + Val2 sql.NullInt32 +} diff --git a/internal/endtoend/testdata/create_view/mysql/go/query.sql.go b/internal/endtoend/testdata/create_view/mysql/go/query.sql.go new file mode 100644 index 0000000000..f2d47c7b6d --- /dev/null +++ b/internal/endtoend/testdata/create_view/mysql/go/query.sql.go @@ -0,0 +1,62 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const getFirst = `-- name: GetFirst :many +SELECT val FROM first_view +` + +func (q *Queries) GetFirst(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getFirst) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var val string + if err := rows.Scan(&val); err != nil { + return nil, err + } + items = append(items, val) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSecond = `-- name: GetSecond :many +SELECT val, val2 FROM second_view WHERE val2 = $1 +` + +func (q *Queries) GetSecond(ctx context.Context) ([]SecondView, error) { + rows, err := q.db.QueryContext(ctx, getSecond) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SecondView + for rows.Next() { + var i SecondView + if err := rows.Scan(&i.Val, &i.Val2); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/create_view/mysql/query.sql b/internal/endtoend/testdata/create_view/mysql/query.sql new file mode 100644 index 0000000000..1063db8740 --- /dev/null +++ b/internal/endtoend/testdata/create_view/mysql/query.sql @@ -0,0 +1,5 @@ +-- name: GetFirst :many +SELECT * FROM first_view; + +-- name: GetSecond :many +SELECT * FROM second_view WHERE val2 = $1; diff --git a/internal/endtoend/testdata/create_view/mysql/schema.sql b/internal/endtoend/testdata/create_view/mysql/schema.sql new file mode 100644 index 0000000000..3f94440dec --- /dev/null +++ b/internal/endtoend/testdata/create_view/mysql/schema.sql @@ -0,0 +1,10 @@ +CREATE TABLE foo (val text not null); + +CREATE VIEW first_view AS SELECT * FROM foo; +CREATE VIEW second_view AS SELECT * FROM foo; +CREATE VIEW third_view AS SELECT * FROM foo; + +ALTER TABLE foo ADD COLUMN val2 integer; +CREATE OR REPLACE VIEW second_view AS SELECT * FROM foo; + +DROP VIEW third_view; diff --git a/internal/endtoend/testdata/create_view/mysql/sqlc.json b/internal/endtoend/testdata/create_view/mysql/sqlc.json new file mode 100644 index 0000000000..974aa9ff9e --- /dev/null +++ b/internal/endtoend/testdata/create_view/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "mysql", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/create_view/postgresql/go/db.go b/internal/endtoend/testdata/create_view/postgresql/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/create_view/postgresql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/create_view/postgresql/go/models.go b/internal/endtoend/testdata/create_view/postgresql/go/models.go new file mode 100644 index 0000000000..409c0fb965 --- /dev/null +++ b/internal/endtoend/testdata/create_view/postgresql/go/models.go @@ -0,0 +1,21 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type FirstView struct { + Val string +} + +type Foo struct { + Val string + Val2 sql.NullInt32 +} + +type SecondView struct { + Val string + Val2 sql.NullInt32 +} diff --git a/internal/endtoend/testdata/create_view/postgresql/go/query.sql.go b/internal/endtoend/testdata/create_view/postgresql/go/query.sql.go new file mode 100644 index 0000000000..aeed8ad5b6 --- /dev/null +++ b/internal/endtoend/testdata/create_view/postgresql/go/query.sql.go @@ -0,0 +1,63 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const getFirst = `-- name: GetFirst :many +SELECT val FROM first_view +` + +func (q *Queries) GetFirst(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getFirst) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var val string + if err := rows.Scan(&val); err != nil { + return nil, err + } + items = append(items, val) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSecond = `-- name: GetSecond :many +SELECT val, val2 FROM second_view WHERE val2 = $1 +` + +func (q *Queries) GetSecond(ctx context.Context, val2 sql.NullInt32) ([]SecondView, error) { + rows, err := q.db.QueryContext(ctx, getSecond, val2) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SecondView + for rows.Next() { + var i SecondView + if err := rows.Scan(&i.Val, &i.Val2); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/create_view/postgresql/query.sql b/internal/endtoend/testdata/create_view/postgresql/query.sql new file mode 100644 index 0000000000..1063db8740 --- /dev/null +++ b/internal/endtoend/testdata/create_view/postgresql/query.sql @@ -0,0 +1,5 @@ +-- name: GetFirst :many +SELECT * FROM first_view; + +-- name: GetSecond :many +SELECT * FROM second_view WHERE val2 = $1; diff --git a/internal/endtoend/testdata/create_view/postgresql/schema.sql b/internal/endtoend/testdata/create_view/postgresql/schema.sql new file mode 100644 index 0000000000..3f94440dec --- /dev/null +++ b/internal/endtoend/testdata/create_view/postgresql/schema.sql @@ -0,0 +1,10 @@ +CREATE TABLE foo (val text not null); + +CREATE VIEW first_view AS SELECT * FROM foo; +CREATE VIEW second_view AS SELECT * FROM foo; +CREATE VIEW third_view AS SELECT * FROM foo; + +ALTER TABLE foo ADD COLUMN val2 integer; +CREATE OR REPLACE VIEW second_view AS SELECT * FROM foo; + +DROP VIEW third_view; diff --git a/internal/endtoend/testdata/create_view/postgresql/sqlc.json b/internal/endtoend/testdata/create_view/postgresql/sqlc.json new file mode 100644 index 0000000000..cd518671ac --- /dev/null +++ b/internal/endtoend/testdata/create_view/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 2f4db8342c..94d9c1b8f9 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -302,10 +302,6 @@ func (c *cc) convertDeleteStmt(n *pcast.DeleteStmt) *ast.DeleteStmt { } func (c *cc) convertDropTableStmt(n *pcast.DropTableStmt) ast.Node { - // TODO: Remove once views are supported. - if n.IsView { - return todo(n) - } drop := &ast.DropTableStmt{IfExists: n.IfExists} for _, name := range n.Tables { drop.Tables = append(drop.Tables, parseTableName(name)) @@ -667,7 +663,14 @@ func (c *cc) convertCreateUserStmt(n *pcast.CreateUserStmt) ast.Node { } func (c *cc) convertCreateViewStmt(n *pcast.CreateViewStmt) ast.Node { - return todo(n) + return &ast.ViewStmt{ + View: c.convertTableName(n.ViewName), + Aliases: &ast.List{}, + Query: c.convert(n.Select), + Replace: n.OrReplace, + Options: &ast.List{}, + WithCheckOption: ast.ViewCheckOption(n.CheckOption), + } } func (c *cc) convertDeallocateStmt(n *pcast.DeallocateStmt) ast.Node { diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 675d0b6022..02b990a31c 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -514,7 +514,7 @@ func translate(node *nodes.Node) (ast.Node, error) { } return drop, nil - case nodes.ObjectType_OBJECT_TABLE: + case nodes.ObjectType_OBJECT_TABLE, nodes.ObjectType_OBJECT_VIEW, nodes.ObjectType_OBJECT_MATVIEW: drop := &ast.DropTableStmt{ IfExists: n.MissingOk, } diff --git a/internal/sql/catalog/catalog.go b/internal/sql/catalog/catalog.go index 460131baaa..36c466180a 100644 --- a/internal/sql/catalog/catalog.go +++ b/internal/sql/catalog/catalog.go @@ -267,14 +267,21 @@ func New(def string) *Catalog { func (c *Catalog) Build(stmts []ast.Statement) error { for i := range stmts { - if err := c.Update(stmts[i]); err != nil { + if err := c.Update(stmts[i], nil); err != nil { return err } } return nil } -func (c *Catalog) Update(stmt ast.Statement) error { +// An interface is used to resolve a circular import between the catalog and compiler packages. +// The createView function requires access to functions in the compiler package to parse the SELECT +// statement that defines the view. +type columnGenerator interface { + OutputColumns(node ast.Node) ([]*Column, error) +} + +func (c *Catalog) Update(stmt ast.Statement, colGen columnGenerator) error { if stmt.Raw == nil { return nil } @@ -322,6 +329,9 @@ func (c *Catalog) Update(stmt ast.Statement) error { case *ast.CreateTableStmt: err = c.createTable(n) + case *ast.ViewStmt: + err = c.createView(n, colGen) + case *ast.DropFunctionStmt: err = c.dropFunction(n) diff --git a/internal/sql/catalog/view.go b/internal/sql/catalog/view.go new file mode 100644 index 0000000000..d119894095 --- /dev/null +++ b/internal/sql/catalog/view.go @@ -0,0 +1,52 @@ +package catalog + +import ( + "github.com/kyleconroy/sqlc/internal/sql/ast" + "github.com/kyleconroy/sqlc/internal/sql/sqlerr" +) + +func (c *Catalog) createView(stmt *ast.ViewStmt, colGen columnGenerator) error { + cols, err := colGen.OutputColumns(stmt.Query) + if err != nil { + return err + } + + catName := "" + if stmt.View.Catalogname != nil { + catName = *stmt.View.Catalogname + } + schemaName := "" + if stmt.View.Schemaname != nil { + schemaName = *stmt.View.Schemaname + } + + tbl := Table{ + Rel: &ast.TableName{ + Catalog: catName, + Schema: schemaName, + Name: *stmt.View.Relname, + }, + Columns: cols, + } + + ns := tbl.Rel.Schema + if ns == "" { + ns = c.DefaultSchema + } + schema, err := c.getSchema(ns) + if err != nil { + return err + } + _, existingIdx, err := schema.getTable(tbl.Rel) + if err == nil && !stmt.Replace { + return sqlerr.RelationExists(tbl.Rel.Name) + } + + if stmt.Replace && err == nil { + schema.Tables[existingIdx] = &tbl + } else { + schema.Tables = append(schema.Tables, &tbl) + } + + return nil +}