diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index d034de09..0bb190a9 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -237,6 +237,7 @@ impl<'a> CompletionContext<'a> { executor.add_query_results::(); executor.add_query_results::(); executor.add_query_results::(); + executor.add_query_results::(); for relation_match in executor.get_iter(stmt_range) { match relation_match { @@ -251,6 +252,7 @@ impl<'a> CompletionContext<'a> { }) .or_insert(HashSet::from([table_name])); } + QueryResult::TableAliases(table_alias_match) => { self.mentioned_table_aliases.insert( table_alias_match.get_alias(sql), @@ -272,6 +274,20 @@ impl<'a> CompletionContext<'a> { .or_insert(HashSet::from([mentioned])); } + QueryResult::WhereClauseColumns(c) => { + let mentioned = MentionedColumn { + column: c.get_column(sql), + alias: c.get_alias(sql), + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Where)) + .and_modify(|s| { + s.insert(mentioned.clone()); + }) + .or_insert(HashSet::from([mentioned])); + } + QueryResult::InsertClauseColumns(c) => { let mentioned = MentionedColumn { column: c.get_column(sql), @@ -359,8 +375,9 @@ impl<'a> CompletionContext<'a> { let parent_node_kind = parent_node.kind(); let current_node_kind = current_node.kind(); - // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node_kind == parent_node_kind { + // prevent infinite recursion – this can happen with ERROR nodes + if current_node_kind == parent_node_kind && ["ERROR", "program"].contains(&parent_node_kind) + { self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); return; } diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 148504b9..a040bab1 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -23,7 +23,12 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio }; // autocomplete with the alias in a join clause if we find one - if matches!(ctx.wrapping_clause_type, Some(WrappingClause::Join { .. })) { + if matches!( + ctx.wrapping_clause_type, + Some(WrappingClause::Join { .. }) + | Some(WrappingClause::Where) + | Some(WrappingClause::Select) + ) { item.completion_text = find_matching_alias_for_table(ctx, col.table_name.as_str()) .and_then(|alias| { get_completion_text_with_schema_or_alias(ctx, col.name.as_str(), alias.as_str()) @@ -36,6 +41,8 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio #[cfg(test)] mod tests { + use std::vec; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ @@ -643,4 +650,81 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_columns_in_where_clause() { + let setup = r#" + create table instruments ( + id bigint primary key generated always as identity, + name text not null, + z text, + created_at timestamp with time zone default now() + ); + + create table others ( + a text, + b text, + c text + ); + "#; + + assert_complete_results( + format!("select name from instruments where {} ", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + CompletionAssertion::Label("z".into()), + ], + setup, + ) + .await; + + assert_complete_results( + format!( + "select name from instruments where z = 'something' and created_at > {}", + CURSOR_POS + ) + .as_str(), + // simply do not complete columns + schemas; functions etc. are ok + vec![ + CompletionAssertion::KindNotExists(CompletionItemKind::Column), + CompletionAssertion::KindNotExists(CompletionItemKind::Schema), + ], + setup, + ) + .await; + + // prefers not mentioned columns + assert_complete_results( + format!( + "select name from instruments where id = 'something' and {}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("name".into()), + CompletionAssertion::Label("z".into()), + ], + setup, + ) + .await; + + // // uses aliases + assert_complete_results( + format!( + "select name from instruments i join others o on i.z = o.a where i.{}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index cb6d2cf6..5323e2bc 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -119,6 +119,13 @@ impl CompletionFilter<'_> { .as_ref() .is_some_and(|n| n == &WrappingNode::List), + // only autocomplete left side of binary expression + WrappingClause::Where => { + ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"]) + || (ctx.before_cursor_matches_kind(&["."]) + && ctx.parent_matches_one_of_kind(&["field"])) + } + _ => true, } } @@ -133,12 +140,15 @@ impl CompletionFilter<'_> { CompletionRelevanceData::Schema(_) => match clause { WrappingClause::Select - | WrappingClause::Where | WrappingClause::From | WrappingClause::Join { .. } | WrappingClause::Update | WrappingClause::Delete => true, + WrappingClause::Where => { + ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"]) + } + WrappingClause::DropTable | WrappingClause::AlterTable => ctx .before_cursor_matches_kind(&[ "keyword_exists", diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 8b1cae6e..40dea7e6 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -193,11 +193,6 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool return false; } - // not okay to be on the semi. - if byte == leaf_node.start_byte() { - return false; - } - leaf_node .prev_named_sibling() .map(|n| n.end_byte() < byte) @@ -355,19 +350,17 @@ mod tests { // select * from| ; <-- still touches the from assert!(!cursor_before_semicolon(&tree, TextSize::new(13))); - // not okay to be ON the semi. - // select * from |; - assert!(!cursor_before_semicolon(&tree, TextSize::new(18))); - // anything is fine here - // select * from | ; - // select * from | ; - // select * from | ; - // select * from |; + // select * from | ; + // select * from | ; + // select * from | ; + // select * from | ; + // select * from |; assert!(cursor_before_semicolon(&tree, TextSize::new(14))); assert!(cursor_before_semicolon(&tree, TextSize::new(15))); assert!(cursor_before_semicolon(&tree, TextSize::new(16))); assert!(cursor_before_semicolon(&tree, TextSize::new(17))); + assert!(cursor_before_semicolon(&tree, TextSize::new(18))); } #[test] diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index 2d957872..b9f39aed 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -3,12 +3,14 @@ mod parameters; mod relations; mod select_columns; mod table_aliases; +mod where_columns; pub use insert_columns::*; pub use parameters::*; pub use relations::*; pub use select_columns::*; pub use table_aliases::*; +pub use where_columns::*; #[derive(Debug)] pub enum QueryResult<'a> { @@ -17,6 +19,7 @@ pub enum QueryResult<'a> { TableAliases(TableAliasMatch<'a>), SelectClauseColumns(SelectColumnMatch<'a>), InsertClauseColumns(InsertColumnMatch<'a>), + WhereClauseColumns(WhereColumnMatch<'a>), } impl QueryResult<'_> { @@ -53,6 +56,16 @@ impl QueryResult<'_> { start >= range.start_point && end <= range.end_point } + Self::WhereClauseColumns(cm) => { + let start = match cm.alias { + Some(n) => n.start_position(), + None => cm.column.start_position(), + }; + + let end = cm.column.end_position(); + + start >= range.start_point && end <= range.end_point + } Self::InsertClauseColumns(cm) => { let start = cm.column.start_position(); let end = cm.column.end_position(); diff --git a/crates/pgt_treesitter_queries/src/queries/where_columns.rs b/crates/pgt_treesitter_queries/src/queries/where_columns.rs new file mode 100644 index 00000000..8e19590d --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/where_columns.rs @@ -0,0 +1,96 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" + (where + (binary_expression + (binary_expression + (field + (object_reference)? @alias + "."? + (identifier) @column + ) + ) + ) + ) +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct WhereColumnMatch<'a> { + pub(crate) alias: Option>, + pub(crate) column: tree_sitter::Node<'a>, +} + +impl WhereColumnMatch<'_> { + pub fn get_alias(&self, sql: &str) -> Option { + let str = self + .alias + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get alias from ColumnMatch"); + + Some(str.to_string()) + } + + pub fn get_column(&self, sql: &str) -> String { + self.column + .utf8_text(sql.as_bytes()) + .expect("Failed to get column from ColumnMatch") + .to_string() + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a WhereColumnMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::WhereClauseColumns(c) => Ok(c), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for WhereColumnMatch<'a> { + type Ref = &'a WhereColumnMatch<'a>; +} + +impl<'a> Query<'a> for WhereColumnMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch { + alias: None, + column: capture, + })); + } + + if m.captures.len() == 2 { + let alias = m.captures[0].node; + let column = m.captures[1].node; + + to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch { + alias: Some(alias), + column, + })); + } + } + + to_return + } +}