Skip to content

Commit 86e579d

Browse files
committed
dolphin: add support for union query
1 parent 3a58ee9 commit 86e579d

File tree

11 files changed

+186
-1
lines changed

11 files changed

+186
-1
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"engine": "mysql",
6+
"path": "go",
7+
"name": "querytest",
8+
"schema": "query.sql",
9+
"queries": "query.sql"
10+
}
11+
]
12+
}

internal/endtoend/testdata/select_union/postgres/go/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/select_union/postgres/go/models.go

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/select_union/postgres/go/query.sql.go

Lines changed: 37 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
CREATE TABLE foo (a text, b text);
2+
3+
-- name: SelectUnion :many
4+
SELECT * FROM foo
5+
UNION
6+
SELECT * FROM foo;

internal/engine/dolphin/convert.go

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,13 @@ func (c *cc) convertSelectField(n *pcast.SelectField) *ast.ResTarget {
442442
}
443443

444444
func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt {
445+
op, all := c.convertSetOprType(n.AfterSetOperator)
445446
stmt := &ast.SelectStmt{
446447
TargetList: c.convertFieldList(n.Fields),
447448
FromClause: c.convertTableRefsClause(n.From),
448449
WhereClause: c.convert(n.Where),
450+
Op: op,
451+
All: all,
449452
}
450453
if n.Limit != nil {
451454
stmt.LimitCount = c.convert(n.Limit.Count)
@@ -937,11 +940,97 @@ func (c *cc) convertSetDefaultRoleStmt(n *pcast.SetDefaultRoleStmt) ast.Node {
937940
return todo(n)
938941
}
939942

943+
func (c *cc) convertSetOprType(n *pcast.SetOprType) (op ast.SetOperation, all bool) {
944+
if n == nil {
945+
return
946+
}
947+
948+
switch *n {
949+
case pcast.Union:
950+
op = ast.Union
951+
case pcast.UnionAll:
952+
op = ast.Union
953+
all = true
954+
case pcast.Intersect:
955+
op = ast.Intersect
956+
case pcast.IntersectAll:
957+
op = ast.Intersect
958+
all = true
959+
case pcast.Except:
960+
op = ast.Except
961+
case pcast.ExceptAll:
962+
op = ast.Except
963+
all = true
964+
}
965+
return
966+
}
967+
968+
// convertSetOprSelectList converts a list of SELECT from the Pingcap parser
969+
// into a tree. It is called for UNION, INTERSECT or EXCLUDE operation.
970+
//
971+
// Given an union with the following nodes:
972+
// [Select{1}, Select{2}, Select{3}, Select{4}]
973+
//
974+
// The function will return:
975+
// Select{
976+
// Larg: Select{
977+
// Larg: Select{
978+
// Larg: Select{1},
979+
// Rarg: Select{2},
980+
// Op: Union
981+
// },
982+
// Rarg: Select{3},
983+
// Op: Union,
984+
// },
985+
// Rarg: Select{4},
986+
// Op: Union,
987+
// }
940988
func (c *cc) convertSetOprSelectList(n *pcast.SetOprSelectList) ast.Node {
941-
return todo(n)
989+
selectStmts := make([]*ast.SelectStmt, len(n.Selects))
990+
for i, node := range n.Selects {
991+
selectStmts[i] = c.convertSelectStmt(node.(*pcast.SelectStmt))
992+
}
993+
994+
op, all := c.convertSetOprType(n.AfterSetOperator)
995+
tree := &ast.SelectStmt{
996+
TargetList: &ast.List{},
997+
FromClause: &ast.List{},
998+
WhereClause: nil,
999+
Op: op,
1000+
All: all,
1001+
}
1002+
for _, stmt := range selectStmts {
1003+
// We move Op and All from the child to the parent.
1004+
op, all := stmt.Op, stmt.All
1005+
stmt.Op, stmt.All = ast.None, false
1006+
1007+
switch {
1008+
case tree.Larg == nil:
1009+
tree.Larg = stmt
1010+
case tree.Rarg == nil:
1011+
tree.Rarg = stmt
1012+
tree.Op = op
1013+
tree.All = all
1014+
default:
1015+
tree = &ast.SelectStmt{
1016+
TargetList: &ast.List{},
1017+
FromClause: &ast.List{},
1018+
WhereClause: nil,
1019+
Larg: tree,
1020+
Rarg: stmt,
1021+
Op: op,
1022+
All: all,
1023+
}
1024+
}
1025+
}
1026+
1027+
return tree
9421028
}
9431029

9441030
func (c *cc) convertSetOprStmt(n *pcast.SetOprStmt) ast.Node {
1031+
if n.SelectList != nil {
1032+
return c.convertSetOprSelectList(n.SelectList)
1033+
}
9451034
return todo(n)
9461035
}
9471036

0 commit comments

Comments
 (0)