Skip to content

Commit 27cca0e

Browse files
build context…
1 parent 7622409 commit 27cca0e

File tree

3 files changed

+89
-37
lines changed

3 files changed

+89
-37
lines changed

crates/pgt_completions/src/context/context.rs

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ use pgt_treesitter_queries::{
77
queries::{self, QueryResult},
88
};
99

10-
use crate::sanitization::SanitizedCompletionParams;
10+
use crate::{
11+
NodeText, context::policy_parser::PolicyParser, sanitization::SanitizedCompletionParams,
12+
};
1113

1214
#[derive(Debug, PartialEq, Eq)]
1315
pub enum WrappingClause<'a> {
@@ -19,12 +21,8 @@ pub enum WrappingClause<'a> {
1921
},
2022
Update,
2123
Delete,
22-
}
23-
24-
#[derive(PartialEq, Eq, Debug)]
25-
pub(crate) enum NodeText<'a> {
26-
Replaced,
27-
Original(&'a str),
24+
PolicyName,
25+
ToRole,
2826
}
2927

3028
/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
@@ -45,7 +43,7 @@ pub enum WrappingNode {
4543
pub(crate) enum NodeUnderCursor<'a> {
4644
TsNode(tree_sitter::Node<'a>),
4745
CustomNode {
48-
text: NodeText<'a>,
46+
text: NodeText,
4947
range: TextRange,
5048
kind: String,
5149
},
@@ -172,14 +170,35 @@ impl<'a> CompletionContext<'a> {
172170
// policy handling is important to Supabase, but they are a PostgreSQL specific extension,
173171
// so the tree_sitter_sql language does not support it.
174172
// We infer the context manually.
175-
// if params.text.to_lowercase().starts_with("create policy")
176-
// || params.text.to_lowercase().starts_with("alter policy")
177-
// || params.text.to_lowercase().starts_with("drop policy")
178-
// {
179-
// } else {
180-
ctx.gather_tree_context();
181-
ctx.gather_info_from_ts_queries();
182-
// }
173+
if params.text.to_lowercase().starts_with("create policy")
174+
|| params.text.to_lowercase().starts_with("alter policy")
175+
|| params.text.to_lowercase().starts_with("drop policy")
176+
{
177+
let policy_context = PolicyParser::get_context(&ctx.text, ctx.position);
178+
179+
ctx.node_under_cursor = Some(NodeUnderCursor::CustomNode {
180+
text: policy_context.node_text.into(),
181+
range: policy_context.node_range,
182+
kind: policy_context.node_kind.clone(),
183+
});
184+
185+
if policy_context.table_name.is_some() {
186+
let mut new = HashSet::new();
187+
new.insert(policy_context.table_name.unwrap());
188+
ctx.mentioned_relations
189+
.insert(policy_context.schema_name, new);
190+
}
191+
192+
ctx.wrapping_clause_type = match policy_context.node_kind.as_str() {
193+
"policy_name" => Some(WrappingClause::PolicyName),
194+
"policy_role" => Some(WrappingClause::ToRole),
195+
"policy_table" => Some(WrappingClause::From),
196+
_ => None,
197+
};
198+
} else {
199+
ctx.gather_tree_context();
200+
ctx.gather_info_from_ts_queries();
201+
}
183202

184203
tracing::warn!("sql: {}", ctx.text);
185204
tracing::warn!("position: {}", ctx.position);
@@ -237,13 +256,13 @@ impl<'a> CompletionContext<'a> {
237256
}
238257
}
239258

240-
fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option<NodeText<'a>> {
259+
fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option<NodeText> {
241260
let source = self.text;
242261
ts_node.utf8_text(source.as_bytes()).ok().map(|txt| {
243262
if SanitizedCompletionParams::is_sanitized_token(txt) {
244263
NodeText::Replaced
245264
} else {
246-
NodeText::Original(txt)
265+
NodeText::Original(txt.into())
247266
}
248267
})
249268
}
@@ -386,7 +405,7 @@ impl<'a> CompletionContext<'a> {
386405
NodeText::Original(txt) => Some(txt),
387406
NodeText::Replaced => None,
388407
}) {
389-
match txt {
408+
match txt.as_str() {
390409
"where" => return Some(WrappingClause::Where),
391410
"update" => return Some(WrappingClause::Update),
392411
"select" => return Some(WrappingClause::Select),
@@ -436,7 +455,8 @@ impl<'a> CompletionContext<'a> {
436455
#[cfg(test)]
437456
mod tests {
438457
use crate::{
439-
context::{CompletionContext, NodeText, WrappingClause},
458+
NodeText,
459+
context::{CompletionContext, WrappingClause},
440460
sanitization::SanitizedCompletionParams,
441461
test_helper::{CURSOR_POS, get_text_and_position},
442462
};
@@ -607,7 +627,7 @@ mod tests {
607627
NodeUnderCursor::TsNode(node) => {
608628
assert_eq!(
609629
ctx.get_ts_node_content(node),
610-
Some(NodeText::Original("select"))
630+
Some(NodeText::Original("select".into()))
611631
);
612632

613633
assert_eq!(
@@ -643,7 +663,7 @@ mod tests {
643663
NodeUnderCursor::TsNode(node) => {
644664
assert_eq!(
645665
ctx.get_ts_node_content(&node),
646-
Some(NodeText::Original("from"))
666+
Some(NodeText::Original("from".into()))
647667
);
648668
}
649669
_ => unreachable!(),
@@ -671,7 +691,10 @@ mod tests {
671691

672692
match node {
673693
NodeUnderCursor::TsNode(node) => {
674-
assert_eq!(ctx.get_ts_node_content(&node), Some(NodeText::Original("")));
694+
assert_eq!(
695+
ctx.get_ts_node_content(&node),
696+
Some(NodeText::Original("".into()))
697+
);
675698
assert_eq!(ctx.wrapping_clause_type, None);
676699
}
677700
_ => unreachable!(),
@@ -703,7 +726,7 @@ mod tests {
703726
NodeUnderCursor::TsNode(node) => {
704727
assert_eq!(
705728
ctx.get_ts_node_content(&node),
706-
Some(NodeText::Original("fro"))
729+
Some(NodeText::Original("fro".into()))
707730
);
708731
assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select));
709732
}

crates/pgt_completions/src/context/policy_parser.rs

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use std::{env::current_exe, iter::Peekable};
1+
use std::iter::Peekable;
22

33
use pgt_text_size::{TextRange, TextSize};
44

55
#[derive(Default, Debug, PartialEq, Eq)]
6-
pub enum PolicyStmtKind {
6+
pub(crate) enum PolicyStmtKind {
77
#[default]
88
Create,
99

@@ -90,17 +90,17 @@ fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
9090
}
9191

9292
#[derive(Default, Debug, PartialEq, Eq)]
93-
pub struct PolicyContext {
94-
policy_name: Option<String>,
95-
table_name: Option<String>,
96-
schema_name: Option<String>,
97-
statement_kind: PolicyStmtKind,
98-
node_text: String,
99-
node_range: TextRange,
100-
node_kind: String,
93+
pub(crate) struct PolicyContext {
94+
pub policy_name: Option<String>,
95+
pub table_name: Option<String>,
96+
pub schema_name: Option<String>,
97+
pub statement_kind: PolicyStmtKind,
98+
pub node_text: String,
99+
pub node_range: TextRange,
100+
pub node_kind: String,
101101
}
102102

103-
pub struct PolicyParser {
103+
pub(crate) struct PolicyParser {
104104
tokens: Peekable<std::vec::IntoIter<WordWithIndex>>,
105105
previous_token: Option<WordWithIndex>,
106106
current_token: Option<WordWithIndex>,
@@ -110,6 +110,13 @@ pub struct PolicyParser {
110110

111111
impl PolicyParser {
112112
pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext {
113+
assert!(
114+
sql.starts_with("create policy")
115+
|| sql.starts_with("drop policy")
116+
|| sql.starts_with("alter policy"),
117+
"PolicyParser should only be used for policy statements. Developer error!"
118+
);
119+
113120
match sql_to_words(sql) {
114121
Ok(tokens) => {
115122
let parser = PolicyParser {

crates/pgt_completions/src/sanitization.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use pgt_text_size::TextSize;
44

55
use crate::CompletionParams;
66

7+
static SANITIZED_TOKEN: &str = "REPLACED_TOKEN";
8+
79
pub(crate) struct SanitizedCompletionParams<'a> {
810
pub position: TextSize,
911
pub text: String,
@@ -16,6 +18,28 @@ pub fn benchmark_sanitization(params: CompletionParams) -> String {
1618
params.text
1719
}
1820

21+
#[derive(PartialEq, Eq, Debug)]
22+
pub(crate) enum NodeText {
23+
Replaced,
24+
Original(String),
25+
}
26+
27+
impl From<&str> for NodeText {
28+
fn from(value: &str) -> Self {
29+
if value == SANITIZED_TOKEN {
30+
NodeText::Replaced
31+
} else {
32+
NodeText::Original(value.into())
33+
}
34+
}
35+
}
36+
37+
impl From<String> for NodeText {
38+
fn from(value: String) -> Self {
39+
NodeText::from(value.as_str())
40+
}
41+
}
42+
1943
impl<'larger, 'smaller> From<CompletionParams<'larger>> for SanitizedCompletionParams<'smaller>
2044
where
2145
'larger: 'smaller,
@@ -33,8 +57,6 @@ where
3357
}
3458
}
3559

36-
static SANITIZED_TOKEN: &str = "REPLACED_TOKEN";
37-
3860
impl<'larger, 'smaller> SanitizedCompletionParams<'smaller>
3961
where
4062
'larger: 'smaller,

0 commit comments

Comments
 (0)