@@ -144,7 +144,15 @@ func (c *cc) convertCreate_view_stmtContext(n *parser.Create_view_stmtContext) a
144
144
}
145
145
}
146
146
147
- func (c * cc ) convertDelete_stmtContext (n * parser.Delete_stmtContext ) ast.Node {
147
+ type Delete_stmt interface {
148
+ node
149
+
150
+ Qualified_table_name () parser.IQualified_table_nameContext
151
+ WHERE_ () antlr.TerminalNode
152
+ Expr () parser.IExprContext
153
+ }
154
+
155
+ func (c * cc ) convertDelete_stmtContext (n Delete_stmt ) ast.Node {
148
156
if qualifiedName , ok := n .Qualified_table_name ().(* parser.Qualified_table_nameContext ); ok {
149
157
150
158
tableName := qualifiedName .Table_name ().GetText ()
@@ -167,15 +175,28 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node {
167
175
relations .Items = append (relations .Items , relation )
168
176
169
177
delete := & ast.DeleteStmt {
170
- Relations : relations ,
171
- ReturningList : c .convertReturning_caluseContext (n .Returning_clause ()),
172
- WithClause : nil ,
178
+ Relations : relations ,
179
+ WithClause : nil ,
173
180
}
174
181
175
182
if n .WHERE_ () != nil && n .Expr () != nil {
176
183
delete .WhereClause = c .convert (n .Expr ())
177
184
}
178
185
186
+ if n , ok := n .(interface {
187
+ Returning_clause () parser.IReturning_clauseContext
188
+ }); ok {
189
+ delete .ReturningList = c .convertReturning_caluseContext (n .Returning_clause ())
190
+ } else {
191
+ delete .ReturningList = c .convertReturning_caluseContext (nil )
192
+ }
193
+ if n , ok := n .(interface {
194
+ Limit_stmt () parser.ILimit_stmtContext
195
+ }); ok {
196
+ limitCount , _ := c .convertLimit_stmtContext (n .Limit_stmt ())
197
+ delete .LimitCount = limitCount
198
+ }
199
+
179
200
return delete
180
201
}
181
202
@@ -796,7 +817,16 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
796
817
return tables
797
818
}
798
819
799
- func (c * cc ) convertUpdate_stmtContext (n * parser.Update_stmtContext ) ast.Node {
820
+ type Update_stmt interface {
821
+ Qualified_table_name () parser.IQualified_table_nameContext
822
+ GetStart () antlr.Token
823
+ AllColumn_name () []parser.IColumn_nameContext
824
+ WHERE_ () antlr.TerminalNode
825
+ Expr (i int ) parser.IExprContext
826
+ AllExpr () []parser.IExprContext
827
+ }
828
+
829
+ func (c * cc ) convertUpdate_stmtContext (n Update_stmt ) ast.Node {
800
830
if n == nil {
801
831
return nil
802
832
}
@@ -824,14 +854,27 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node {
824
854
where = c .convert (n .Expr (len (n .AllExpr ()) - 1 ))
825
855
}
826
856
827
- return & ast.UpdateStmt {
828
- Relations : relations ,
829
- TargetList : list ,
830
- WhereClause : where ,
831
- ReturningList : c .convertReturning_caluseContext (n .Returning_clause ()),
832
- FromClause : & ast.List {},
833
- WithClause : nil , // TODO: support with clause
857
+ stmt := & ast.UpdateStmt {
858
+ Relations : relations ,
859
+ TargetList : list ,
860
+ WhereClause : where ,
861
+ FromClause : & ast.List {},
862
+ WithClause : nil , // TODO: support with clause
863
+ }
864
+ if n , ok := n .(interface {
865
+ Returning_clause () parser.IReturning_clauseContext
866
+ }); ok {
867
+ stmt .ReturningList = c .convertReturning_caluseContext (n .Returning_clause ())
868
+ } else {
869
+ stmt .ReturningList = c .convertReturning_caluseContext (nil )
834
870
}
871
+ if n , ok := n .(interface {
872
+ Limit_stmt () parser.ILimit_stmtContext
873
+ }); ok {
874
+ limitCount , _ := c .convertLimit_stmtContext (n .Limit_stmt ())
875
+ stmt .LimitCount = limitCount
876
+ }
877
+ return stmt
835
878
}
836
879
837
880
func (c * cc ) convertBetweenExpr (n * parser.Expr_betweenContext ) ast.Node {
@@ -865,6 +908,9 @@ func (c *cc) convert(node node) ast.Node {
865
908
case * parser.Delete_stmtContext :
866
909
return c .convertDelete_stmtContext (n )
867
910
911
+ case * parser.Delete_stmt_limitedContext :
912
+ return c .convertDelete_stmtContext (n )
913
+
868
914
case * parser.ExprContext :
869
915
return c .convertExprContext (n )
870
916
@@ -917,6 +963,9 @@ func (c *cc) convert(node node) ast.Node {
917
963
case * parser.Update_stmtContext :
918
964
return c .convertUpdate_stmtContext (n )
919
965
966
+ case * parser.Update_stmt_limitedContext :
967
+ return c .convertUpdate_stmtContext (n )
968
+
920
969
default :
921
970
return todo ("convert(case=default)" , n )
922
971
}
0 commit comments