Skip to content

Commit 400715f

Browse files
feat(completions): complete policies (#397)
1 parent 4e57995 commit 400715f

File tree

18 files changed

+1086
-147
lines changed

18 files changed

+1086
-147
lines changed

.github/workflows/pull_request.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ jobs:
184184
uses: ./.github/actions/free-disk-space
185185
- name: Install toolchain
186186
uses: moonrepo/setup-rust@v1
187+
with:
188+
cache-base: main
187189
- name: Build main binary
188190
run: cargo build -p pgt_cli --release
189191
- name: Setup Bun
@@ -222,6 +224,10 @@ jobs:
222224
cache-base: main
223225
env:
224226
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
227+
- name: Ensure RustFMT on nightly toolchain
228+
run: rustup component add rustfmt --toolchain nightly
229+
- name: echo toolchain
230+
run: rustup show
225231
- name: Run the analyser codegen
226232
run: cargo run -p xtask_codegen -- analyser
227233
- name: Run the configuration codegen

crates/pgt_completions/src/complete.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ use crate::{
44
builder::CompletionBuilder,
55
context::CompletionContext,
66
item::CompletionItem,
7-
providers::{complete_columns, complete_functions, complete_schemas, complete_tables},
7+
providers::{
8+
complete_columns, complete_functions, complete_policies, complete_schemas, complete_tables,
9+
},
810
sanitization::SanitizedCompletionParams,
911
};
1012

@@ -33,6 +35,7 @@ pub fn complete(params: CompletionParams) -> Vec<CompletionItem> {
3335
complete_functions(&ctx, &mut builder);
3436
complete_columns(&ctx, &mut builder);
3537
complete_schemas(&ctx, &mut builder);
38+
complete_policies(&ctx, &mut builder);
3639

3740
builder.finish()
3841
}

crates/pgt_completions/src/context.rs renamed to crates/pgt_completions/src/context/mod.rs

Lines changed: 156 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
mod policy_parser;
2+
13
use std::collections::{HashMap, HashSet};
24

35
use pgt_schema_cache::SchemaCache;
6+
use pgt_text_size::TextRange;
47
use pgt_treesitter_queries::{
58
TreeSitterQueriesExecutor,
69
queries::{self, QueryResult},
710
};
811

9-
use crate::sanitization::SanitizedCompletionParams;
12+
use crate::{
13+
NodeText,
14+
context::policy_parser::{PolicyParser, PolicyStmtKind},
15+
sanitization::SanitizedCompletionParams,
16+
};
1017

1118
#[derive(Debug, PartialEq, Eq, Hash)]
1219
pub enum WrappingClause<'a> {
@@ -18,12 +25,8 @@ pub enum WrappingClause<'a> {
1825
},
1926
Update,
2027
Delete,
21-
}
22-
23-
#[derive(PartialEq, Eq, Debug)]
24-
pub(crate) enum NodeText<'a> {
25-
Replaced,
26-
Original(&'a str),
28+
PolicyName,
29+
ToRoleAssignment,
2730
}
2831

2932
#[derive(PartialEq, Eq, Hash, Debug)]
@@ -47,6 +50,45 @@ pub enum WrappingNode {
4750
Assignment,
4851
}
4952

53+
#[derive(Debug)]
54+
pub(crate) enum NodeUnderCursor<'a> {
55+
TsNode(tree_sitter::Node<'a>),
56+
CustomNode {
57+
text: NodeText,
58+
range: TextRange,
59+
kind: String,
60+
},
61+
}
62+
63+
impl NodeUnderCursor<'_> {
64+
pub fn start_byte(&self) -> usize {
65+
match self {
66+
NodeUnderCursor::TsNode(node) => node.start_byte(),
67+
NodeUnderCursor::CustomNode { range, .. } => range.start().into(),
68+
}
69+
}
70+
71+
pub fn end_byte(&self) -> usize {
72+
match self {
73+
NodeUnderCursor::TsNode(node) => node.end_byte(),
74+
NodeUnderCursor::CustomNode { range, .. } => range.end().into(),
75+
}
76+
}
77+
78+
pub fn kind(&self) -> &str {
79+
match self {
80+
NodeUnderCursor::TsNode(node) => node.kind(),
81+
NodeUnderCursor::CustomNode { kind, .. } => kind.as_str(),
82+
}
83+
}
84+
}
85+
86+
impl<'a> From<tree_sitter::Node<'a>> for NodeUnderCursor<'a> {
87+
fn from(node: tree_sitter::Node<'a>) -> Self {
88+
NodeUnderCursor::TsNode(node)
89+
}
90+
}
91+
5092
impl TryFrom<&str> for WrappingNode {
5193
type Error = String;
5294

@@ -77,7 +119,7 @@ impl TryFrom<String> for WrappingNode {
77119
}
78120

79121
pub(crate) struct CompletionContext<'a> {
80-
pub node_under_cursor: Option<tree_sitter::Node<'a>>,
122+
pub node_under_cursor: Option<NodeUnderCursor<'a>>,
81123

82124
pub tree: &'a tree_sitter::Tree,
83125
pub text: &'a str,
@@ -137,12 +179,49 @@ impl<'a> CompletionContext<'a> {
137179
is_in_error_node: false,
138180
};
139181

140-
ctx.gather_tree_context();
141-
ctx.gather_info_from_ts_queries();
182+
// policy handling is important to Supabase, but they are a PostgreSQL specific extension,
183+
// so the tree_sitter_sql language does not support it.
184+
// We infer the context manually.
185+
if PolicyParser::looks_like_policy_stmt(&params.text) {
186+
ctx.gather_policy_context();
187+
} else {
188+
ctx.gather_tree_context();
189+
ctx.gather_info_from_ts_queries();
190+
}
142191

143192
ctx
144193
}
145194

