1
+ mod policy_parser;
2
+
1
3
use std:: collections:: { HashMap , HashSet } ;
2
4
3
5
use pgt_schema_cache:: SchemaCache ;
6
+ use pgt_text_size:: TextRange ;
4
7
use pgt_treesitter_queries:: {
5
8
TreeSitterQueriesExecutor ,
6
9
queries:: { self , QueryResult } ,
7
10
} ;
8
11
9
- use crate :: sanitization:: SanitizedCompletionParams ;
12
+ use crate :: {
13
+ NodeText ,
14
+ context:: policy_parser:: { PolicyParser , PolicyStmtKind } ,
15
+ sanitization:: SanitizedCompletionParams ,
16
+ } ;
10
17
11
18
#[ derive( Debug , PartialEq , Eq , Hash ) ]
12
19
pub enum WrappingClause < ' a > {
@@ -18,12 +25,8 @@ pub enum WrappingClause<'a> {
18
25
} ,
19
26
Update ,
20
27
Delete ,
21
- }
22
-
23
- #[ derive( PartialEq , Eq , Debug ) ]
24
- pub ( crate ) enum NodeText < ' a > {
25
- Replaced ,
26
- Original ( & ' a str ) ,
28
+ PolicyName ,
29
+ ToRoleAssignment ,
27
30
}
28
31
29
32
#[ derive( PartialEq , Eq , Hash , Debug ) ]
@@ -47,6 +50,45 @@ pub enum WrappingNode {
47
50
Assignment ,
48
51
}
49
52
53
+ #[ derive( Debug ) ]
54
+ pub ( crate ) enum NodeUnderCursor < ' a > {
55
+ TsNode ( tree_sitter:: Node < ' a > ) ,
56
+ CustomNode {
57
+ text : NodeText ,
58
+ range : TextRange ,
59
+ kind : String ,
60
+ } ,
61
+ }
62
+
63
+ impl NodeUnderCursor < ' _ > {
64
+ pub fn start_byte ( & self ) -> usize {
65
+ match self {
66
+ NodeUnderCursor :: TsNode ( node) => node. start_byte ( ) ,
67
+ NodeUnderCursor :: CustomNode { range, .. } => range. start ( ) . into ( ) ,
68
+ }
69
+ }
70
+
71
+ pub fn end_byte ( & self ) -> usize {
72
+ match self {
73
+ NodeUnderCursor :: TsNode ( node) => node. end_byte ( ) ,
74
+ NodeUnderCursor :: CustomNode { range, .. } => range. end ( ) . into ( ) ,
75
+ }
76
+ }
77
+
78
+ pub fn kind ( & self ) -> & str {
79
+ match self {
80
+ NodeUnderCursor :: TsNode ( node) => node. kind ( ) ,
81
+ NodeUnderCursor :: CustomNode { kind, .. } => kind. as_str ( ) ,
82
+ }
83
+ }
84
+ }
85
+
86
+ impl < ' a > From < tree_sitter:: Node < ' a > > for NodeUnderCursor < ' a > {
87
+ fn from ( node : tree_sitter:: Node < ' a > ) -> Self {
88
+ NodeUnderCursor :: TsNode ( node)
89
+ }
90
+ }
91
+
50
92
impl TryFrom < & str > for WrappingNode {
51
93
type Error = String ;
52
94
@@ -77,7 +119,7 @@ impl TryFrom<String> for WrappingNode {
77
119
}
78
120
79
121
pub ( crate ) struct CompletionContext < ' a > {
80
- pub node_under_cursor : Option < tree_sitter :: Node < ' a > > ,
122
+ pub node_under_cursor : Option < NodeUnderCursor < ' a > > ,
81
123
82
124
pub tree : & ' a tree_sitter:: Tree ,
83
125
pub text : & ' a str ,
@@ -137,12 +179,49 @@ impl<'a> CompletionContext<'a> {
137
179
is_in_error_node : false ,
138
180
} ;
139
181
140
- ctx. gather_tree_context ( ) ;
141
- ctx. gather_info_from_ts_queries ( ) ;
182
+ // policy handling is important to Supabase, but they are a PostgreSQL specific extension,
183
+ // so the tree_sitter_sql language does not support it.
184
+ // We infer the context manually.
185
+ if PolicyParser :: looks_like_policy_stmt ( & params. text ) {
186
+ ctx. gather_policy_context ( ) ;
187
+ } else {
188
+ ctx. gather_tree_context ( ) ;
189
+ ctx. gather_info_from_ts_queries ( ) ;
190
+ }
142
191
143
192
ctx
144
193
}
145
194
195
+ fn gather_policy_context ( & mut self ) {
196
+ let policy_context = PolicyParser :: get_context ( self . text , self . position ) ;
197
+
198
+ self . node_under_cursor = Some ( NodeUnderCursor :: CustomNode {
199
+ text : policy_context. node_text . into ( ) ,
200
+ range : policy_context. node_range ,
201
+ kind : policy_context. node_kind . clone ( ) ,
202
+ } ) ;
203
+
204
+ if policy_context. node_kind == "policy_table" {
205
+ self . schema_or_alias_name = policy_context. schema_name . clone ( ) ;
206
+ }
207
+
208
+ if policy_context. table_name . is_some ( ) {
209
+ let mut new = HashSet :: new ( ) ;
210
+ new. insert ( policy_context. table_name . unwrap ( ) ) ;
211
+ self . mentioned_relations
212
+ . insert ( policy_context. schema_name , new) ;
213
+ }
214
+
215
+ self . wrapping_clause_type = match policy_context. node_kind . as_str ( ) {
216
+ "policy_name" if policy_context. statement_kind != PolicyStmtKind :: Create => {
217
+ Some ( WrappingClause :: PolicyName )
218
+ }
219
+ "policy_role" => Some ( WrappingClause :: ToRoleAssignment ) ,
220
+ "policy_table" => Some ( WrappingClause :: From ) ,
221
+ _ => None ,
222
+ } ;
223
+ }
224
+
146
225
fn gather_info_from_ts_queries ( & mut self ) {
147
226
let stmt_range = self . wrapping_statement_range . as_ref ( ) ;
148
227
let sql = self . text ;
@@ -195,24 +274,30 @@ impl<'a> CompletionContext<'a> {
195
274
}
196
275
}
197
276
198
- pub fn get_ts_node_content ( & self , ts_node : tree_sitter:: Node < ' a > ) -> Option < NodeText < ' a > > {
277
+ fn get_ts_node_content ( & self , ts_node : & tree_sitter:: Node < ' a > ) -> Option < NodeText > {
199
278
let source = self . text ;
200
279
ts_node. utf8_text ( source. as_bytes ( ) ) . ok ( ) . map ( |txt| {
201
280
if SanitizedCompletionParams :: is_sanitized_token ( txt) {
202
281
NodeText :: Replaced
203
282
} else {
204
- NodeText :: Original ( txt)
283
+ NodeText :: Original ( txt. into ( ) )
205
284
}
206
285
} )
207
286
}
208
287
209
288
pub fn get_node_under_cursor_content ( & self ) -> Option < String > {
210
- self . node_under_cursor
211
- . and_then ( |n| self . get_ts_node_content ( n) )
212
- . and_then ( |txt| match txt {
289
+ match self . node_under_cursor . as_ref ( ) ? {
290
+ NodeUnderCursor :: TsNode ( node) => {
291
+ self . get_ts_node_content ( node) . and_then ( |nt| match nt {
292
+ NodeText :: Replaced => None ,
293
+ NodeText :: Original ( c) => Some ( c. to_string ( ) ) ,
294
+ } )
295
+ }
296
+ NodeUnderCursor :: CustomNode { text, .. } => match text {
213
297
NodeText :: Replaced => None ,
214
298
NodeText :: Original ( c) => Some ( c. to_string ( ) ) ,
215
- } )
299
+ } ,
300
+ }
216
301
}
217
302
218
303
fn gather_tree_context ( & mut self ) {
@@ -250,7 +335,7 @@ impl<'a> CompletionContext<'a> {
250
335
251
336
// prevent infinite recursion – this can happen if we only have a PROGRAM node
252
337
if current_node_kind == parent_node_kind {
253
- self . node_under_cursor = Some ( current_node) ;
338
+ self . node_under_cursor = Some ( NodeUnderCursor :: from ( current_node) ) ;
254
339
return ;
255
340
}
256
341
@@ -289,7 +374,7 @@ impl<'a> CompletionContext<'a> {
289
374
290
375
match current_node_kind {
291
376
"object_reference" | "field" => {
292
- let content = self . get_ts_node_content ( current_node) ;
377
+ let content = self . get_ts_node_content ( & current_node) ;
293
378
if let Some ( node_txt) = content {
294
379
match node_txt {
295
380
NodeText :: Original ( txt) => {
@@ -321,7 +406,7 @@ impl<'a> CompletionContext<'a> {
321
406
322
407
// We have arrived at the leaf node
323
408
if current_node. child_count ( ) == 0 {
324
- self . node_under_cursor = Some ( current_node) ;
409
+ self . node_under_cursor = Some ( NodeUnderCursor :: from ( current_node) ) ;
325
410
return ;
326
411
}
327
412
@@ -334,11 +419,11 @@ impl<'a> CompletionContext<'a> {
334
419
node : tree_sitter:: Node < ' a > ,
335
420
) -> Option < WrappingClause < ' a > > {
336
421
if node. kind ( ) . starts_with ( "keyword_" ) {
337
- if let Some ( txt) = self . get_ts_node_content ( node) . and_then ( |txt| match txt {
422
+ if let Some ( txt) = self . get_ts_node_content ( & node) . and_then ( |txt| match txt {
338
423
NodeText :: Original ( txt) => Some ( txt) ,
339
424
NodeText :: Replaced => None ,
340
425
} ) {
341
- match txt {
426
+ match txt. as_str ( ) {
342
427
"where" => return Some ( WrappingClause :: Where ) ,
343
428
"update" => return Some ( WrappingClause :: Update ) ,
344
429
"select" => return Some ( WrappingClause :: Select ) ,
@@ -388,11 +473,14 @@ impl<'a> CompletionContext<'a> {
388
473
#[ cfg( test) ]
389
474
mod tests {
390
475
use crate :: {
391
- context:: { CompletionContext , NodeText , WrappingClause } ,
476
+ NodeText ,
477
+ context:: { CompletionContext , WrappingClause } ,
392
478
sanitization:: SanitizedCompletionParams ,
393
479
test_helper:: { CURSOR_POS , get_text_and_position} ,
394
480
} ;
395
481
482
+ use super :: NodeUnderCursor ;
483
+
396
484
fn get_tree ( input : & str ) -> tree_sitter:: Tree {
397
485
let mut parser = tree_sitter:: Parser :: new ( ) ;
398
486
parser
@@ -551,17 +639,22 @@ mod tests {
551
639
552
640
let ctx = CompletionContext :: new ( & params) ;
553
641
554
- let node = ctx. node_under_cursor . unwrap ( ) ;
642
+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
555
643
556
- assert_eq ! (
557
- ctx. get_ts_node_content( node) ,
558
- Some ( NodeText :: Original ( "select" ) )
559
- ) ;
644
+ match node {
645
+ NodeUnderCursor :: TsNode ( node) => {
646
+ assert_eq ! (
647
+ ctx. get_ts_node_content( node) ,
648
+ Some ( NodeText :: Original ( "select" . into( ) ) )
649
+ ) ;
560
650
561
- assert_eq ! (
562
- ctx. wrapping_clause_type,
563
- Some ( crate :: context:: WrappingClause :: Select )
564
- ) ;
651
+ assert_eq ! (
652
+ ctx. wrapping_clause_type,
653
+ Some ( crate :: context:: WrappingClause :: Select )
654
+ ) ;
655
+ }
656
+ _ => unreachable ! ( ) ,
657
+ }
565
658
}
566
659
}
567
660
@@ -582,12 +675,17 @@ mod tests {
582
675
583
676
let ctx = CompletionContext :: new ( & params) ;
584
677
585
- let node = ctx. node_under_cursor . unwrap ( ) ;
678
+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
586
679
587
- assert_eq ! (
588
- ctx. get_ts_node_content( node) ,
589
- Some ( NodeText :: Original ( "from" ) )
590
- ) ;
680
+ match node {
681
+ NodeUnderCursor :: TsNode ( node) => {
682
+ assert_eq ! (
683
+ ctx. get_ts_node_content( node) ,
684
+ Some ( NodeText :: Original ( "from" . into( ) ) )
685
+ ) ;
686
+ }
687
+ _ => unreachable ! ( ) ,
688
+ }
591
689
}
592
690
593
691
#[ test]
@@ -607,10 +705,18 @@ mod tests {
607
705
608
706
let ctx = CompletionContext :: new ( & params) ;
609
707
610
- let node = ctx. node_under_cursor . unwrap ( ) ;
708
+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
611
709
612
- assert_eq ! ( ctx. get_ts_node_content( node) , Some ( NodeText :: Original ( "" ) ) ) ;
613
- assert_eq ! ( ctx. wrapping_clause_type, None ) ;
710
+ match node {
711
+ NodeUnderCursor :: TsNode ( node) => {
712
+ assert_eq ! (
713
+ ctx. get_ts_node_content( node) ,
714
+ Some ( NodeText :: Original ( "" . into( ) ) )
715
+ ) ;
716
+ assert_eq ! ( ctx. wrapping_clause_type, None ) ;
717
+ }
718
+ _ => unreachable ! ( ) ,
719
+ }
614
720
}
615
721
616
722
#[ test]
@@ -632,12 +738,17 @@ mod tests {
632
738
633
739
let ctx = CompletionContext :: new ( & params) ;
634
740
635
- let node = ctx. node_under_cursor . unwrap ( ) ;
741
+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
636
742
637
- assert_eq ! (
638
- ctx. get_ts_node_content( node) ,
639
- Some ( NodeText :: Original ( "fro" ) )
640
- ) ;
641
- assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
743
+ match node {
744
+ NodeUnderCursor :: TsNode ( node) => {
745
+ assert_eq ! (
746
+ ctx. get_ts_node_content( node) ,
747
+ Some ( NodeText :: Original ( "fro" . into( ) ) )
748
+ ) ;
749
+ assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
750
+ }
751
+ _ => unreachable ! ( ) ,
752
+ }
642
753
}
643
754
}
0 commit comments