diff --git a/internal/endtoend/testdata/experimental.sql b/internal/endtoend/testdata/experimental.sql new file mode 100644 index 0000000000..c61b15171d --- /dev/null +++ b/internal/endtoend/testdata/experimental.sql @@ -0,0 +1,12 @@ +CREATE TABLE foo ( + bar text NOT NULL +); + +CREATE TABLE bar ( + baz text NOT NULL +); + +SELECT bar FROM foo; + +DROP TABLE bar; +DROP TABLE IF EXISTS baz; diff --git a/internal/endtoend/testdata/experimental_dolphin/query.sql b/internal/endtoend/testdata/experimental_dolphin/query.sql deleted file mode 100644 index 72d89e29f7..0000000000 --- a/internal/endtoend/testdata/experimental_dolphin/query.sql +++ /dev/null @@ -1,8 +0,0 @@ -CREATE TABLE foo ( - bar text NOT NULL -); - -SELECT bar FROM foo; - --- DROP TABLE foo; --- DROP TABLE IF EXISTS bar; diff --git a/internal/endtoend/testdata/experimental_dolphin/sqlc.json b/internal/endtoend/testdata/experimental_dolphin/sqlc.json index 65b2f3b387..11441c22f9 100644 --- a/internal/endtoend/testdata/experimental_dolphin/sqlc.json +++ b/internal/endtoend/testdata/experimental_dolphin/sqlc.json @@ -5,8 +5,8 @@ "path": "go", "name": "querytest", "engine": "_dolphin", - "schema": "query.sql", - "queries": "query.sql" + "schema": "../experimental.sql", + "queries": "../experimental.sql" } ] } diff --git a/internal/endtoend/testdata/experimental_elephant/query.sql b/internal/endtoend/testdata/experimental_elephant/query.sql deleted file mode 100644 index 72d89e29f7..0000000000 --- a/internal/endtoend/testdata/experimental_elephant/query.sql +++ /dev/null @@ -1,8 +0,0 @@ -CREATE TABLE foo ( - bar text NOT NULL -); - -SELECT bar FROM foo; - --- DROP TABLE foo; --- DROP TABLE IF EXISTS bar; diff --git a/internal/endtoend/testdata/experimental_elephant/sqlc.json b/internal/endtoend/testdata/experimental_elephant/sqlc.json index 8759dc6a2c..4d019a3dcd 100644 --- a/internal/endtoend/testdata/experimental_elephant/sqlc.json +++ b/internal/endtoend/testdata/experimental_elephant/sqlc.json @@ -5,8 +5,8 @@ "path": "go", "name": "querytest", "engine": "_elephant", - "schema": "query.sql", - "queries": "query.sql" + "schema": "../experimental.sql", + "queries": "../experimental.sql" } ] } diff --git a/internal/endtoend/testdata/experimental_lemon/query.sql b/internal/endtoend/testdata/experimental_lemon/query.sql deleted file mode 100644 index 72d89e29f7..0000000000 --- a/internal/endtoend/testdata/experimental_lemon/query.sql +++ /dev/null @@ -1,8 +0,0 @@ -CREATE TABLE foo ( - bar text NOT NULL -); - -SELECT bar FROM foo; - --- DROP TABLE foo; --- DROP TABLE IF EXISTS bar; diff --git a/internal/endtoend/testdata/experimental_lemon/sqlc.json b/internal/endtoend/testdata/experimental_lemon/sqlc.json index 28b46e69de..548e63d707 100644 --- a/internal/endtoend/testdata/experimental_lemon/sqlc.json +++ b/internal/endtoend/testdata/experimental_lemon/sqlc.json @@ -5,8 +5,8 @@ "path": "go", "name": "querytest", "engine": "_lemon", - "schema": "query.sql", - "queries": "query.sql" + "schema": "../experimental.sql", + "queries": "../experimental.sql" } ] } diff --git a/internal/sqlite/listener.go b/internal/sqlite/listener.go new file mode 100644 index 0000000000..b83ab1593f --- /dev/null +++ b/internal/sqlite/listener.go @@ -0,0 +1,141 @@ +package sqlite + +import ( + "github.com/kyleconroy/sqlc/internal/sql/ast" + "github.com/kyleconroy/sqlc/internal/sqlite/parser" +) + +type listener struct { + *parser.BaseSQLiteListener + + stmt *ast.RawStmt + + stmts []ast.Statement +} + +// The Visitor code generated by Antlr doesn't currently work +// +// To make due, we mark the listener as "busy" if it's currently processing a +// node. This helps avoids scenarios where enter is called on nested +// statements. +func (l *listener) busy() bool { + return l.stmt != nil +} + +func (l *listener) EnterSql_stmt(c *parser.Sql_stmtContext) { + l.stmt = nil +} + +func (l *listener) ExitSql_stmt(c *parser.Sql_stmtContext) { + if l.stmt != nil { + l.stmts = append(l.stmts, ast.Statement{ + Raw: l.stmt, + }) + } +} + +func (l *listener) EnterCreate_table_stmt(c *parser.Create_table_stmtContext) { + if l.busy() { + return + } + + name := ast.TableName{ + Name: c.Table_name().GetText(), + } + + if c.Database_name() != nil { + name.Schema = c.Database_name().GetText() + } + + stmt := &ast.CreateTableStmt{ + Name: &name, + IfNotExists: c.K_EXISTS() != nil, + } + + for _, idef := range c.AllColumn_def() { + if def, ok := idef.(*parser.Column_defContext); ok { + stmt.Cols = append(stmt.Cols, &ast.ColumnDef{ + Colname: def.Column_name().GetText(), + TypeName: &ast.TypeName{ + Name: def.Type_name().GetText(), + }, + }) + } + } + + l.stmt = &ast.RawStmt{Stmt: stmt} +} + +func (l *listener) EnterDrop_table_stmt(c *parser.Drop_table_stmtContext) { + if l.busy() { + return + } + + drop := &ast.DropTableStmt{ + IfExists: c.K_EXISTS() != nil, + } + + name := ast.TableName{ + Name: c.Table_name().GetText(), + } + + if c.Database_name() != nil { + name.Schema = c.Database_name().GetText() + } + + drop.Tables = append(drop.Tables, &name) + l.stmt = &ast.RawStmt{Stmt: drop} +} + +func (l *listener) EnterFactored_select_stmt(c *parser.Factored_select_stmtContext) { + if l.busy() { + return + } + + var tables []ast.Node + var cols []ast.Node + for _, icore := range c.AllSelect_core() { + core, ok := icore.(*parser.Select_coreContext) + if !ok { + continue + } + for _, icol := range core.AllResult_column() { + col, ok := icol.(*parser.Result_columnContext) + if !ok { + continue + } + iexpr := col.Expr() + if iexpr == nil { + continue + } + expr, ok := iexpr.(*parser.ExprContext) + if !ok { + continue + } + cols = append(cols, &ast.ResTarget{ + Val: &ast.ColumnRef{ + Name: expr.Column_name().GetText(), + }, + }) + } + for _, ifrom := range core.AllTable_or_subquery() { + from, ok := ifrom.(*parser.Table_or_subqueryContext) + if !ok { + continue + } + name := ast.TableName{ + Name: from.Table_name().GetText(), + } + if from.Schema_name() != nil { + name.Schema = from.Schema_name().GetText() + } + tables = append(tables, &name) + } + } + + sel := &ast.SelectStmt{ + From: &ast.List{Items: tables}, + Fields: &ast.List{Items: cols}, + } + l.stmt = &ast.RawStmt{Stmt: sel} +} diff --git a/internal/sqlite/parse.go b/internal/sqlite/parse.go index 41e2db6dc1..8ac3aa6d64 100644 --- a/internal/sqlite/parse.go +++ b/internal/sqlite/parse.go @@ -11,55 +11,6 @@ import ( "github.com/kyleconroy/sqlc/internal/sqlite/parser" ) -type listener struct { - *parser.BaseSQLiteListener - - stmt *ast.RawStmt - - stmts []ast.Statement -} - -func (l *listener) EnterSql_stmt(c *parser.Sql_stmtContext) { - l.stmt = nil -} - -func (l *listener) ExitSql_stmt(c *parser.Sql_stmtContext) { - if l.stmt != nil { - l.stmts = append(l.stmts, ast.Statement{ - Raw: l.stmt, - }) - return - } -} - -func (l *listener) EnterCreate_table_stmt(c *parser.Create_table_stmtContext) { - name := ast.TableName{ - Name: c.Table_name().GetText(), - } - - if c.Database_name() != nil { - name.Schema = c.Database_name().GetText() - } - - stmt := &ast.CreateTableStmt{ - Name: &name, - IfNotExists: c.K_EXISTS() != nil, - } - - for _, idef := range c.AllColumn_def() { - if def, ok := idef.(*parser.Column_defContext); ok { - stmt.Cols = append(stmt.Cols, &ast.ColumnDef{ - Colname: def.Column_name().GetText(), - TypeName: &ast.TypeName{ - Name: def.Type_name().GetText(), - }, - }) - } - } - - l.stmt = &ast.RawStmt{Stmt: stmt} -} - type errorListener struct { *antlr.DefaultErrorListener @@ -98,11 +49,11 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { l := &listener{} el := &errorListener{} pp.AddErrorListener(el) + // pp.BuildParseTrees = true tree := pp.Parse() if el.err != "" { return nil, errors.New(el.err) } - // p.BuildParseTrees = true antlr.ParseTreeWalkerDefault.Walk(l, tree) return l.stmts, nil } diff --git a/internal/sqlite/visitor.go b/internal/sqlite/visitor.go new file mode 100644 index 0000000000..5222e601f1 --- /dev/null +++ b/internal/sqlite/visitor.go @@ -0,0 +1,347 @@ +package sqlite + +import ( + "fmt" + + "github.com/antlr/antlr4/runtime/Go/antlr" + "github.com/kyleconroy/sqlc/internal/sqlite/parser" +) + +type visitor struct { +} + +// ParseTreeVisitor interace +func (v *visitor) Visit(tree antlr.ParseTree) interface{} { return v } +func (v *visitor) VisitChildren(node antlr.RuleNode) interface{} { return v } +func (v *visitor) VisitTerminal(node antlr.TerminalNode) interface{} { return v } +func (v *visitor) VisitErrorNode(node antlr.ErrorNode) interface{} { return v } + +// SQLiteVisitor interface +func (v *visitor) VisitParse(ctx *parser.ParseContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSql_stmt_list(ctx *parser.Sql_stmt_listContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSql_stmt(ctx *parser.Sql_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitAlter_table_stmt(ctx *parser.Alter_table_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitAnalyze_stmt(ctx *parser.Analyze_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitAttach_stmt(ctx *parser.Attach_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitBegin_stmt(ctx *parser.Begin_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCommit_stmt(ctx *parser.Commit_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCompound_select_stmt(ctx *parser.Compound_select_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCreate_index_stmt(ctx *parser.Create_index_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCreate_table_stmt(ctx *parser.Create_table_stmtContext) interface{} { + fmt.Println("CREATE TABLE", ctx) + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCreate_trigger_stmt(ctx *parser.Create_trigger_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCreate_view_stmt(ctx *parser.Create_view_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCreate_virtual_table_stmt(ctx *parser.Create_virtual_table_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitDelete_stmt(ctx *parser.Delete_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitDelete_stmt_limited(ctx *parser.Delete_stmt_limitedContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitDetach_stmt(ctx *parser.Detach_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitDrop_index_stmt(ctx *parser.Drop_index_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitDrop_table_stmt(ctx *parser.Drop_table_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitDrop_trigger_stmt(ctx *parser.Drop_trigger_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitDrop_view_stmt(ctx *parser.Drop_view_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitFactored_select_stmt(ctx *parser.Factored_select_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitInsert_stmt(ctx *parser.Insert_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitPragma_stmt(ctx *parser.Pragma_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitReindex_stmt(ctx *parser.Reindex_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitRelease_stmt(ctx *parser.Release_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitRollback_stmt(ctx *parser.Rollback_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSavepoint_stmt(ctx *parser.Savepoint_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSimple_select_stmt(ctx *parser.Simple_select_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSelect_stmt(ctx *parser.Select_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSelect_or_values(ctx *parser.Select_or_valuesContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitUpdate_stmt(ctx *parser.Update_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitUpdate_stmt_limited(ctx *parser.Update_stmt_limitedContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitVacuum_stmt(ctx *parser.Vacuum_stmtContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitColumn_def(ctx *parser.Column_defContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitType_name(ctx *parser.Type_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitColumn_constraint(ctx *parser.Column_constraintContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitConflict_clause(ctx *parser.Conflict_clauseContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitExpr(ctx *parser.ExprContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitForeign_key_clause(ctx *parser.Foreign_key_clauseContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitRaise_function(ctx *parser.Raise_functionContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitIndexed_column(ctx *parser.Indexed_columnContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitTable_constraint(ctx *parser.Table_constraintContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitWith_clause(ctx *parser.With_clauseContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitQualified_table_name(ctx *parser.Qualified_table_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitOrdering_term(ctx *parser.Ordering_termContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitPragma_value(ctx *parser.Pragma_valueContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCommon_table_expression(ctx *parser.Common_table_expressionContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitResult_column(ctx *parser.Result_columnContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitTable_or_subquery(ctx *parser.Table_or_subqueryContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitJoin_clause(ctx *parser.Join_clauseContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitJoin_operator(ctx *parser.Join_operatorContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitJoin_constraint(ctx *parser.Join_constraintContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSelect_core(ctx *parser.Select_coreContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCompound_operator(ctx *parser.Compound_operatorContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSigned_number(ctx *parser.Signed_numberContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitLiteral_value(ctx *parser.Literal_valueContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitUnary_operator(ctx *parser.Unary_operatorContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitError_message(ctx *parser.Error_messageContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitModule_argument(ctx *parser.Module_argumentContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitColumn_alias(ctx *parser.Column_aliasContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitKeyword(ctx *parser.KeywordContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitName(ctx *parser.NameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitFunction_name(ctx *parser.Function_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitDatabase_name(ctx *parser.Database_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSchema_name(ctx *parser.Schema_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitTable_function_name(ctx *parser.Table_function_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitTable_name(ctx *parser.Table_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitTable_or_index_name(ctx *parser.Table_or_index_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitNew_table_name(ctx *parser.New_table_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitColumn_name(ctx *parser.Column_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitCollation_name(ctx *parser.Collation_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitForeign_table(ctx *parser.Foreign_tableContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitIndex_name(ctx *parser.Index_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitTrigger_name(ctx *parser.Trigger_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitView_name(ctx *parser.View_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitModule_name(ctx *parser.Module_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitPragma_name(ctx *parser.Pragma_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitSavepoint_name(ctx *parser.Savepoint_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitTable_alias(ctx *parser.Table_aliasContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitTransaction_name(ctx *parser.Transaction_nameContext) interface{} { + return v.VisitChildren(ctx) +} + +func (v *visitor) VisitAny_name(ctx *parser.Any_nameContext) interface{} { + return v.VisitChildren(ctx) +}