Skip to content

Commit 121431d

Browse files
ok
1 parent 27cca0e commit 121431d

File tree

6 files changed

+122
-37
lines changed

6 files changed

+122
-37
lines changed

crates/pgt_completions/src/context/context.rs

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ use pgt_treesitter_queries::{
88
};
99

1010
use crate::{
11-
NodeText, context::policy_parser::PolicyParser, sanitization::SanitizedCompletionParams,
11+
NodeText,
12+
context::policy_parser::{PolicyParser, PolicyStmtKind},
13+
sanitization::SanitizedCompletionParams,
1214
};
1315

1416
#[derive(Debug, PartialEq, Eq)]
@@ -40,6 +42,7 @@ pub enum WrappingNode {
4042
Assignment,
4143
}
4244

45+
#[derive(Debug)]
4346
pub(crate) enum NodeUnderCursor<'a> {
4447
TsNode(tree_sitter::Node<'a>),
4548
CustomNode {
@@ -64,6 +67,13 @@ impl<'a> NodeUnderCursor<'a> {
6467
}
6568
}
6669

70+
pub fn range(&self) -> TextRange {
71+
let start: u32 = self.start_byte().try_into().unwrap();
72+
let end: u32 = self.end_byte().try_into().unwrap();
73+
74+
TextRange::new(start.into(), end.into())
75+
}
76+
6777
pub fn kind(&self) -> &str {
6878
match self {
6979
NodeUnderCursor::TsNode(node) => node.kind(),
@@ -182,6 +192,10 @@ impl<'a> CompletionContext<'a> {
182192
kind: policy_context.node_kind.clone(),
183193
});
184194

195+
if policy_context.node_kind == "policy_table" {
196+
ctx.schema_or_alias_name = policy_context.schema_name.clone();
197+
}
198+
185199
if policy_context.table_name.is_some() {
186200
let mut new = HashSet::new();
187201
new.insert(policy_context.table_name.unwrap());
@@ -190,7 +204,9 @@ impl<'a> CompletionContext<'a> {
190204
}
191205

192206
ctx.wrapping_clause_type = match policy_context.node_kind.as_str() {
193-
"policy_name" => Some(WrappingClause::PolicyName),
207+
"policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => {
208+
Some(WrappingClause::PolicyName)
209+
}
194210
"policy_role" => Some(WrappingClause::ToRole),
195211
"policy_table" => Some(WrappingClause::From),
196212
_ => None,
@@ -200,19 +216,11 @@ impl<'a> CompletionContext<'a> {
200216
ctx.gather_info_from_ts_queries();
201217
}
202218

203-
tracing::warn!("sql: {}", ctx.text);
204-
tracing::warn!("position: {}", ctx.position);
205-
tracing::warn!(
206-
"node range: {} - {}",
207-
ctx.node_under_cursor
208-
.as_ref()
209-
.map(|n| n.start_byte())
210-
.unwrap_or(0),
211-
ctx.node_under_cursor
212-
.as_ref()
213-
.map(|n| n.end_byte())
214-
.unwrap_or(0)
215-
);
219+
tracing::warn!("SQL: {}", ctx.text);
220+
tracing::warn!("Position: {}", ctx.position);
221+
tracing::warn!("Node: {:#?}", ctx.node_under_cursor);
222+
tracing::warn!("Relations: {:#?}", ctx.mentioned_relations);
223+
tracing::warn!("Clause: {:#?}", ctx.wrapping_clause_type);
216224

217225
ctx
218226
}

crates/pgt_completions/src/providers/helper.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,25 @@ pub(crate) fn find_matching_alias_for_table(
1414
None
1515
}
1616

17+
pub(crate) fn get_range_to_replace(ctx: &CompletionContext) -> TextRange {
18+
let start = ctx
19+
.node_under_cursor
20+
.as_ref()
21+
.map(|n| n.start_byte())
22+
.unwrap_or(0);
23+
24+
let end = ctx
25+
.get_node_under_cursor_content()
26+
.unwrap_or("".into())
27+
.len()
28+
+ start;
29+
30+
TextRange::new(
31+
TextSize::new(start.try_into().unwrap()),
32+
end.try_into().unwrap(),
33+
)
34+
}
35+
1736
pub(crate) fn get_completion_text_with_schema_or_alias(
1837
ctx: &CompletionContext,
1938
item_name: &str,
@@ -22,12 +41,7 @@ pub(crate) fn get_completion_text_with_schema_or_alias(
2241
if schema_or_alias_name == "public" || ctx.schema_or_alias_name.is_some() {
2342
None
2443
} else {
25-
let node = ctx.node_under_cursor.as_ref().unwrap();
26-
27-
let range = TextRange::new(
28-
TextSize::try_from(node.start_byte()).unwrap(),
29-
TextSize::try_from(node.end_byte()).unwrap(),
30-
);
44+
let range = get_range_to_replace(ctx);
3145

3246
Some(CompletionText {
3347
text: format!("{}.{}", schema_or_alias_name, item_name),
Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,60 @@
11
use crate::{
2-
CompletionItemKind,
2+
CompletionItemKind, CompletionText,
33
builder::{CompletionBuilder, PossibleCompletionItem},
44
context::CompletionContext,
55
relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore},
66
};
77

8+
use super::helper::get_range_to_replace;
9+
810
pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) {
911
let available_policies = &ctx.schema_cache.policies;
1012

1113
for pol in available_policies {
1214
let relevance = CompletionRelevanceData::Policy(pol);
1315

1416
let item = PossibleCompletionItem {
15-
label: pol.name.clone(),
17+
label: pol.name.chars().take(35).collect::<String>(),
1618
score: CompletionScore::from(relevance.clone()),
1719
filter: CompletionFilter::from(relevance),
18-
description: format!("Table: {}", pol.table_name),
20+
description: format!("{}", pol.table_name),
1921
kind: CompletionItemKind::Policy,
20-
completion_text: None,
22+
completion_text: Some(CompletionText {
23+
text: format!("\"{}\"", pol.name),
24+
range: get_range_to_replace(ctx),
25+
}),
2126
};
2227

2328
builder.add_item(item);
2429
}
2530
}
31+
32+
#[cfg(test)]
33+
mod tests {
34+
use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results};
35+
36+
#[tokio::test]
37+
async fn completes_within_quotation_marks() {
38+
let setup = r#"
39+
create table users (
40+
id serial primary key,
41+
email text
42+
);
43+
44+
create policy "should never have access" on users
45+
as restrictive
46+
for all
47+
to public
48+
using (false);
49+
"#;
50+
51+
assert_complete_results(
52+
format!("alter policy \"{}\" on users;", CURSOR_POS).as_str(),
53+
vec![CompletionAssertion::Label(
54+
"should never have access".into(),
55+
)],
56+
setup,
57+
)
58+
.await;
59+
}
60+
}

crates/pgt_completions/src/relevance/filtering.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,19 @@ impl CompletionFilter<'_> {
6767
fn check_clause(&self, ctx: &CompletionContext) -> Option<()> {
6868
let clause = ctx.wrapping_clause_type.as_ref();
6969

70+
let in_clause = |compare: WrappingClause| clause.is_some_and(|c| c == &compare);
71+
7072
match self.data {
7173
CompletionRelevanceData::Table(_) => {
72-
let in_select_clause = clause.is_some_and(|c| c == &WrappingClause::Select);
73-
let in_where_clause = clause.is_some_and(|c| c == &WrappingClause::Where);
74-
75-
if in_select_clause || in_where_clause {
74+
if in_clause(WrappingClause::Select)
75+
|| in_clause(WrappingClause::Where)
76+
|| in_clause(WrappingClause::PolicyName)
77+
{
7678
return None;
7779
};
7880
}
7981
CompletionRelevanceData::Column(_) => {
80-
let in_from_clause = clause.is_some_and(|c| c == &WrappingClause::From);
81-
if in_from_clause {
82+
if in_clause(WrappingClause::From) || in_clause(WrappingClause::PolicyName) {
8283
return None;
8384
}
8485

@@ -100,7 +101,16 @@ impl CompletionFilter<'_> {
100101
return None;
101102
}
102103
}
103-
_ => {}
104+
CompletionRelevanceData::Policy(_) => {
105+
if clause.is_none_or(|c| c != &WrappingClause::PolicyName) {
106+
return None;
107+
}
108+
}
109+
_ => {
110+
if in_clause(WrappingClause::PolicyName) {
111+
return None;
112+
}
113+
}
104114
}
105115

106116
Some(())
@@ -140,7 +150,7 @@ impl CompletionFilter<'_> {
140150
}
141151

142152
// no aliases and schemas for policies
143-
CompletionRelevanceData::Policy(_) => false,
153+
CompletionRelevanceData::Policy(p) => p.schema_name == p.schema_name,
144154
};
145155

146156
if !matches {

crates/pgt_completions/src/relevance/scoring.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ impl CompletionScore<'_> {
119119
WrappingClause::Delete if !has_mentioned_schema => 15,
120120
_ => -50,
121121
},
122-
CompletionRelevanceData::Policy(_) => 0,
122+
CompletionRelevanceData::Policy(_) => match clause_type {
123+
WrappingClause::PolicyName => 25,
124+
_ => -50,
125+
},
123126
}
124127
}
125128

@@ -187,7 +190,7 @@ impl CompletionScore<'_> {
187190
CompletionRelevanceData::Table(t) => t.schema.as_str(),
188191
CompletionRelevanceData::Column(c) => c.schema_name.as_str(),
189192
CompletionRelevanceData::Schema(s) => s.name.as_str(),
190-
CompletionRelevanceData::Policy(p) => p.name.as_str(),
193+
CompletionRelevanceData::Policy(p) => p.schema_name.as_str(),
191194
}
192195
}
193196

