diff --git a/internal/endtoend/testdata/select_union/mysql/go/models.go b/internal/endtoend/testdata/select_union/mysql/go/models.go index c0cab4c642..635378fe82 100644 --- a/internal/endtoend/testdata/select_union/mysql/go/models.go +++ b/internal/endtoend/testdata/select_union/mysql/go/models.go @@ -8,6 +8,11 @@ import ( "database/sql" ) +type Bar struct { + A sql.NullString + B sql.NullString +} + type Foo struct { A sql.NullString B sql.NullString diff --git a/internal/endtoend/testdata/select_union/mysql/go/query.sql.go b/internal/endtoend/testdata/select_union/mysql/go/query.sql.go index f514165eef..3c937f9b88 100644 --- a/internal/endtoend/testdata/select_union/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/select_union/mysql/go/query.sql.go @@ -96,6 +96,35 @@ func (q *Queries) SelectUnion(ctx context.Context) ([]Foo, error) { return items, nil } +const selectUnionOther = `-- name: SelectUnionOther :many +SELECT a, b FROM foo +UNION +SELECT a, b FROM bar +` + +func (q *Queries) SelectUnionOther(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, selectUnionOther) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const selectUnionWithLimit = `-- name: SelectUnionWithLimit :many SELECT a, b FROM foo UNION diff --git a/internal/endtoend/testdata/select_union/mysql/query.sql b/internal/endtoend/testdata/select_union/mysql/query.sql index b3a6371161..f8aca8b150 100644 --- a/internal/endtoend/testdata/select_union/mysql/query.sql +++ b/internal/endtoend/testdata/select_union/mysql/query.sql @@ -1,4 +1,5 @@ CREATE TABLE foo (a text, b text); +CREATE TABLE bar (a text, b text); -- name: SelectUnion :many SELECT * FROM foo @@ -20,3 +21,8 @@ SELECT * FROM foo; SELECT * FROM foo INTERSECT SELECT * FROM foo; + +-- name: SelectUnionOther :many +SELECT * FROM foo +UNION +SELECT * FROM bar; \ No newline at end of file diff --git a/internal/endtoend/testdata/select_union/postgres/pgx/v4/go/models.go b/internal/endtoend/testdata/select_union/postgres/pgx/v4/go/models.go index c0cab4c642..635378fe82 100644 --- a/internal/endtoend/testdata/select_union/postgres/pgx/v4/go/models.go +++ b/internal/endtoend/testdata/select_union/postgres/pgx/v4/go/models.go @@ -8,6 +8,11 @@ import ( "database/sql" ) +type Bar struct { + A sql.NullString + B sql.NullString +} + type Foo struct { A sql.NullString B sql.NullString diff --git a/internal/endtoend/testdata/select_union/postgres/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/select_union/postgres/pgx/v4/go/query.sql.go index 2b21c3e7fc..5977586a2a 100644 --- a/internal/endtoend/testdata/select_union/postgres/pgx/v4/go/query.sql.go +++ b/internal/endtoend/testdata/select_union/postgres/pgx/v4/go/query.sql.go @@ -87,6 +87,32 @@ func (q *Queries) SelectUnion(ctx context.Context) ([]Foo, error) { return items, nil } +const selectUnionOther = `-- name: SelectUnionOther :many +SELECT a, b FROM foo +UNION +SELECT a, b FROM bar +` + +func (q *Queries) SelectUnionOther(ctx context.Context) ([]Foo, error) { + rows, err := q.db.Query(ctx, selectUnionOther) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const selectUnionWithLimit = `-- name: SelectUnionWithLimit :many SELECT a, b FROM foo UNION diff --git a/internal/endtoend/testdata/select_union/postgres/pgx/v4/query.sql b/internal/endtoend/testdata/select_union/postgres/pgx/v4/query.sql index 9653e0707e..a2fafc7d0a 100644 --- a/internal/endtoend/testdata/select_union/postgres/pgx/v4/query.sql +++ b/internal/endtoend/testdata/select_union/postgres/pgx/v4/query.sql @@ -1,4 +1,5 @@ CREATE TABLE foo (a text, b text); +CREATE TABLE bar (a text, b text); -- name: SelectUnion :many SELECT * FROM foo @@ -20,3 +21,8 @@ SELECT * FROM foo; SELECT * FROM foo INTERSECT SELECT * FROM foo; + +-- name: SelectUnionOther :many +SELECT * FROM foo +UNION +SELECT * FROM bar; \ No newline at end of file diff --git a/internal/endtoend/testdata/select_union/postgres/pgx/v5/go/models.go b/internal/endtoend/testdata/select_union/postgres/pgx/v5/go/models.go index db5fe749f6..269271c124 100644 --- a/internal/endtoend/testdata/select_union/postgres/pgx/v5/go/models.go +++ b/internal/endtoend/testdata/select_union/postgres/pgx/v5/go/models.go @@ -8,6 +8,11 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +type Bar struct { + A pgtype.Text + B pgtype.Text +} + type Foo struct { A pgtype.Text B pgtype.Text diff --git a/internal/endtoend/testdata/select_union/postgres/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/select_union/postgres/pgx/v5/go/query.sql.go index 2b21c3e7fc..5977586a2a 100644 --- a/internal/endtoend/testdata/select_union/postgres/pgx/v5/go/query.sql.go +++ b/internal/endtoend/testdata/select_union/postgres/pgx/v5/go/query.sql.go @@ -87,6 +87,32 @@ func (q *Queries) SelectUnion(ctx context.Context) ([]Foo, error) { return items, nil } +const selectUnionOther = `-- name: SelectUnionOther :many +SELECT a, b FROM foo +UNION +SELECT a, b FROM bar +` + +func (q *Queries) SelectUnionOther(ctx context.Context) ([]Foo, error) { + rows, err := q.db.Query(ctx, selectUnionOther) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const selectUnionWithLimit = `-- name: SelectUnionWithLimit :many SELECT a, b FROM foo UNION diff --git a/internal/endtoend/testdata/select_union/postgres/pgx/v5/query.sql b/internal/endtoend/testdata/select_union/postgres/pgx/v5/query.sql index 9653e0707e..a2fafc7d0a 100644 --- a/internal/endtoend/testdata/select_union/postgres/pgx/v5/query.sql +++ b/internal/endtoend/testdata/select_union/postgres/pgx/v5/query.sql @@ -1,4 +1,5 @@ CREATE TABLE foo (a text, b text); +CREATE TABLE bar (a text, b text); -- name: SelectUnion :many SELECT * FROM foo @@ -20,3 +21,8 @@ SELECT * FROM foo; SELECT * FROM foo INTERSECT SELECT * FROM foo; + +-- name: SelectUnionOther :many +SELECT * FROM foo +UNION +SELECT * FROM bar; \ No newline at end of file diff --git a/internal/endtoend/testdata/select_union/postgres/stdlib/go/models.go b/internal/endtoend/testdata/select_union/postgres/stdlib/go/models.go index c0cab4c642..635378fe82 100644 --- a/internal/endtoend/testdata/select_union/postgres/stdlib/go/models.go +++ b/internal/endtoend/testdata/select_union/postgres/stdlib/go/models.go @@ -8,6 +8,11 @@ import ( "database/sql" ) +type Bar struct { + A sql.NullString + B sql.NullString +} + type Foo struct { A sql.NullString B sql.NullString diff --git a/internal/endtoend/testdata/select_union/postgres/stdlib/go/query.sql.go b/internal/endtoend/testdata/select_union/postgres/stdlib/go/query.sql.go index 505bf72bc9..0b638e9497 100644 --- a/internal/endtoend/testdata/select_union/postgres/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/select_union/postgres/stdlib/go/query.sql.go @@ -96,6 +96,35 @@ func (q *Queries) SelectUnion(ctx context.Context) ([]Foo, error) { return items, nil } +const selectUnionOther = `-- name: SelectUnionOther :many +SELECT a, b FROM foo +UNION +SELECT a, b FROM bar +` + +func (q *Queries) SelectUnionOther(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, selectUnionOther) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const selectUnionWithLimit = `-- name: SelectUnionWithLimit :many SELECT a, b FROM foo UNION diff --git a/internal/endtoend/testdata/select_union/postgres/stdlib/query.sql b/internal/endtoend/testdata/select_union/postgres/stdlib/query.sql index 9653e0707e..a2fafc7d0a 100644 --- a/internal/endtoend/testdata/select_union/postgres/stdlib/query.sql +++ b/internal/endtoend/testdata/select_union/postgres/stdlib/query.sql @@ -1,4 +1,5 @@ CREATE TABLE foo (a text, b text); +CREATE TABLE bar (a text, b text); -- name: SelectUnion :many SELECT * FROM foo @@ -20,3 +21,8 @@ SELECT * FROM foo; SELECT * FROM foo INTERSECT SELECT * FROM foo; + +-- name: SelectUnionOther :many +SELECT * FROM foo +UNION +SELECT * FROM bar; \ No newline at end of file diff --git a/internal/endtoend/testdata/select_union/sqlite/go/db.go b/internal/endtoend/testdata/select_union/sqlite/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/select_union/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/select_union/sqlite/go/models.go b/internal/endtoend/testdata/select_union/sqlite/go/models.go new file mode 100644 index 0000000000..635378fe82 --- /dev/null +++ b/internal/endtoend/testdata/select_union/sqlite/go/models.go @@ -0,0 +1,19 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql" +) + +type Bar struct { + A sql.NullString + B sql.NullString +} + +type Foo struct { + A sql.NullString + B sql.NullString +} diff --git a/internal/endtoend/testdata/select_union/sqlite/go/query.sql.go b/internal/endtoend/testdata/select_union/sqlite/go/query.sql.go new file mode 100644 index 0000000000..a38fd5db9a --- /dev/null +++ b/internal/endtoend/testdata/select_union/sqlite/go/query.sql.go @@ -0,0 +1,161 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const selectExcept = `-- name: SelectExcept :many +SELECT a, b FROM foo +EXCEPT +SELECT a, b FROM foo +` + +func (q *Queries) SelectExcept(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, selectExcept) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectIntersect = `-- name: SelectIntersect :many +SELECT a, b FROM foo +INTERSECT +SELECT a, b FROM foo +` + +func (q *Queries) SelectIntersect(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, selectIntersect) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUnion = `-- name: SelectUnion :many +SELECT a, b FROM foo +UNION +SELECT a, b FROM foo +` + +func (q *Queries) SelectUnion(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, selectUnion) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUnionOther = `-- name: SelectUnionOther :many +SELECT a, b FROM foo +UNION +SELECT a, b FROM bar +` + +func (q *Queries) SelectUnionOther(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, selectUnionOther) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectUnionWithLimit = `-- name: SelectUnionWithLimit :many +SELECT a, b FROM foo +UNION +SELECT a, b FROM foo +LIMIT ? OFFSET ? +` + +type SelectUnionWithLimitParams struct { + Limit int64 + Offset int64 +} + +func (q *Queries) SelectUnionWithLimit(ctx context.Context, arg SelectUnionWithLimitParams) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, selectUnionWithLimit, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/select_union/sqlite/query.sql b/internal/endtoend/testdata/select_union/sqlite/query.sql new file mode 100644 index 0000000000..f8aca8b150 --- /dev/null +++ b/internal/endtoend/testdata/select_union/sqlite/query.sql @@ -0,0 +1,28 @@ +CREATE TABLE foo (a text, b text); +CREATE TABLE bar (a text, b text); + +-- name: SelectUnion :many +SELECT * FROM foo +UNION +SELECT * FROM foo; + +-- name: SelectUnionWithLimit :many +SELECT * FROM foo +UNION +SELECT * FROM foo +LIMIT ? OFFSET ?; + +-- name: SelectExcept :many +SELECT * FROM foo +EXCEPT +SELECT * FROM foo; + +-- name: SelectIntersect :many +SELECT * FROM foo +INTERSECT +SELECT * FROM foo; + +-- name: SelectUnionOther :many +SELECT * FROM foo +UNION +SELECT * FROM bar; \ No newline at end of file diff --git a/internal/endtoend/testdata/select_union/sqlite/sqlc.json b/internal/endtoend/testdata/select_union/sqlite/sqlc.json new file mode 100644 index 0000000000..fc58be5b0d --- /dev/null +++ b/internal/endtoend/testdata/select_union/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "sqlite", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index ad8ddc7c0e..37e3a3c394 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -375,13 +375,7 @@ func (c *cc) convertComparison(n *parser.Expr_comparisonContext) ast.Node { } func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.Node { - var tables []ast.Node - var cols []ast.Node - var where ast.Node - var groups = []ast.Node{} - var having ast.Node - var ctes []ast.Node - + var ctes ast.List if ct := n.Common_table_stmt(); ct != nil { recursive := ct.RECURSIVE_() != nil for _, cte := range ct.AllCommon_table_expression() { @@ -390,7 +384,7 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No for _, col := range cte.AllColumn_name() { cteCols.Items = append(cteCols.Items, NewIdentifier(col.GetText())) } - ctes = append(ctes, &ast.CommonTableExpr{ + ctes.Items = append(ctes.Items, &ast.CommonTableExpr{ Ctename: &tableName, Ctequery: c.convert(cte.Select_stmt()), Location: cte.GetStart().GetStart(), @@ -400,20 +394,24 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No } } - for _, icore := range n.AllSelect_core() { + var selectStmt *ast.SelectStmt + for s, icore := range n.AllSelect_core() { core, ok := icore.(*parser.Select_coreContext) if !ok { continue } - cols = append(cols, c.getCols(core)...) - tables = append(tables, c.getTables(core)...) + cols := c.getCols(core) + tables := c.getTables(core) + var where ast.Node i := 0 if core.WHERE_() != nil { where = c.convert(core.Expr(i)) i++ } + var groups ast.List + var having ast.Node if core.GROUP_() != nil { l := len(core.AllExpr()) - i if core.HAVING_() != nil { @@ -422,33 +420,103 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No } for i < l { - groups = append(groups, c.convert(core.Expr(i))) + groups.Items = append(groups.Items, c.convert(core.Expr(i))) i++ } } - } - - window := &ast.List{Items: []ast.Node{}} - if n.Order_by_stmt() != nil { - window.Items = append(window.Items, c.convert(n.Order_by_stmt())) + var window ast.List + if core.WINDOW_() != nil { + for w, windowNameCtx := range core.AllWindow_name() { + windowName := identifier(windowNameCtx.GetText()) + windowDef := core.Window_defn(w) + + _ = windowDef.Base_window_name() + var partitionBy ast.List + if windowDef.PARTITION_() != nil { + for _, e := range windowDef.AllExpr() { + partitionBy.Items = append(partitionBy.Items, c.convert(e)) + } + } + var orderBy ast.List + if windowDef.ORDER_() != nil { + for _, e := range windowDef.AllOrdering_term() { + oterm := e.(*parser.Ordering_termContext) + sortByDir := ast.SortByDirDefault + if ad := oterm.Asc_desc(); ad != nil { + if ad.ASC_() != nil { + sortByDir = ast.SortByDirAsc + } else { + sortByDir = ast.SortByDirDesc + } + } + sortByNulls := ast.SortByNullsDefault + if oterm.NULLS_() != nil { + if oterm.FIRST_() != nil { + sortByNulls = ast.SortByNullsFirst + } else { + sortByNulls = ast.SortByNullsLast + } + } + + orderBy.Items = append(orderBy.Items, &ast.SortBy{ + Node: c.convert(oterm.Expr()), + SortbyDir: sortByDir, + SortbyNulls: sortByNulls, + UseOp: &ast.List{}, + }) + } + } + window.Items = append(window.Items, &ast.WindowDef{ + Name: &windowName, + PartitionClause: &partitionBy, + OrderClause: &orderBy, + FrameOptions: 0, // todo + StartOffset: &ast.TODO{}, + EndOffset: &ast.TODO{}, + Location: windowNameCtx.GetStart().GetStart(), + }) + } + } + sel := &ast.SelectStmt{ + FromClause: &ast.List{Items: tables}, + TargetList: &ast.List{Items: cols}, + WhereClause: where, + GroupClause: &groups, + HavingClause: having, + WindowClause: &window, + ValuesLists: &ast.List{}, + } + if selectStmt == nil { + selectStmt = sel + } else { + co := n.Compound_operator(s - 1) + so := ast.None + all := false + switch { + case co.UNION_() != nil: + so = ast.Union + all = co.ALL_() != nil + case co.INTERSECT_() != nil: + so = ast.Intersect + case co.EXCEPT_() != nil: + so = ast.Except + } + selectStmt = &ast.SelectStmt{ + TargetList: &ast.List{}, + FromClause: &ast.List{}, + Op: so, + All: all, + Larg: selectStmt, + Rarg: sel, + } + } } limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt()) - - return &ast.SelectStmt{ - FromClause: &ast.List{Items: tables}, - TargetList: &ast.List{Items: cols}, - WhereClause: where, - GroupClause: &ast.List{Items: groups}, - HavingClause: having, - WindowClause: window, - LimitCount: limitCount, - LimitOffset: limitOffset, - ValuesLists: &ast.List{}, - WithClause: &ast.WithClause{ - Ctes: &ast.List{Items: ctes}, - }, - } + selectStmt.LimitCount = limitCount + selectStmt.LimitOffset = limitOffset + selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} + return selectStmt } func (c *cc) convertExprListContext(n *parser.Expr_listContext) ast.Node { diff --git a/internal/sql/ast/sort_by_dir.go b/internal/sql/ast/sort_by_dir.go index f7f8d53950..3ebd212a79 100644 --- a/internal/sql/ast/sort_by_dir.go +++ b/internal/sql/ast/sort_by_dir.go @@ -5,3 +5,11 @@ type SortByDir uint func (n *SortByDir) Pos() int { return 0 } + +const ( + SortByDirUndefined SortByDir = 0 + SortByDirDefault SortByDir = 1 + SortByDirAsc SortByDir = 2 + SortByDirDesc SortByDir = 3 + SortByDirUsing SortByDir = 4 +) diff --git a/internal/sql/ast/sort_by_nulls.go b/internal/sql/ast/sort_by_nulls.go index c4d67c5d7f..512b5a14e1 100644 --- a/internal/sql/ast/sort_by_nulls.go +++ b/internal/sql/ast/sort_by_nulls.go @@ -5,3 +5,10 @@ type SortByNulls uint func (n *SortByNulls) Pos() int { return 0 } + +const ( + SortByNullsUndefined SortByNulls = 0 + SortByNullsDefault SortByNulls = 1 + SortByNullsFirst SortByNulls = 2 + SortByNullsLast SortByNulls = 3 +)