diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 23a6fcae..7ae5ab27 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -1,7 +1,9 @@ +use std::{ + cmp, + collections::{HashMap, HashSet}, +}; mod policy_parser; -use std::collections::{HashMap, HashSet}; - use pgt_schema_cache::SchemaCache; use pgt_text_size::TextRange; use pgt_treesitter_queries::{ @@ -15,7 +17,7 @@ use crate::{ sanitization::SanitizedCompletionParams, }; -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum WrappingClause<'a> { Select, Where, @@ -25,6 +27,10 @@ pub enum WrappingClause<'a> { }, Update, Delete, + ColumnDefinitions, + Insert, + AlterTable, + DropTable, PolicyName, ToRoleAssignment, } @@ -48,6 +54,7 @@ pub enum WrappingNode { Relation, BinaryExpression, Assignment, + List, } #[derive(Debug)] @@ -97,6 +104,7 @@ impl TryFrom<&str> for WrappingNode { "relation" => Ok(Self::Relation), "assignment" => Ok(Self::Assignment), "binary_expression" => Ok(Self::BinaryExpression), + "list" => Ok(Self::List), _ => { let message = format!("Unimplemented Relation: {}", value); @@ -118,6 +126,7 @@ impl TryFrom for WrappingNode { } } +#[derive(Debug)] pub(crate) struct CompletionContext<'a> { pub node_under_cursor: Option>, @@ -152,9 +161,6 @@ pub(crate) struct CompletionContext<'a> { pub is_invocation: bool, pub wrapping_statement_range: Option, - /// Some incomplete statements can't be correctly parsed by TreeSitter. - pub is_in_error_node: bool, - pub mentioned_relations: HashMap, HashSet>, pub mentioned_table_aliases: HashMap, pub mentioned_columns: HashMap>, HashSet>, @@ -176,7 +182,6 @@ impl<'a> CompletionContext<'a> { mentioned_relations: HashMap::new(), mentioned_table_aliases: HashMap::new(), mentioned_columns: HashMap::new(), - is_in_error_node: false, }; // policy handling is important to Supabase, but they are a PostgreSQL specific extension, @@ -189,6 +194,14 @@ impl<'a> CompletionContext<'a> { ctx.gather_info_from_ts_queries(); } + // if cfg!(test) { + // println!("{:?}", ctx.position); + // println!("{:?}", ctx.text); + // println!("{:?}", ctx.wrapping_clause_type); + // println!("{:?}", ctx.wrapping_node_kind); + // println!("{:?}", ctx.before_cursor_matches_kind(&["keyword_table"])); + // } + ctx } @@ -316,10 +329,20 @@ impl<'a> CompletionContext<'a> { * `select * from use {}` becomes `select * from use{}`. */ let current_node = cursor.node(); - while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { - self.position -= 1; + + let mut chars = self.text.chars(); + + if chars + .nth(self.position) + .is_some_and(|c| !c.is_ascii_whitespace() && c != ';') + { + self.position = cmp::min(self.position + 1, self.text.len()); + } else { + self.position = cmp::min(self.position, self.text.len()); } + cursor.goto_first_child_for_byte(self.position); + self.gather_context_from_node(cursor, current_node); } @@ -351,24 +374,12 @@ impl<'a> CompletionContext<'a> { } // try to gather context from the siblings if we're within an error node. - if self.is_in_error_node { - let mut next_sibling = current_node.next_named_sibling(); - while let Some(n) = next_sibling { - if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { - self.wrapping_clause_type = Some(clause_type); - break; - } else { - next_sibling = n.next_named_sibling(); - } + if parent_node_kind == "ERROR" { + if let Some(clause_type) = self.get_wrapping_clause_from_siblings(current_node) { + self.wrapping_clause_type = Some(clause_type); } - let mut prev_sibling = current_node.prev_named_sibling(); - while let Some(n) = prev_sibling { - if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { - self.wrapping_clause_type = Some(clause_type); - break; - } else { - prev_sibling = n.prev_named_sibling(); - } + if let Some(wrapping_node) = self.get_wrapping_node_from_siblings(current_node) { + self.wrapping_node_kind = Some(wrapping_node) } } @@ -388,19 +399,16 @@ impl<'a> CompletionContext<'a> { } } - "where" | "update" | "select" | "delete" | "from" | "join" => { + "where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions" + | "drop_table" | "alter_table" => { self.wrapping_clause_type = self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } - "relation" | "binary_expression" | "assignment" => { + "relation" | "binary_expression" | "assignment" | "list" => { self.wrapping_node_kind = current_node_kind.try_into().ok(); } - "ERROR" => { - self.is_in_error_node = true; - } - _ => {} } @@ -414,31 +422,99 @@ impl<'a> CompletionContext<'a> { self.gather_context_from_node(cursor, current_node); } - fn get_wrapping_clause_from_keyword_node( + fn get_first_sibling(&self, node: tree_sitter::Node<'a>) -> tree_sitter::Node<'a> { + let mut first_sibling = node; + while let Some(n) = first_sibling.prev_sibling() { + first_sibling = n; + } + first_sibling + } + + fn get_wrapping_node_from_siblings(&self, node: tree_sitter::Node<'a>) -> Option { + self.wrapping_clause_type + .as_ref() + .and_then(|clause| match clause { + WrappingClause::Insert => { + if node.prev_sibling().is_some_and(|n| n.kind() == "(") + || node.next_sibling().is_some_and(|n| n.kind() == ")") + { + Some(WrappingNode::List) + } else { + None + } + } + _ => None, + }) + } + + fn get_wrapping_clause_from_siblings( &self, node: tree_sitter::Node<'a>, ) -> Option> { - if node.kind().starts_with("keyword_") { - if let Some(txt) = self.get_ts_node_content(&node).and_then(|txt| match txt { - NodeText::Original(txt) => Some(txt), - NodeText::Replaced => None, - }) { - match txt.as_str() { - "where" => return Some(WrappingClause::Where), - "update" => return Some(WrappingClause::Update), - "select" => return Some(WrappingClause::Select), - "delete" => return Some(WrappingClause::Delete), - "from" => return Some(WrappingClause::From), - "join" => { - // TODO: not sure if we can infer it here. - return Some(WrappingClause::Join { on_node: None }); + let clause_combinations: Vec<(WrappingClause, &[&'static str])> = vec![ + (WrappingClause::Where, &["where"]), + (WrappingClause::Update, &["update"]), + (WrappingClause::Select, &["select"]), + (WrappingClause::Delete, &["delete"]), + (WrappingClause::Insert, &["insert", "into"]), + (WrappingClause::From, &["from"]), + (WrappingClause::Join { on_node: None }, &["join"]), + (WrappingClause::AlterTable, &["alter", "table"]), + ( + WrappingClause::AlterTable, + &["alter", "table", "if", "exists"], + ), + (WrappingClause::DropTable, &["drop", "table"]), + ( + WrappingClause::DropTable, + &["drop", "table", "if", "exists"], + ), + ]; + + let first_sibling = self.get_first_sibling(node); + + /* + * For each clause, we'll iterate from first_sibling to the next ones, + * either until the end or until we land on the node under the cursor. + * We'll score the `WrappingClause` by how many tokens it matches in order. + */ + let mut clauses_with_score: Vec<(WrappingClause, usize)> = clause_combinations + .into_iter() + .map(|(clause, tokens)| { + let mut idx = 0; + + let mut sibling = Some(first_sibling); + while let Some(sib) = sibling { + if sib.end_byte() >= node.end_byte() || idx >= tokens.len() { + break; } - _ => {} + + if let Some(sibling_content) = + self.get_ts_node_content(&sib).and_then(|txt| match txt { + NodeText::Original(txt) => Some(txt), + NodeText::Replaced => None, + }) + { + if sibling_content == tokens[idx] { + idx += 1; + } + } else { + break; + } + + sibling = sib.next_sibling(); } - }; - } - None + (clause, idx) + }) + .collect(); + + clauses_with_score.sort_by(|(_, score_a), (_, score_b)| score_b.cmp(score_a)); + clauses_with_score + .iter() + .filter(|(_, score)| *score > 0) + .next() + .map(|c| c.0.clone()) } fn get_wrapping_clause_from_current_node( @@ -452,6 +528,10 @@ impl<'a> CompletionContext<'a> { "select" => Some(WrappingClause::Select), "delete" => Some(WrappingClause::Delete), "from" => Some(WrappingClause::From), + "drop_table" => Some(WrappingClause::DropTable), + "alter_table" => Some(WrappingClause::AlterTable), + "column_definitions" => Some(WrappingClause::ColumnDefinitions), + "insert" => Some(WrappingClause::Insert), "join" => { // sadly, we need to manually iterate over the children – // `node.child_by_field_id(..)` does not work as expected @@ -468,6 +548,27 @@ impl<'a> CompletionContext<'a> { _ => None, } } + + pub(crate) fn before_cursor_matches_kind(&self, kinds: &[&'static str]) -> bool { + self.node_under_cursor.as_ref().is_some_and(|under_cursor| { + match under_cursor { + NodeUnderCursor::TsNode(node) => { + let mut current = node.clone(); + + // move up to the parent until we're at top OR we have a prev sibling + while current.prev_sibling().is_none() && current.parent().is_some() { + current = current.parent().unwrap(); + } + + current + .prev_sibling() + .is_some_and(|sib| kinds.contains(&sib.kind())) + } + + NodeUnderCursor::CustomNode { .. } => false, + } + }) + } } #[cfg(test)] diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 8109ba83..9dc7bfa9 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -573,4 +573,24 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_columns_in_insert_clause() { + let setup = r#" + create table instruments ( + id bigint primary key generated always as identity, + name text not null + ); + "#; + + assert_complete_results( + format!("insert into instruments ({})", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("id".to_string()), + CompletionAssertion::Label("name".to_string()), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 57195da7..217db91f 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -310,4 +310,73 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_tables_in_alter_and_drop_statements() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + + create table auth.posts ( + pid serial primary key, + user_id int not null references auth.users(uid), + title text not null, + content text, + created_at timestamp default now() + ); + "#; + + assert_complete_results( + format!("alter table {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("alter table if exists {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("drop table {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("drop table if exists {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 3b148336..725c175d 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,4 +1,4 @@ -use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause}; +use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause, WrappingNode}; use super::CompletionRelevanceData; @@ -24,6 +24,10 @@ impl CompletionFilter<'_> { } fn completable_context(&self, ctx: &CompletionContext) -> Option<()> { + if ctx.wrapping_node_kind.is_none() && ctx.wrapping_clause_type.is_none() { + return None; + } + let current_node_kind = ctx .node_under_cursor .as_ref() @@ -65,55 +69,99 @@ impl CompletionFilter<'_> { } fn check_clause(&self, ctx: &CompletionContext) -> Option<()> { - let clause = ctx.wrapping_clause_type.as_ref(); - - let in_clause = |compare: WrappingClause| clause.is_some_and(|c| c == &compare); - - match self.data { - CompletionRelevanceData::Table(_) => { - if in_clause(WrappingClause::Select) - || in_clause(WrappingClause::Where) - || in_clause(WrappingClause::PolicyName) - { - return None; - }; - } - CompletionRelevanceData::Column(_) => { - if in_clause(WrappingClause::From) || in_clause(WrappingClause::PolicyName) { - return None; - } - - // We can complete columns in JOIN cluases, but only if we are after the - // ON node in the "ON u.id = posts.user_id" part. - let in_join_clause_before_on_node = clause.is_some_and(|c| match c { - // we are in a JOIN, but definitely not after an ON - WrappingClause::Join { on_node: None } => true, - - WrappingClause::Join { on_node: Some(on) } => ctx - .node_under_cursor - .as_ref() - .is_some_and(|n| n.end_byte() < on.start_byte()), - - _ => false, - }); - - if in_join_clause_before_on_node { - return None; - } - } - CompletionRelevanceData::Policy(_) => { - if clause.is_none_or(|c| c != &WrappingClause::PolicyName) { - return None; - } - } - _ => { - if in_clause(WrappingClause::PolicyName) { - return None; + ctx.wrapping_clause_type + .as_ref() + .map(|clause| { + match self.data { + CompletionRelevanceData::Table(_) => match clause { + WrappingClause::Select + | WrappingClause::Where + | WrappingClause::ColumnDefinitions => false, + + WrappingClause::Insert => { + ctx.wrapping_node_kind + .as_ref() + .is_some_and(|n| n != &WrappingNode::List) + && ctx.before_cursor_matches_kind(&["keyword_into"]) + } + + WrappingClause::DropTable | WrappingClause::AlterTable => ctx + .before_cursor_matches_kind(&[ + "keyword_exists", + "keyword_only", + "keyword_table", + ]), + + _ => true, + }, + + CompletionRelevanceData::Column(_) => { + match clause { + WrappingClause::From + | WrappingClause::ColumnDefinitions + | WrappingClause::AlterTable + | WrappingClause::DropTable => false, + + // We can complete columns in JOIN cluases, but only if we are after the + // ON node in the "ON u.id = posts.user_id" part. + WrappingClause::Join { on_node: Some(on) } => ctx + .node_under_cursor + .as_ref() + .is_some_and(|cn| cn.start_byte() >= on.end_byte()), + + // we are in a JOIN, but definitely not after an ON + WrappingClause::Join { on_node: None } => false, + + WrappingClause::Insert => ctx + .wrapping_node_kind + .as_ref() + .is_some_and(|n| n == &WrappingNode::List), + + _ => true, + } + } + + CompletionRelevanceData::Function(_) => match clause { + WrappingClause::From + | WrappingClause::Select + | WrappingClause::Where + | WrappingClause::Join { .. } => true, + + _ => false, + }, + + CompletionRelevanceData::Schema(_) => match clause { + WrappingClause::Select + | WrappingClause::Where + | WrappingClause::From + | WrappingClause::Join { .. } + | WrappingClause::Update + | WrappingClause::Delete => true, + + WrappingClause::DropTable | WrappingClause::AlterTable => ctx + .before_cursor_matches_kind(&[ + "keyword_exists", + "keyword_only", + "keyword_table", + ]), + + WrappingClause::Insert => { + ctx.wrapping_node_kind + .as_ref() + .is_some_and(|n| n != &WrappingNode::List) + && ctx.before_cursor_matches_kind(&["keyword_into"]) + } + + _ => false, + }, + + CompletionRelevanceData::Policy(_) => match clause { + WrappingClause::PolicyName => true, + _ => false, + }, } - } - } - - Some(()) + }) + .and_then(|is_ok| if is_ok { Some(()) } else { None }) } fn check_invocation(&self, ctx: &CompletionContext) -> Option<()> { @@ -188,4 +236,15 @@ mod tests { ) .await; } + + #[tokio::test] + async fn completion_after_create_table() { + assert_no_complete_results(format!("create table {}", CURSOR_POS).as_str(), "").await; + } + + #[tokio::test] + async fn completion_in_column_definitions() { + let query = format!(r#"create table instruments ( {} )"#, CURSOR_POS); + assert_no_complete_results(query.as_str(), "").await; + } } diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 6aa75a16..75876887 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -53,6 +53,7 @@ where || cursor_prepared_to_write_token_after_last_node(¶ms.text, params.position) || cursor_before_semicolon(params.tree, params.position) || cursor_on_a_dot(¶ms.text, params.position) + || cursor_between_parentheses(¶ms.text, params.position) { SanitizedCompletionParams::with_adjusted_sql(params) } else { @@ -203,13 +204,19 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool .unwrap_or(false) } +fn cursor_between_parentheses(sql: &str, position: TextSize) -> bool { + let position: usize = position.into(); + let mut chars = sql.chars(); + chars.nth(position - 1).is_some_and(|c| c == '(') && chars.next().is_some_and(|c| c == ')') +} + #[cfg(test)] mod tests { use pgt_text_size::TextSize; use crate::sanitization::{ - cursor_before_semicolon, cursor_inbetween_nodes, cursor_on_a_dot, - cursor_prepared_to_write_token_after_last_node, + cursor_before_semicolon, cursor_between_parentheses, cursor_inbetween_nodes, + cursor_on_a_dot, cursor_prepared_to_write_token_after_last_node, }; #[test] @@ -306,4 +313,18 @@ mod tests { assert!(cursor_before_semicolon(&tree, TextSize::new(16))); assert!(cursor_before_semicolon(&tree, TextSize::new(17))); } + + #[test] + fn between_parentheses() { + let input = "insert into instruments ()"; + + // insert into (|) <- right in the parentheses + assert!(cursor_between_parentheses(input, TextSize::new(25))); + + // insert into ()| <- too late + assert!(!cursor_between_parentheses(input, TextSize::new(26))); + + // insert into |() <- too early + assert!(!cursor_between_parentheses(input, TextSize::new(24))); + } } diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 937c11af..f3d5c2bf 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -244,6 +244,8 @@ pub(crate) async fn assert_complete_results( pub(crate) async fn assert_no_complete_results(query: &str, setup: &str) { let (tree, cache) = get_test_deps(setup, query.into()).await; let params = get_test_params(&tree, &cache, query.into()); + println!("{:#?}", params.position); + println!("{:#?}", params.text); let items = complete(params); assert_eq!(items.len(), 0)