195+
fn gather_policy_context(&mut self) {
196+
let policy_context = PolicyParser::get_context(self.text, self.position);
197+
198+
self.node_under_cursor = Some(NodeUnderCursor::CustomNode {
199+
text: policy_context.node_text.into(),
200+
range: policy_context.node_range,
201+
kind: policy_context.node_kind.clone(),
202+
});
203+
204+
if policy_context.node_kind == "policy_table" {
205+
self.schema_or_alias_name = policy_context.schema_name.clone();
206+
}
207+
208+
if policy_context.table_name.is_some() {
209+
let mut new = HashSet::new();
210+
new.insert(policy_context.table_name.unwrap());
211+
self.mentioned_relations
212+
.insert(policy_context.schema_name, new);
213+
}
214+
215+
self.wrapping_clause_type = match policy_context.node_kind.as_str() {
216+
"policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => {
217+
Some(WrappingClause::PolicyName)
218+
}
219+
"policy_role" => Some(WrappingClause::ToRoleAssignment),
220+
"policy_table" => Some(WrappingClause::From),
221+
_ => None,
222+
};
223+
}
224+
146225
fn gather_info_from_ts_queries(&mut self) {
147226
let stmt_range = self.wrapping_statement_range.as_ref();
148227
let sql = self.text;
@@ -195,24 +274,30 @@ impl<'a> CompletionContext<'a> {
195274
}
196275
}
197276

198-
pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<NodeText<'a>> {
277+
fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option<NodeText> {
199278
let source = self.text;
200279
ts_node.utf8_text(source.as_bytes()).ok().map(|txt| {
201280
if SanitizedCompletionParams::is_sanitized_token(txt) {
202281
NodeText::Replaced
203282
} else {
204-
NodeText::Original(txt)
283+
NodeText::Original(txt.into())
205284
}
206285
})
207286
}
208287

209288
pub fn get_node_under_cursor_content(&self) -> Option<String> {
210-
self.node_under_cursor
211-
.and_then(|n| self.get_ts_node_content(n))
212-
.and_then(|txt| match txt {
289+
match self.node_under_cursor.as_ref()? {
290+
NodeUnderCursor::TsNode(node) => {
291+
self.get_ts_node_content(node).and_then(|nt| match nt {
292+
NodeText::Replaced => None,
293+
NodeText::Original(c) => Some(c.to_string()),
294+
})
295+
}
296+
NodeUnderCursor::CustomNode { text, .. } => match text {
213297
NodeText::Replaced => None,
214298
NodeText::Original(c) => Some(c.to_string()),
215-
})
299+
},
300+
}
216301
}
217302

