1
- use std:: collections:: { HashMap , HashSet } ;
1
+ use std:: {
2
+ cmp,
3
+ collections:: { HashMap , HashSet } ,
4
+ } ;
2
5
3
6
use pgt_schema_cache:: SchemaCache ;
4
7
use pgt_treesitter_queries:: {
@@ -8,7 +11,7 @@ use pgt_treesitter_queries::{
8
11
9
12
use crate :: sanitization:: SanitizedCompletionParams ;
10
13
11
- #[ derive( Debug , PartialEq , Eq , Hash ) ]
14
+ #[ derive( Debug , PartialEq , Eq , Hash , Clone ) ]
12
15
pub enum WrappingClause < ' a > {
13
16
Select ,
14
17
Where ,
@@ -19,6 +22,7 @@ pub enum WrappingClause<'a> {
19
22
Update ,
20
23
Delete ,
21
24
ColumnDefinitions ,
25
+ Insert ,
22
26
}
23
27
24
28
#[ derive( PartialEq , Eq , Debug ) ]
@@ -46,6 +50,7 @@ pub enum WrappingNode {
46
50
Relation ,
47
51
BinaryExpression ,
48
52
Assignment ,
53
+ List ,
49
54
}
50
55
51
56
impl TryFrom < & str > for WrappingNode {
@@ -56,6 +61,7 @@ impl TryFrom<&str> for WrappingNode {
56
61
"relation" => Ok ( Self :: Relation ) ,
57
62
"assignment" => Ok ( Self :: Assignment ) ,
58
63
"binary_expression" => Ok ( Self :: BinaryExpression ) ,
64
+ "list" => Ok ( Self :: List ) ,
59
65
_ => {
60
66
let message = format ! ( "Unimplemented Relation: {}" , value) ;
61
67
@@ -142,13 +148,6 @@ impl<'a> CompletionContext<'a> {
142
148
ctx. gather_tree_context ( ) ;
143
149
ctx. gather_info_from_ts_queries ( ) ;
144
150
145
- if cfg ! ( test) {
146
- println ! ( "{:#?}" , ctx. wrapping_clause_type) ;
147
- println ! ( "{:#?}" , ctx. wrapping_node_kind) ;
148
- println ! ( "{:#?}" , ctx. is_in_error_node) ;
149
- println ! ( "{:#?}" , ctx. text) ;
150
- }
151
-
152
151
ctx
153
152
}
154
153
@@ -240,10 +239,20 @@ impl<'a> CompletionContext<'a> {
240
239
* `select * from use {}` becomes `select * from use{}`.
241
240
*/
242
241
let current_node = cursor. node ( ) ;
243
- while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
244
- self . position -= 1 ;
242
+
243
+ let mut chars = self . text . chars ( ) ;
244
+
245
+ if chars
246
+ . nth ( self . position )
247
+ . is_some_and ( |c| !c. is_ascii_whitespace ( ) )
248
+ {
249
+ self . position = cmp:: min ( self . position + 1 , self . text . len ( ) ) ;
250
+ } else {
251
+ self . position = cmp:: min ( self . position , self . text . len ( ) ) ;
245
252
}
246
253
254
+ cursor. goto_first_child_for_byte ( self . position ) ;
255
+
247
256
self . gather_context_from_node ( cursor, current_node) ;
248
257
}
249
258
@@ -276,23 +285,11 @@ impl<'a> CompletionContext<'a> {
276
285
277
286
// try to gather context from the siblings if we're within an error node.
278
287
if self . is_in_error_node {
279
- let mut next_sibling = current_node. next_named_sibling ( ) ;
280
- while let Some ( n) = next_sibling {
281
- if let Some ( clause_type) = self . get_wrapping_clause_from_keyword_node ( n) {
282
- self . wrapping_clause_type = Some ( clause_type) ;
283
- break ;
284
- } else {
285
- next_sibling = n. next_named_sibling ( ) ;
286
- }
288
+ if let Some ( clause_type) = self . get_wrapping_clause_from_siblings ( current_node) {
289
+ self . wrapping_clause_type = Some ( clause_type) ;
287
290
}
288
- let mut prev_sibling = current_node. prev_named_sibling ( ) ;
289
- while let Some ( n) = prev_sibling {
290
- if let Some ( clause_type) = self . get_wrapping_clause_from_keyword_node ( n) {
291
- self . wrapping_clause_type = Some ( clause_type) ;
292
- break ;
293
- } else {
294
- prev_sibling = n. prev_named_sibling ( ) ;
295
- }
291
+ if let Some ( wrapping_node) = self . get_wrapping_node_from_siblings ( current_node) {
292
+ self . wrapping_node_kind = Some ( wrapping_node)
296
293
}
297
294
}
298
295
@@ -317,7 +314,7 @@ impl<'a> CompletionContext<'a> {
317
314
self . get_wrapping_clause_from_current_node ( current_node, & mut cursor) ;
318
315
}
319
316
320
- "relation" | "binary_expression" | "assignment" => {
317
+ "relation" | "binary_expression" | "assignment" | "list" => {
321
318
self . wrapping_node_kind = current_node_kind. try_into ( ) . ok ( ) ;
322
319
}
323
320
@@ -338,31 +335,89 @@ impl<'a> CompletionContext<'a> {
338
335
self . gather_context_from_node ( cursor, current_node) ;
339
336
}
340
337
341
- fn get_wrapping_clause_from_keyword_node (
338
+ fn get_first_sibling ( & self , node : tree_sitter:: Node < ' a > ) -> tree_sitter:: Node < ' a > {
339
+ let mut first_sibling = node;
340
+ while let Some ( n) = first_sibling. prev_sibling ( ) {
341
+ first_sibling = n;
342
+ }
343
+ first_sibling
344
+ }
345
+
346
+ fn get_wrapping_node_from_siblings ( & self , node : tree_sitter:: Node < ' a > ) -> Option < WrappingNode > {
347
+ self . wrapping_clause_type
348
+ . as_ref ( )
349
+ . and_then ( |clause| match clause {
350
+ WrappingClause :: Insert => {
351
+ if node. prev_sibling ( ) . is_some_and ( |n| n. kind ( ) == "(" )
352
+ || node. next_sibling ( ) . is_some_and ( |n| n. kind ( ) == ")" )
353
+ {
354
+ Some ( WrappingNode :: List )
355
+ } else {
356
+ None
357
+ }
358
+ }
359
+ _ => None ,
360
+ } )
361
+ }
362
+
363
+ fn get_wrapping_clause_from_siblings (
342
364
& self ,
343
365
node : tree_sitter:: Node < ' a > ,
344
366
) -> Option < WrappingClause < ' a > > {
345
- if node. kind ( ) . starts_with ( "keyword_" ) {
346
- if let Some ( txt) = self . get_ts_node_content ( node) . and_then ( |txt| match txt {
347
- NodeText :: Original ( txt) => Some ( txt) ,
348
- NodeText :: Replaced => None ,
349
- } ) {
350
- match txt {
351
- "where" => return Some ( WrappingClause :: Where ) ,
352
- "update" => return Some ( WrappingClause :: Update ) ,
353
- "select" => return Some ( WrappingClause :: Select ) ,
354
- "delete" => return Some ( WrappingClause :: Delete ) ,
355
- "from" => return Some ( WrappingClause :: From ) ,
356
- "join" => {
357
- // TODO: not sure if we can infer it here.
358
- return Some ( WrappingClause :: Join { on_node : None } ) ;
367
+ let clause_combinations: Vec < ( WrappingClause , & [ & ' static str ] ) > = vec ! [
368
+ ( WrappingClause :: Where , & [ "where" ] ) ,
369
+ ( WrappingClause :: Update , & [ "update" ] ) ,
370
+ ( WrappingClause :: Select , & [ "select" ] ) ,
371
+ ( WrappingClause :: Delete , & [ "delete" ] ) ,
372
+ ( WrappingClause :: Insert , & [ "insert" , "into" ] ) ,
373
+ ( WrappingClause :: From , & [ "from" ] ) ,
374
+ ( WrappingClause :: Join { on_node: None } , & [ "join" ] ) ,
375
+ ] ;
376
+
377
+ let first_sibling = self . get_first_sibling ( node) ;
378
+
379
+ /*
380
+ * For each clause, we'll iterate from first_sibling to the next ones,
381
+ * either until the end or until we land on the node under the cursor.
382
+ * We'll score the `WrappingClause` by how many tokens it matches in order.
383
+ */
384
+ let mut clauses_with_score: Vec < ( WrappingClause , usize ) > = clause_combinations
385
+ . into_iter ( )
386
+ . map ( |( clause, tokens) | {
387
+ let mut idx = 0 ;
388
+
389
+ let mut sibling = Some ( first_sibling) ;
390
+ while let Some ( sib) = sibling {
391
+ if sib. end_byte ( ) >= node. end_byte ( ) || idx >= tokens. len ( ) {
392
+ break ;
359
393
}
360
- _ => { }
394
+
395
+ if let Some ( sibling_content) =
396
+ self . get_ts_node_content ( sib) . and_then ( |txt| match txt {
397
+ NodeText :: Original ( txt) => Some ( txt) ,
398
+ NodeText :: Replaced => None ,
399
+ } )
400
+ {
401
+ if sibling_content == tokens[ idx] {
402
+ idx += 1 ;
403
+ }
404
+ } else {
405
+ break ;
406
+ }
407
+
408
+ sibling = sib. next_sibling ( ) ;
361
409
}
362
- } ;
363
- }
364
410
365
- None
411
+ ( clause, idx)
412
+ } )
413
+ . collect ( ) ;
414
+
415
+ clauses_with_score. sort_by ( |( _, score_a) , ( _, score_b) | score_b. cmp ( score_a) ) ;
416
+ clauses_with_score
417
+ . iter ( )
418
+ . filter ( |( _, score) | * score > 0 )
419
+ . next ( )
420
+ . map ( |c| c. 0 . clone ( ) )
366
421
}
367
422
368
423
fn get_wrapping_clause_from_current_node (
@@ -377,6 +432,7 @@ impl<'a> CompletionContext<'a> {
377
432
"delete" => Some ( WrappingClause :: Delete ) ,
378
433
"from" => Some ( WrappingClause :: From ) ,
379
434
"column_definitions" => Some ( WrappingClause :: ColumnDefinitions ) ,
435
+ "insert" => Some ( WrappingClause :: Insert ) ,
380
436
"join" => {
381
437
// sadly, we need to manually iterate over the children –
382
438
// `node.child_by_field_id(..)` does not work as expected
0 commit comments