crates/pgt_completions/src/sanitization.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ where
4949
|| cursor_prepared_to_write_token_after_last_node(params.tree, params.position)
5050
|| cursor_before_semicolon(params.tree, params.position)
5151
|| cursor_on_a_dot(&params.text, params.position)
52+
|| cursor_between_double_quotes(&params.text, params.position)
5253
{
5354
SanitizedCompletionParams::with_adjusted_sql(params)
5455
} else {
@@ -178,6 +179,12 @@ fn cursor_on_a_dot(sql: &str, position: TextSize) -> bool {
178179
sql.chars().nth(position - 1).is_some_and(|c| c == '.')
179180
}
180181

182+
fn cursor_between_double_quotes(sql: &str, position: TextSize) -> bool {
183+
let position: usize = position.into();
184+
let mut chars = sql.chars();
185+
chars.nth(position - 1).is_some_and(|c| c == '"') && chars.next().is_some_and(|c| c == '"')
186+
}
187+
181188
fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool {
182189
let mut cursor = tree.walk();
183190
let mut leaf_node = tree.root_node();
@@ -227,8 +234,8 @@ mod tests {
227234
use pgt_text_size::TextSize;
228235

229236
use crate::sanitization::{
230-
cursor_before_semicolon, cursor_inbetween_nodes, cursor_on_a_dot,
231-
cursor_prepared_to_write_token_after_last_node,
237+
cursor_before_semicolon, cursor_between_double_quotes, cursor_inbetween_nodes,
238+
cursor_on_a_dot, cursor_prepared_to_write_token_after_last_node,
232239
};
233240

234241
#[test]
@@ -339,4 +346,12 @@ mod tests {
339346
assert!(cursor_before_semicolon(&tree, TextSize::new(16)));
340347
assert!(cursor_before_semicolon(&tree, TextSize::new(17)));
341348
}
349+
350+
#[test]
351+
fn between_quotations() {
352+
let input = "select * from \"\"";
353+
354+
// select * from "|" <-- between quotations
355+
assert!(cursor_between_double_quotes(input, TextSize::new(15)));
356+
}
342357
}

0 commit comments

Comments
 (0)