Skip to content

Commit 76d0a78

Browse files
authored
feat(sqlite): Add support for UPDATE/DELETE with a LIMIT clause (#2384)
closes #2291
1 parent 3e0fca0 commit 76d0a78

File tree

3 files changed

+85
-12
lines changed

3 files changed

+85
-12
lines changed

internal/endtoend/testdata/limit/sqlite/go/query.sql.go

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/limit/sqlite/query.sql

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@ CREATE TABLE foo (bar bool not null);
22

33
-- name: LimitMe :many
44
SELECT bar FROM foo LIMIT ?;
5+
6+
-- name: UpdateLimit :exec
7+
UPDATE foo SET bar='baz' LIMIT ?;
8+
9+
-- name: DeleteLimit :exec
10+
DELETE FROM foo LIMIT ?;

internal/engine/sqlite/convert.go

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,15 @@ func (c *cc) convertCreate_view_stmtContext(n *parser.Create_view_stmtContext) a
144144
}
145145
}
146146

147-
func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node {
147+
type Delete_stmt interface {
148+
node
149+
150+
Qualified_table_name() parser.IQualified_table_nameContext
151+
WHERE_() antlr.TerminalNode
152+
Expr() parser.IExprContext
153+
}
154+
155+
func (c *cc) convertDelete_stmtContext(n Delete_stmt) ast.Node {
148156
if qualifiedName, ok := n.Qualified_table_name().(*parser.Qualified_table_nameContext); ok {
149157

150158
tableName := qualifiedName.Table_name().GetText()
@@ -167,15 +175,28 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node {
167175
relations.Items = append(relations.Items, relation)
168176

169177
delete := &ast.DeleteStmt{
170-
Relations: relations,
171-
ReturningList: c.convertReturning_caluseContext(n.Returning_clause()),
172-
WithClause: nil,
178+
Relations: relations,
179+
WithClause: nil,
173180
}
174181

175182
if n.WHERE_() != nil && n.Expr() != nil {
176183
delete.WhereClause = c.convert(n.Expr())
177184
}
178185

186+
if n, ok := n.(interface {
187+
Returning_clause() parser.IReturning_clauseContext
188+
}); ok {
189+
delete.ReturningList = c.convertReturning_caluseContext(n.Returning_clause())
190+
} else {
191+
delete.ReturningList = c.convertReturning_caluseContext(nil)
192+
}
193+
if n, ok := n.(interface {
194+
Limit_stmt() parser.ILimit_stmtContext
195+
}); ok {
196+
limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt())
197+
delete.LimitCount = limitCount
198+
}
199+
179200
return delete
180201
}
181202

@@ -796,7 +817,16 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
796817
return tables
797818
}
798819

799-
func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node {
820+
type Update_stmt interface {
821+
Qualified_table_name() parser.IQualified_table_nameContext
822+
GetStart() antlr.Token
823+
AllColumn_name() []parser.IColumn_nameContext
824+
WHERE_() antlr.TerminalNode
825+
Expr(i int) parser.IExprContext
826+
AllExpr() []parser.IExprContext
827+
}
828+
829+
func (c *cc) convertUpdate_stmtContext(n Update_stmt) ast.Node {
800830
if n == nil {
801831
return nil
802832
}
@@ -824,14 +854,27 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node {
824854
where = c.convert(n.Expr(len(n.AllExpr()) - 1))
825855
}
826856

827-
return &ast.UpdateStmt{
828-
Relations: relations,
829-
TargetList: list,
830-
WhereClause: where,
831-
ReturningList: c.convertReturning_caluseContext(n.Returning_clause()),
832-
FromClause: &ast.List{},
833-
WithClause: nil, // TODO: support with clause
857+
stmt := &ast.UpdateStmt{
858+
Relations: relations,
859+
TargetList: list,
860+
WhereClause: where,
861+
FromClause: &ast.List{},
862+
WithClause: nil, // TODO: support with clause
863+
}
864+
if n, ok := n.(interface {
865+
Returning_clause() parser.IReturning_clauseContext
866+
}); ok {
867+
stmt.ReturningList = c.convertReturning_caluseContext(n.Returning_clause())
868+
} else {
869+
stmt.ReturningList = c.convertReturning_caluseContext(nil)
834870
}
871+
if n, ok := n.(interface {
872+
Limit_stmt() parser.ILimit_stmtContext
873+
}); ok {
874+
limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt())
875+
stmt.LimitCount = limitCount
876+
}
877+
return stmt
835878
}
836879

837880
func (c *cc) convertBetweenExpr(n *parser.Expr_betweenContext) ast.Node {
@@ -865,6 +908,9 @@ func (c *cc) convert(node node) ast.Node {
865908
case *parser.Delete_stmtContext:
866909
return c.convertDelete_stmtContext(n)
867910

911+
case *parser.Delete_stmt_limitedContext:
912+
return c.convertDelete_stmtContext(n)
913+
868914
case *parser.ExprContext:
869915
return c.convertExprContext(n)
870916

@@ -917,6 +963,9 @@ func (c *cc) convert(node node) ast.Node {
917963
case *parser.Update_stmtContext:
918964
return c.convertUpdate_stmtContext(n)
919965

966+
case *parser.Update_stmt_limitedContext:
967+
return c.convertUpdate_stmtContext(n)
968+
920969
default:
921970
return todo("convert(case=default)", n)
922971
}

0 commit comments

Comments
 (0)