218303
fn gather_tree_context(&mut self) {
@@ -250,7 +335,7 @@ impl<'a> CompletionContext<'a> {
250335

251336
// prevent infinite recursion – this can happen if we only have a PROGRAM node
252337
if current_node_kind == parent_node_kind {
253-
self.node_under_cursor = Some(current_node);
338+
self.node_under_cursor = Some(NodeUnderCursor::from(current_node));
254339
return;
255340
}
256341

@@ -289,7 +374,7 @@ impl<'a> CompletionContext<'a> {
289374

290375
match current_node_kind {
291376
"object_reference" | "field" => {
292-
let content = self.get_ts_node_content(current_node);
377+
let content = self.get_ts_node_content(&current_node);
293378
if let Some(node_txt) = content {
294379
match node_txt {
295380
NodeText::Original(txt) => {
@@ -321,7 +406,7 @@ impl<'a> CompletionContext<'a> {
321406

322407
// We have arrived at the leaf node
323408
if current_node.child_count() == 0 {
324-
self.node_under_cursor = Some(current_node);
409+
self.node_under_cursor = Some(NodeUnderCursor::from(current_node));
325410
return;
326411
}
327412

@@ -334,11 +419,11 @@ impl<'a> CompletionContext<'a> {
334419
node: tree_sitter::Node<'a>,
335420
) -> Option<WrappingClause<'a>> {
336421
if node.kind().starts_with("keyword_") {
337-
if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt {
422+
if let Some(txt) = self.get_ts_node_content(&node).and_then(|txt| match txt {
338423
NodeText::Original(txt) => Some(txt),
339424
NodeText::Replaced => None,
340425
}) {
341-
match txt {
426+
match txt.as_str() {
342427
"where" => return Some(WrappingClause::Where),
343428
"update" => return Some(WrappingClause::Update),
344429
"select" => return Some(WrappingClause::Select),
@@ -388,11 +473,14 @@ impl<'a> CompletionContext<'a> {
388473
#[cfg(test)]
389474
mod tests {
390475
use crate::{
391-
context::{CompletionContext, NodeText, WrappingClause},
476+
NodeText,
477+
context::{CompletionContext, WrappingClause},
392478
sanitization::SanitizedCompletionParams,
393479
test_helper::{CURSOR_POS, get_text_and_position},
394480
};
395481

482+
use super::NodeUnderCursor;
483+
396484
fn get_tree(input: &str) -> tree_sitter::Tree {
397485
let mut parser = tree_sitter::Parser::new();
398486
parser
@@ -551,17 +639,22 @@ mod tests {
551639

552640
let ctx = CompletionContext::new(&params);
553641

554-
let node = ctx.node_under_cursor.unwrap();
642+
let node = ctx.node_under_cursor.as_ref().unwrap();
555643

556-
assert_eq!(
557-
ctx.get_ts_node_content(node),
558-
Some(NodeText::Original("select"))
559-
);
644+
match node {
645+
NodeUnderCursor::TsNode(node) => {
646+
assert_eq!(
647+
ctx.get_ts_node_content(node),
648+
Some(NodeText::Original("select".into()))
649+
);
560650

561-
assert_eq!(
562-
ctx.wrapping_clause_type,
563-
Some(crate::context::WrappingClause::Select)
564-
);
651+
assert_eq!(
652+
ctx.wrapping_clause_type,
653+
Some(crate::context::WrappingClause::Select)
654+
);
655+
}
656+
_ => unreachable!(),
657+
}
565658
}
566659
}
567660

@@ -582,12 +675,17 @@ mod tests {
582675

583676
let ctx = CompletionContext::new(&params);
584677

585-
let node = ctx.node_under_cursor.unwrap();
678+
let node = ctx.node_under_cursor.as_ref().unwrap();
586679

587-
assert_eq!(
588-
ctx.get_ts_node_content(node),
589-
Some(NodeText::Original("from"))
590-
);
680+
match node {
681+
NodeUnderCursor::TsNode(node) => {
682+
assert_eq!(
683+
ctx.get_ts_node_content(node),
684+
Some(NodeText::Original("from".into()))
685+
);
686+
}
687+
_ => unreachable!(),
688+
}
591689
}
592690

593691
#[test]
@@ -607,10 +705,18 @@ mod tests {
607705

608706
let ctx = CompletionContext::new(&params);
609707

610-
let node = ctx.node_under_cursor.unwrap();
708+
let node = ctx.node_under_cursor.as_ref().unwrap();
611709

612-
assert_eq!(ctx.get_ts_node_content(node), Some(NodeText::Original("")));
613-
assert_eq!(ctx.wrapping_clause_type, None);
710+
match node {
711+
NodeUnderCursor::TsNode(node) => {
712+
assert_eq!(
713+
ctx.get_ts_node_content(node),
714+
Some(NodeText::Original("".into()))
715+
);
716+
assert_eq!(ctx.wrapping_clause_type, None);
717+
}
718+
_ => unreachable!(),
719+
}
614720
}
615721

616722
#[test]
@@ -632,12 +738,17 @@ mod tests {
632738

633739
let ctx = CompletionContext::new(&params);
634740

635-
let node = ctx.node_under_cursor.unwrap();
741+
let node = ctx.node_under_cursor.as_ref().unwrap();
636742

637-
assert_eq!(
638-
ctx.get_ts_node_content(node),
639-
Some(NodeText::Original("fro"))
640-
);
641-
assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select));
743+
match node {
744+
NodeUnderCursor::TsNode(node) => {
745+
assert_eq!(
746+
ctx.get_ts_node_content(node),
747+
Some(NodeText::Original("fro".into()))
748+
);
749+
assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select));
750+
}
751+
_ => unreachable!(),
752+
}
642753
}
643754
}

0 commit comments

Comments
 (0)