Skip to content

Commit 09084c1

Browse files
cool
1 parent ec3d575 commit 09084c1

File tree

7 files changed

+347
-19
lines changed

7 files changed

+347
-19
lines changed

crates/pgt_completions/src/context/mod.rs

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -236,44 +236,54 @@ impl<'a> CompletionContext<'a> {
236236
executor.add_query_results::<queries::RelationMatch>();
237237
executor.add_query_results::<queries::TableAliasMatch>();
238238
executor.add_query_results::<queries::SelectColumnMatch>();
239+
executor.add_query_results::<queries::InsertColumnMatch>();
239240

240241
for relation_match in executor.get_iter(stmt_range) {
241242
match relation_match {
242243
QueryResult::Relation(r) => {
243244
let schema_name = r.get_schema(sql);
244245
let table_name = r.get_table(sql);
245246

246-
if let Some(c) = self.mentioned_relations.get_mut(&schema_name) {
247-
c.insert(table_name);
248-
} else {
249-
let mut new = HashSet::new();
250-
new.insert(table_name);
251-
self.mentioned_relations.insert(schema_name, new);
252-
}
247+
self.mentioned_relations
248+
.entry(schema_name)
249+
.and_modify(|s| {
250+
s.insert(table_name.clone());
251+
})
252+
.or_insert(HashSet::from([table_name]));
253253
}
254254
QueryResult::TableAliases(table_alias_match) => {
255255
self.mentioned_table_aliases.insert(
256256
table_alias_match.get_alias(sql),
257257
table_alias_match.get_table(sql),
258258
);
259259
}
260+
260261
QueryResult::SelectClauseColumns(c) => {
261262
let mentioned = MentionedColumn {
262263
column: c.get_column(sql),
263264
alias: c.get_alias(sql),
264265
};
265266

266-
if let Some(cols) = self
267-
.mentioned_columns
268-
.get_mut(&Some(WrappingClause::Select))
269-
{
270-
cols.insert(mentioned);
271-
} else {
272-
let mut new = HashSet::new();
273-
new.insert(mentioned);
274-
self.mentioned_columns
275-
.insert(Some(WrappingClause::Select), new);
276-
}
267+
self.mentioned_columns
268+
.entry(Some(WrappingClause::Select))
269+
.and_modify(|s| {
270+
s.insert(mentioned.clone());
271+
})
272+
.or_insert(HashSet::from([mentioned]));
273+
}
274+
275+
QueryResult::InsertClauseColumns(c) => {
276+
let mentioned = MentionedColumn {
277+
column: c.get_column(sql),
278+
alias: None,
279+
};
280+
281+
self.mentioned_columns
282+
.entry(Some(WrappingClause::Insert))
283+
.and_modify(|s| {
284+
s.insert(mentioned.clone());
285+
})
286+
.or_insert(HashSet::from([mentioned]));
277287
}
278288
};
279289
}
@@ -628,6 +638,17 @@ impl<'a> CompletionContext<'a> {
628638
}
629639
}
630640

