diff --git a/internal/endtoend/testdata/limit/sqlite/go/query.sql.go b/internal/endtoend/testdata/limit/sqlite/go/query.sql.go index afa456a8e0..ec6f47c09c 100644 --- a/internal/endtoend/testdata/limit/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/limit/sqlite/go/query.sql.go @@ -9,6 +9,15 @@ import ( "context" ) +const deleteLimit = `-- name: DeleteLimit :exec +DELETE FROM foo LIMIT ? +` + +func (q *Queries) DeleteLimit(ctx context.Context, limit int64) error { + _, err := q.db.ExecContext(ctx, deleteLimit, limit) + return err +} + const limitMe = `-- name: LimitMe :many SELECT bar FROM foo LIMIT ? ` @@ -35,3 +44,12 @@ func (q *Queries) LimitMe(ctx context.Context, limit int64) ([]bool, error) { } return items, nil } + +const updateLimit = `-- name: UpdateLimit :exec +UPDATE foo SET bar='baz' LIMIT ? +` + +func (q *Queries) UpdateLimit(ctx context.Context, limit int64) error { + _, err := q.db.ExecContext(ctx, updateLimit, limit) + return err +} diff --git a/internal/endtoend/testdata/limit/sqlite/query.sql b/internal/endtoend/testdata/limit/sqlite/query.sql index e7e373d13b..99862ad760 100644 --- a/internal/endtoend/testdata/limit/sqlite/query.sql +++ b/internal/endtoend/testdata/limit/sqlite/query.sql @@ -2,3 +2,9 @@ CREATE TABLE foo (bar bool not null); -- name: LimitMe :many SELECT bar FROM foo LIMIT ?; + +-- name: UpdateLimit :exec +UPDATE foo SET bar='baz' LIMIT ?; + +-- name: DeleteLimit :exec +DELETE FROM foo LIMIT ?; diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index a3a75b56e7..a730fabdb6 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -144,7 +144,15 @@ func (c *cc) convertCreate_view_stmtContext(n *parser.Create_view_stmtContext) a } } -func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { +type Delete_stmt interface { + node + + Qualified_table_name() parser.IQualified_table_nameContext + WHERE_() antlr.TerminalNode + Expr() parser.IExprContext +} + +func (c *cc) convertDelete_stmtContext(n Delete_stmt) ast.Node { if qualifiedName, ok := n.Qualified_table_name().(*parser.Qualified_table_nameContext); ok { tableName := qualifiedName.Table_name().GetText() @@ -167,15 +175,28 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { relations.Items = append(relations.Items, relation) delete := &ast.DeleteStmt{ - Relations: relations, - ReturningList: c.convertReturning_caluseContext(n.Returning_clause()), - WithClause: nil, + Relations: relations, + WithClause: nil, } if n.WHERE_() != nil && n.Expr() != nil { delete.WhereClause = c.convert(n.Expr()) } + if n, ok := n.(interface { + Returning_clause() parser.IReturning_clauseContext + }); ok { + delete.ReturningList = c.convertReturning_caluseContext(n.Returning_clause()) + } else { + delete.ReturningList = c.convertReturning_caluseContext(nil) + } + if n, ok := n.(interface { + Limit_stmt() parser.ILimit_stmtContext + }); ok { + limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt()) + delete.LimitCount = limitCount + } + return delete } @@ -796,7 +817,16 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast return tables } -func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { +type Update_stmt interface { + Qualified_table_name() parser.IQualified_table_nameContext + GetStart() antlr.Token + AllColumn_name() []parser.IColumn_nameContext + WHERE_() antlr.TerminalNode + Expr(i int) parser.IExprContext + AllExpr() []parser.IExprContext +} + +func (c *cc) convertUpdate_stmtContext(n Update_stmt) ast.Node { if n == nil { return nil } @@ -824,14 +854,27 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { where = c.convert(n.Expr(len(n.AllExpr()) - 1)) } - return &ast.UpdateStmt{ - Relations: relations, - TargetList: list, - WhereClause: where, - ReturningList: c.convertReturning_caluseContext(n.Returning_clause()), - FromClause: &ast.List{}, - WithClause: nil, // TODO: support with clause + stmt := &ast.UpdateStmt{ + Relations: relations, + TargetList: list, + WhereClause: where, + FromClause: &ast.List{}, + WithClause: nil, // TODO: support with clause + } + if n, ok := n.(interface { + Returning_clause() parser.IReturning_clauseContext + }); ok { + stmt.ReturningList = c.convertReturning_caluseContext(n.Returning_clause()) + } else { + stmt.ReturningList = c.convertReturning_caluseContext(nil) } + if n, ok := n.(interface { + Limit_stmt() parser.ILimit_stmtContext + }); ok { + limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt()) + stmt.LimitCount = limitCount + } + return stmt } func (c *cc) convertBetweenExpr(n *parser.Expr_betweenContext) ast.Node { @@ -865,6 +908,9 @@ func (c *cc) convert(node node) ast.Node { case *parser.Delete_stmtContext: return c.convertDelete_stmtContext(n) + case *parser.Delete_stmt_limitedContext: + return c.convertDelete_stmtContext(n) + case *parser.ExprContext: return c.convertExprContext(n) @@ -917,6 +963,9 @@ func (c *cc) convert(node node) ast.Node { case *parser.Update_stmtContext: return c.convertUpdate_stmtContext(n) + case *parser.Update_stmt_limitedContext: + return c.convertUpdate_stmtContext(n) + default: return todo("convert(case=default)", n) }