@@ -7,7 +7,9 @@ use pgt_treesitter_queries::{
7
7
queries:: { self , QueryResult } ,
8
8
} ;
9
9
10
- use crate :: sanitization:: SanitizedCompletionParams ;
10
+ use crate :: {
11
+ NodeText , context:: policy_parser:: PolicyParser , sanitization:: SanitizedCompletionParams ,
12
+ } ;
11
13
12
14
#[ derive( Debug , PartialEq , Eq ) ]
13
15
pub enum WrappingClause < ' a > {
@@ -19,12 +21,8 @@ pub enum WrappingClause<'a> {
19
21
} ,
20
22
Update ,
21
23
Delete ,
22
- }
23
-
24
- #[ derive( PartialEq , Eq , Debug ) ]
25
- pub ( crate ) enum NodeText < ' a > {
26
- Replaced ,
27
- Original ( & ' a str ) ,
24
+ PolicyName ,
25
+ ToRole ,
28
26
}
29
27
30
28
/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
@@ -45,7 +43,7 @@ pub enum WrappingNode {
45
43
pub ( crate ) enum NodeUnderCursor < ' a > {
46
44
TsNode ( tree_sitter:: Node < ' a > ) ,
47
45
CustomNode {
48
- text : NodeText < ' a > ,
46
+ text : NodeText ,
49
47
range : TextRange ,
50
48
kind : String ,
51
49
} ,
@@ -172,14 +170,35 @@ impl<'a> CompletionContext<'a> {
172
170
// policy handling is important to Supabase, but they are a PostgreSQL specific extension,
173
171
// so the tree_sitter_sql language does not support it.
174
172
// We infer the context manually.
175
- // if params.text.to_lowercase().starts_with("create policy")
176
- // || params.text.to_lowercase().starts_with("alter policy")
177
- // || params.text.to_lowercase().starts_with("drop policy")
178
- // {
179
- // } else {
180
- ctx. gather_tree_context ( ) ;
181
- ctx. gather_info_from_ts_queries ( ) ;
182
- // }
173
+ if params. text . to_lowercase ( ) . starts_with ( "create policy" )
174
+ || params. text . to_lowercase ( ) . starts_with ( "alter policy" )
175
+ || params. text . to_lowercase ( ) . starts_with ( "drop policy" )
176
+ {
177
+ let policy_context = PolicyParser :: get_context ( & ctx. text , ctx. position ) ;
178
+
179
+ ctx. node_under_cursor = Some ( NodeUnderCursor :: CustomNode {
180
+ text : policy_context. node_text . into ( ) ,
181
+ range : policy_context. node_range ,
182
+ kind : policy_context. node_kind . clone ( ) ,
183
+ } ) ;
184
+
185
+ if policy_context. table_name . is_some ( ) {
186
+ let mut new = HashSet :: new ( ) ;
187
+ new. insert ( policy_context. table_name . unwrap ( ) ) ;
188
+ ctx. mentioned_relations
189
+ . insert ( policy_context. schema_name , new) ;
190
+ }
191
+
192
+ ctx. wrapping_clause_type = match policy_context. node_kind . as_str ( ) {
193
+ "policy_name" => Some ( WrappingClause :: PolicyName ) ,
194
+ "policy_role" => Some ( WrappingClause :: ToRole ) ,
195
+ "policy_table" => Some ( WrappingClause :: From ) ,
196
+ _ => None ,
197
+ } ;
198
+ } else {
199
+ ctx. gather_tree_context ( ) ;
200
+ ctx. gather_info_from_ts_queries ( ) ;
201
+ }
183
202
184
203
tracing:: warn!( "sql: {}" , ctx. text) ;
185
204
tracing:: warn!( "position: {}" , ctx. position) ;
@@ -237,13 +256,13 @@ impl<'a> CompletionContext<'a> {
237
256
}
238
257
}
239
258
240
- fn get_ts_node_content ( & self , ts_node : & tree_sitter:: Node < ' a > ) -> Option < NodeText < ' a > > {
259
+ fn get_ts_node_content ( & self , ts_node : & tree_sitter:: Node < ' a > ) -> Option < NodeText > {
241
260
let source = self . text ;
242
261
ts_node. utf8_text ( source. as_bytes ( ) ) . ok ( ) . map ( |txt| {
243
262
if SanitizedCompletionParams :: is_sanitized_token ( txt) {
244
263
NodeText :: Replaced
245
264
} else {
246
- NodeText :: Original ( txt)
265
+ NodeText :: Original ( txt. into ( ) )
247
266
}
248
267
} )
249
268
}
@@ -386,7 +405,7 @@ impl<'a> CompletionContext<'a> {
386
405
NodeText :: Original ( txt) => Some ( txt) ,
387
406
NodeText :: Replaced => None ,
388
407
} ) {
389
- match txt {
408
+ match txt. as_str ( ) {
390
409
"where" => return Some ( WrappingClause :: Where ) ,
391
410
"update" => return Some ( WrappingClause :: Update ) ,
392
411
"select" => return Some ( WrappingClause :: Select ) ,
@@ -436,7 +455,8 @@ impl<'a> CompletionContext<'a> {
436
455
#[ cfg( test) ]
437
456
mod tests {
438
457
use crate :: {
439
- context:: { CompletionContext , NodeText , WrappingClause } ,
458
+ NodeText ,
459
+ context:: { CompletionContext , WrappingClause } ,
440
460
sanitization:: SanitizedCompletionParams ,
441
461
test_helper:: { CURSOR_POS , get_text_and_position} ,
442
462
} ;
@@ -607,7 +627,7 @@ mod tests {
607
627
NodeUnderCursor :: TsNode ( node) => {
608
628
assert_eq ! (
609
629
ctx. get_ts_node_content( node) ,
610
- Some ( NodeText :: Original ( "select" ) )
630
+ Some ( NodeText :: Original ( "select" . into ( ) ) )
611
631
) ;
612
632
613
633
assert_eq ! (
@@ -643,7 +663,7 @@ mod tests {
643
663
NodeUnderCursor :: TsNode ( node) => {
644
664
assert_eq ! (
645
665
ctx. get_ts_node_content( & node) ,
646
- Some ( NodeText :: Original ( "from" ) )
666
+ Some ( NodeText :: Original ( "from" . into ( ) ) )
647
667
) ;
648
668
}
649
669
_ => unreachable ! ( ) ,
@@ -671,7 +691,10 @@ mod tests {
671
691
672
692
match node {
673
693
NodeUnderCursor :: TsNode ( node) => {
674
- assert_eq ! ( ctx. get_ts_node_content( & node) , Some ( NodeText :: Original ( "" ) ) ) ;
694
+ assert_eq ! (
695
+ ctx. get_ts_node_content( & node) ,
696
+ Some ( NodeText :: Original ( "" . into( ) ) )
697
+ ) ;
675
698
assert_eq ! ( ctx. wrapping_clause_type, None ) ;
676
699
}
677
700
_ => unreachable ! ( ) ,
@@ -703,7 +726,7 @@ mod tests {
703
726
NodeUnderCursor :: TsNode ( node) => {
704
727
assert_eq ! (
705
728
ctx. get_ts_node_content( & node) ,
706
- Some ( NodeText :: Original ( "fro" ) )
729
+ Some ( NodeText :: Original ( "fro" . into ( ) ) )
707
730
) ;
708
731
assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
709
732
}
0 commit comments