641+
pub(crate) fn parent_matches_one_of_kind(&self, kinds: &[&'static str]) -> bool {
642+
self.node_under_cursor
643+
.as_ref()
644+
.is_some_and(|under_cursor| match under_cursor {
645+
NodeUnderCursor::TsNode(node) => node
646+
.parent()
647+
.is_some_and(|parent| kinds.contains(&parent.kind())),
648+
649+
NodeUnderCursor::CustomNode { .. } => false,
650+
})
651+
}
631652
pub(crate) fn before_cursor_matches_kind(&self, kinds: &[&'static str]) -> bool {
632653
self.node_under_cursor.as_ref().is_some_and(|under_cursor| {
633654
match under_cursor {

crates/pgt_completions/src/providers/columns.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,21 @@ mod tests {
621621
)
622622
.await;
623623

624+
// works with completed statement
625+
assert_complete_results(
626+
format!(
627+
"insert into instruments (name, {}) values ('my_bass');",
628+
CURSOR_POS
629+
)
630+
.as_str(),
631+
vec![
632+
CompletionAssertion::Label("id".to_string()),
633+
CompletionAssertion::Label("z".to_string()),
634+
],
635+
setup,
636+
)
637+
.await;
638+
624639
// no completions in the values list!
625640
assert_no_complete_results(
626641
format!("insert into instruments (id, name) values ({})", CURSOR_POS).as_str(),

crates/pgt_completions/src/providers/tables.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,5 +402,31 @@ mod tests {
402402
setup,
403403
)
404404
.await;
405+
406+
assert_complete_results(
407+
format!("insert into auth.{}", CURSOR_POS).as_str(),
408+
vec![CompletionAssertion::LabelAndKind(
409+
"users".into(),
410+
CompletionItemKind::Table,
411+
)],
412+
setup,
413+
)
414+
.await;
415+
416+
// works with complete statement.
417+
assert_complete_results(
418+
format!(
419+
"insert into {} (name, email) values ('jules', 'a@b.com');",
420+
CURSOR_POS
421+
)
422+
.as_str(),
423+
vec![
424+
CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema),
425+
CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema),
426+
CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table),
427+
],
428+
setup,
429+
)
430+
.await;
405431
}
406432
}

crates/pgt_completions/src/relevance/filtering.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ impl CompletionFilter<'_> {
8282
ctx.wrapping_node_kind
8383
.as_ref()
8484
.is_none_or(|n| n != &WrappingNode::List)
85-
&& ctx.before_cursor_matches_kind(&["keyword_into"])
85+
&& (ctx.before_cursor_matches_kind(&["keyword_into"])
86+
|| (ctx.before_cursor_matches_kind(&["."])
87+
&& ctx.parent_matches_one_of_kind(&["object_reference"])))
8688
}
8789

8890
WrappingClause::DropTable | WrappingClause::AlterTable => ctx
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
use std::sync::LazyLock;
2+
3+
use crate::{Query, QueryResult};
4+
5+
use super::QueryTryFrom;
6+
7+
static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
8+
static QUERY_STR: &str = r#"
9+
(insert
10+
(object_reference)
11+
(list
12+
"("?
13+
(column) @column
14+
","?
15+
")"?
16+
)
17+
)
18+
"#;
19+
tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query")
20+
});
21+
22+
#[derive(Debug)]
23+
pub struct InsertColumnMatch<'a> {
24+
pub(crate) column: tree_sitter::Node<'a>,
25+
}
26+
27+
impl InsertColumnMatch<'_> {
28+
pub fn get_column(&self, sql: &str) -> String {
29+
self.column
30+
.utf8_text(sql.as_bytes())
31+
.expect("Failed to get column from ColumnMatch")
32+
.to_string()
33+
}
34+
}
35+
36+
impl<'a> TryFrom<&'a QueryResult<'a>> for &'a InsertColumnMatch<'a> {
37+
type Error = String;
38+
39+
fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
40+
match q {
41+
QueryResult::InsertClauseColumns(c) => Ok(c),
42+
43+
#[allow(unreachable_patterns)]
44+
_ => Err("Invalid QueryResult type".into()),
45+
}
46+
}
47+
}
48+
49+
impl<'a> QueryTryFrom<'a> for InsertColumnMatch<'a> {
50+
type Ref = &'a InsertColumnMatch<'a>;
51+
}
52+
53+
impl<'a> Query<'a> for InsertColumnMatch<'a> {
54+
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
55+
let mut cursor = tree_sitter::QueryCursor::new();
56+
57+
let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());
58+
59+
let mut to_return = vec![];
60+
61+
for m in matches {
62+
if m.captures.len() == 1 {
63+
let capture = m.captures[0].node;
64+
to_return.push(QueryResult::InsertClauseColumns(InsertColumnMatch {
65+
column: capture,
66+
}));
67+
}
68+
}
69+
70+
to_return
71+
}
72+
}
73+
#[cfg(test)]
74+
mod tests {
75+
use super::InsertColumnMatch;
76+
use crate::TreeSitterQueriesExecutor;
77+
78+
#[test]
79+
fn finds_all_insert_columns() {
80+
let sql = r#"insert into users (id, email, name) values (1, 'a@b.com', 'Alice');"#;
81+
82+
let mut parser = tree_sitter::Parser::new();
83+
parser.set_language(tree_sitter_sql::language()).unwrap();
84+
85+
let tree = parser.parse(sql, None).unwrap();
86+
87+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
88+
89+
executor.add_query_results::<InsertColumnMatch>();
90+
91+
let results: Vec<&InsertColumnMatch> = executor
92+
.get_iter(None)
93+
.filter_map(|q| q.try_into().ok())
94+
.collect();
95+
96+
let columns: Vec<String> = results.iter().map(|c| c.get_column(sql)).collect();
97+
98+
assert_eq!(columns, vec!["id", "email", "name"]);
99+
}
100+
101+
#[test]
102+
fn finds_insert_columns_with_whitespace_and_commas() {
103+
let sql = r#"
104+
insert into users (
105+
id,
106+
email,
107+
name
108+
) values (1, 'a@b.com', 'Alice');
109+
"#;
110+
111+
let mut parser = tree_sitter::Parser::new();
112+
parser.set_language(tree_sitter_sql::language()).unwrap();
113+
114+
let tree = parser.parse(sql, None).unwrap();
115+
116+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
117+
118+
executor.add_query_results::<InsertColumnMatch>();
119+
120+
let results: Vec<&InsertColumnMatch> = executor
121+
.get_iter(None)
122+
.filter_map(|q| q.try_into().ok())
123+
.collect();
124+
125+
let columns: Vec<String> = results.iter().map(|c| c.get_column(sql)).collect();
126+
127+
assert_eq!(columns, vec!["id", "email", "name"]);
128+
}
129+
130+
#[test]
131+
fn returns_empty_for_insert_without_columns() {
132+
let sql = r#"insert into users values (1, 'a@b.com', 'Alice');"#;
133+
134+
let mut parser = tree_sitter::Parser::new();
135+
parser.set_language(tree_sitter_sql::language()).unwrap();
136+
137+
let tree = parser.parse(sql, None).unwrap();
138+
139+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
140+
141+
executor.add_query_results::<InsertColumnMatch>();
142+
143+
let results: Vec<&InsertColumnMatch> = executor
144+
.get_iter(None)
145+
.filter_map(|q| q.try_into().ok())
146+
.collect();
147+
148+
assert!(results.is_empty());
149+
}
150+
}

crates/pgt_treesitter_queries/src/queries/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
mod insert_columns;
12
mod relations;
23
mod select_columns;
34
mod table_aliases;
45

6+
pub use insert_columns::*;
57
pub use relations::*;
68
pub use select_columns::*;
79
pub use table_aliases::*;
@@ -11,6 +13,7 @@ pub enum QueryResult<'a> {
1113
Relation(RelationMatch<'a>),
1214
TableAliases(TableAliasMatch<'a>),
1315
SelectClauseColumns(SelectColumnMatch<'a>),
16+
InsertClauseColumns(InsertColumnMatch<'a>),
1417
}
1518

1619
impl QueryResult<'_> {
@@ -41,6 +44,11 @@ impl QueryResult<'_> {
4144

4245
start >= range.start_point && end <= range.end_point
4346
}
47+
Self::InsertClauseColumns(cm) => {
48+
let start = cm.column.start_position();
49+
let end = cm.column.end_position();
50+
start >= range.start_point && end <= range.end_point
51+
}
4452
}
4553
}
4654
}

0 commit comments

Comments
 (0)