Skip to content

Commit d9f1c8d

Browse files
not sure about this
1 parent bd034e7 commit d9f1c8d

File tree

4 files changed

+173
-50
lines changed

4 files changed

+173
-50
lines changed

crates/pgt_completions/src/context.rs

Lines changed: 103 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::collections::{HashMap, HashSet};
1+
use std::{
2+
cmp,
3+
collections::{HashMap, HashSet},
4+
};
25

36
use pgt_schema_cache::SchemaCache;
47
use pgt_treesitter_queries::{
@@ -8,7 +11,7 @@ use pgt_treesitter_queries::{
811

912
use crate::sanitization::SanitizedCompletionParams;
1013

11-
#[derive(Debug, PartialEq, Eq, Hash)]
14+
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
1215
pub enum WrappingClause<'a> {
1316
Select,
1417
Where,
@@ -19,6 +22,7 @@ pub enum WrappingClause<'a> {
1922
Update,
2023
Delete,
2124
ColumnDefinitions,
25+
Insert,
2226
}
2327

2428
#[derive(PartialEq, Eq, Debug)]
@@ -46,6 +50,7 @@ pub enum WrappingNode {
4650
Relation,
4751
BinaryExpression,
4852
Assignment,
53+
List,
4954
}
5055

5156
impl TryFrom<&str> for WrappingNode {
@@ -56,6 +61,7 @@ impl TryFrom<&str> for WrappingNode {
5661
"relation" => Ok(Self::Relation),
5762
"assignment" => Ok(Self::Assignment),
5863
"binary_expression" => Ok(Self::BinaryExpression),
64+
"list" => Ok(Self::List),
5965
_ => {
6066
let message = format!("Unimplemented Relation: {}", value);
6167

@@ -142,13 +148,6 @@ impl<'a> CompletionContext<'a> {
142148
ctx.gather_tree_context();
143149
ctx.gather_info_from_ts_queries();
144150

145-
if cfg!(test) {
146-
println!("{:#?}", ctx.wrapping_clause_type);
147-
println!("{:#?}", ctx.wrapping_node_kind);
148-
println!("{:#?}", ctx.is_in_error_node);
149-
println!("{:#?}", ctx.text);
150-
}
151-
152151
ctx
153152
}
154153

@@ -240,10 +239,20 @@ impl<'a> CompletionContext<'a> {
240239
* `select * from use {}` becomes `select * from use{}`.
241240
*/
242241
let current_node = cursor.node();
243-
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
244-
self.position -= 1;
242+
243+
let mut chars = self.text.chars();
244+
245+
if chars
246+
.nth(self.position)
247+
.is_some_and(|c| !c.is_ascii_whitespace())
248+
{
249+
self.position = cmp::min(self.position + 1, self.text.len());
250+
} else {
251+
self.position = cmp::min(self.position, self.text.len());
245252
}
246253

254+
cursor.goto_first_child_for_byte(self.position);
255+
247256
self.gather_context_from_node(cursor, current_node);
248257
}
249258

@@ -276,23 +285,11 @@ impl<'a> CompletionContext<'a> {
276285

277286
// try to gather context from the siblings if we're within an error node.
278287
if self.is_in_error_node {
279-
let mut next_sibling = current_node.next_named_sibling();
280-
while let Some(n) = next_sibling {
281-
if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) {
282-
self.wrapping_clause_type = Some(clause_type);
283-
break;
284-
} else {
285-
next_sibling = n.next_named_sibling();
286-
}
288+
if let Some(clause_type) = self.get_wrapping_clause_from_siblings(current_node) {
289+
self.wrapping_clause_type = Some(clause_type);
287290
}
288-
let mut prev_sibling = current_node.prev_named_sibling();
289-
while let Some(n) = prev_sibling {
290-
if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) {
291-
self.wrapping_clause_type = Some(clause_type);
292-
break;
293-
} else {
294-
prev_sibling = n.prev_named_sibling();
295-
}
291+
if let Some(wrapping_node) = self.get_wrapping_node_from_siblings(current_node) {
292+
self.wrapping_node_kind = Some(wrapping_node)
296293
}
297294
}
298295

@@ -317,7 +314,7 @@ impl<'a> CompletionContext<'a> {
317314
self.get_wrapping_clause_from_current_node(current_node, &mut cursor);
318315
}
319316

320-
"relation" | "binary_expression" | "assignment" => {
317+
"relation" | "binary_expression" | "assignment" | "list" => {
321318
self.wrapping_node_kind = current_node_kind.try_into().ok();
322319
}
323320

@@ -338,31 +335,89 @@ impl<'a> CompletionContext<'a> {
338335
self.gather_context_from_node(cursor, current_node);
339336
}
340337

341-
fn get_wrapping_clause_from_keyword_node(
338+
fn get_first_sibling(&self, node: tree_sitter::Node<'a>) -> tree_sitter::Node<'a> {
339+
let mut first_sibling = node;
340+
while let Some(n) = first_sibling.prev_sibling() {
341+
first_sibling = n;
342+
}
343+
first_sibling
344+
}
345+
346+
fn get_wrapping_node_from_siblings(&self, node: tree_sitter::Node<'a>) -> Option<WrappingNode> {
347+
self.wrapping_clause_type
348+
.as_ref()
349+
.and_then(|clause| match clause {
350+
WrappingClause::Insert => {
351+
if node.prev_sibling().is_some_and(|n| n.kind() == "(")
352+
|| node.next_sibling().is_some_and(|n| n.kind() == ")")
353+
{
354+
Some(WrappingNode::List)
355+
} else {
356+
None
357+
}
358+
}
359+
_ => None,
360+
})
361+
}
362+
363+
fn get_wrapping_clause_from_siblings(
342364
&self,
343365
node: tree_sitter::Node<'a>,
344366
) -> Option<WrappingClause<'a>> {
345-
if node.kind().starts_with("keyword_") {
346-
if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt {
347-
NodeText::Original(txt) => Some(txt),
348-
NodeText::Replaced => None,
349-
}) {
350-
match txt {
351-
"where" => return Some(WrappingClause::Where),
352-
"update" => return Some(WrappingClause::Update),
353-
"select" => return Some(WrappingClause::Select),
354-
"delete" => return Some(WrappingClause::Delete),
355-
"from" => return Some(WrappingClause::From),
356-
"join" => {
357-
// TODO: not sure if we can infer it here.
358-
return Some(WrappingClause::Join { on_node: None });
367+
let clause_combinations: Vec<(WrappingClause, &[&'static str])> = vec![
368+
(WrappingClause::Where, &["where"]),
369+
(WrappingClause::Update, &["update"]),
370+
(WrappingClause::Select, &["select"]),
371+
(WrappingClause::Delete, &["delete"]),
372+
(WrappingClause::Insert, &["insert", "into"]),
373+
(WrappingClause::From, &["from"]),
374+
(WrappingClause::Join { on_node: None }, &["join"]),
375+
];
376+
377+
let first_sibling = self.get_first_sibling(node);
378+
379+
/*
380+
* For each clause, we'll iterate from first_sibling to the next ones,
381+
* either until the end or until we land on the node under the cursor.
382+
* We'll score the `WrappingClause` by how many tokens it matches in order.
383+
*/
384+
let mut clauses_with_score: Vec<(WrappingClause, usize)> = clause_combinations
385+
.into_iter()
386+
.map(|(clause, tokens)| {
387+
let mut idx = 0;
388+
389+
let mut sibling = Some(first_sibling);
390+
while let Some(sib) = sibling {
391+
if sib.end_byte() >= node.end_byte() || idx >= tokens.len() {
392+
break;
359393
}
360-
_ => {}
394+
395+
if let Some(sibling_content) =
396+
self.get_ts_node_content(sib).and_then(|txt| match txt {
397+
NodeText::Original(txt) => Some(txt),
398+
NodeText::Replaced => None,
399+
})
400+
{
401+
if sibling_content == tokens[idx] {
402+
idx += 1;
403+
}
404+
} else {
405+
break;
406+
}
407+
408+
sibling = sib.next_sibling();
361409
}
362-
};
363-
}
364410

365-
None
411+
(clause, idx)
412+
})
413+
.collect();
414+
415+
clauses_with_score.sort_by(|(_, score_a), (_, score_b)| score_b.cmp(score_a));
416+
clauses_with_score
417+
.iter()
418+
.filter(|(_, score)| *score > 0)
419+
.next()
420+
.map(|c| c.0.clone())
366421
}
367422

368423
fn get_wrapping_clause_from_current_node(
@@ -377,6 +432,7 @@ impl<'a> CompletionContext<'a> {
377432
"delete" => Some(WrappingClause::Delete),
378433
"from" => Some(WrappingClause::From),
379434
"column_definitions" => Some(WrappingClause::ColumnDefinitions),
435+
"insert" => Some(WrappingClause::Insert),
380436
"join" => {
381437
// sadly, we need to manually iterate over the children –
382438
// `node.child_by_field_id(..)` does not work as expected

crates/pgt_completions/src/providers/columns.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,4 +573,24 @@ mod tests {
573573
)
574574
.await;
575575
}
576+
577+
#[tokio::test]
578+
async fn suggests_columns_in_insert_clause() {
579+
let setup = r#"
580+
create table instruments (
581+
id bigint primary key generated always as identity,
582+
name text not null
583+
);
584+
"#;
585+
586+
assert_complete_results(
587+
format!("insert into instruments ({})", CURSOR_POS).as_str(),
588+
vec![
589+
CompletionAssertion::Label("id".to_string()),
590+
CompletionAssertion::Label("name".to_string()),
591+
],
592+
setup,
593+
)
594+
.await;
595+
}
576596
}

crates/pgt_completions/src/relevance/filtering.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::context::{CompletionContext, WrappingClause};
1+
use crate::context::{CompletionContext, WrappingClause, WrappingNode};
22

33
use super::CompletionRelevanceData;
44

@@ -73,6 +73,17 @@ impl CompletionFilter<'_> {
7373
WrappingClause::Select
7474
| WrappingClause::Where
7575
| WrappingClause::ColumnDefinitions => false,
76+
77+
WrappingClause::Insert => {
78+
ctx.wrapping_node_kind
79+
.as_ref()
80+
.is_some_and(|n| n != &WrappingNode::List)
81+
&& ctx.node_under_cursor.is_some_and(|n| {
82+
n.prev_sibling()
83+
.is_some_and(|sib| sib.kind() == "keyword_into")
84+
})
85+
}
86+
7687
_ => true,
7788
},
7889
CompletionRelevanceData::Column(_) => {
@@ -88,6 +99,11 @@ impl CompletionFilter<'_> {
8899
// we are in a JOIN, but definitely not after an ON
89100
WrappingClause::Join { on_node: None } => false,
90101

102+
WrappingClause::Insert => ctx
103+
.wrapping_node_kind
104+
.as_ref()
105+
.is_some_and(|n| n == &WrappingNode::List),
106+
91107
_ => true,
92108
}
93109
}
@@ -107,6 +123,16 @@ impl CompletionFilter<'_> {
107123
| WrappingClause::Update
108124
| WrappingClause::Delete => true,
109125

126+
WrappingClause::Insert => {
127+
ctx.wrapping_node_kind
128+
.as_ref()
129+
.is_some_and(|n| n != &WrappingNode::List)
130+
&& ctx.node_under_cursor.is_some_and(|n| {
131+
n.prev_sibling()
132+
.is_some_and(|sib| sib.kind() == "keyword_into")
133+
})
134+
}
135+
110136
WrappingClause::ColumnDefinitions => false,
111137
},
112138
}

crates/pgt_completions/src/sanitization.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ where
2525
|| cursor_prepared_to_write_token_after_last_node(params.tree, params.position)
2626
|| cursor_before_semicolon(params.tree, params.position)
2727
|| cursor_on_a_dot(&params.text, params.position)
28+
|| cursor_between_parentheses(&params.text, params.position)
2829
{
2930
SanitizedCompletionParams::with_adjusted_sql(params)
3031
} else {
@@ -200,13 +201,19 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool
200201
.unwrap_or(false)
201202
}
202203

204+
fn cursor_between_parentheses(sql: &str, position: TextSize) -> bool {
205+
let position: usize = position.into();
206+
let mut chars = sql.chars();
207+
chars.nth(position - 1).is_some_and(|c| c == '(') && chars.next().is_some_and(|c| c == ')')
208+
}
209+
203210
#[cfg(test)]
204211
mod tests {
205212
use pgt_text_size::TextSize;
206213

207214
use crate::sanitization::{
208-
cursor_before_semicolon, cursor_inbetween_nodes, cursor_on_a_dot,
209-
cursor_prepared_to_write_token_after_last_node,
215+
cursor_before_semicolon, cursor_between_parentheses, cursor_inbetween_nodes,
216+
cursor_on_a_dot, cursor_prepared_to_write_token_after_last_node,
210217
};
211218

212219
#[test]
@@ -317,4 +324,18 @@ mod tests {
317324
assert!(cursor_before_semicolon(&tree, TextSize::new(16)));
318325
assert!(cursor_before_semicolon(&tree, TextSize::new(17)));
319326
}
327+
328+
#[test]
329+
fn between_parentheses() {
330+
let input = "insert into instruments ()";
331+
332+
// insert into (|) <- right in the parentheses
333+
assert!(cursor_between_parentheses(input, TextSize::new(25)));
334+
335+
// insert into ()| <- too late
336+
assert!(!cursor_between_parentheses(input, TextSize::new(26)));
337+
338+
// insert into |() <- too early
339+
assert!(!cursor_between_parentheses(input, TextSize::new(24)));
340+
}
320341
}

0 commit comments

Comments
 (0)