Skip to content

feat(completions): improve completions in WHERE clauses #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions crates/pgt_completions/src/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ impl<'a> CompletionContext<'a> {
executor.add_query_results::<queries::TableAliasMatch>();
executor.add_query_results::<queries::SelectColumnMatch>();
executor.add_query_results::<queries::InsertColumnMatch>();
executor.add_query_results::<queries::WhereColumnMatch>();

for relation_match in executor.get_iter(stmt_range) {
match relation_match {
Expand All @@ -251,6 +252,7 @@ impl<'a> CompletionContext<'a> {
})
.or_insert(HashSet::from([table_name]));
}

QueryResult::TableAliases(table_alias_match) => {
self.mentioned_table_aliases.insert(
table_alias_match.get_alias(sql),
Expand All @@ -272,6 +274,20 @@ impl<'a> CompletionContext<'a> {
.or_insert(HashSet::from([mentioned]));
}

QueryResult::WhereClauseColumns(c) => {
let mentioned = MentionedColumn {
column: c.get_column(sql),
alias: c.get_alias(sql),
};

self.mentioned_columns
.entry(Some(WrappingClause::Where))
.and_modify(|s| {
s.insert(mentioned.clone());
})
.or_insert(HashSet::from([mentioned]));
}

QueryResult::InsertClauseColumns(c) => {
let mentioned = MentionedColumn {
column: c.get_column(sql),
Expand Down Expand Up @@ -359,8 +375,9 @@ impl<'a> CompletionContext<'a> {
let parent_node_kind = parent_node.kind();
let current_node_kind = current_node.kind();

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node_kind == parent_node_kind {
// prevent infinite recursion – this can happen with ERROR nodes
if current_node_kind == parent_node_kind && ["ERROR", "program"].contains(&parent_node_kind)
{
self.node_under_cursor = Some(NodeUnderCursor::from(current_node));
return;
}
Expand Down
86 changes: 85 additions & 1 deletion crates/pgt_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio
};

// autocomplete with the alias in a join clause if we find one
if matches!(ctx.wrapping_clause_type, Some(WrappingClause::Join { .. })) {
if matches!(
ctx.wrapping_clause_type,
Some(WrappingClause::Join { .. })
| Some(WrappingClause::Where)
| Some(WrappingClause::Select)
) {
item.completion_text = find_matching_alias_for_table(ctx, col.table_name.as_str())
.and_then(|alias| {
get_completion_text_with_schema_or_alias(ctx, col.name.as_str(), alias.as_str())
Expand All @@ -36,6 +41,8 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio

#[cfg(test)]
mod tests {
use std::vec;

use crate::{
CompletionItem, CompletionItemKind, complete,
test_helper::{
Expand Down Expand Up @@ -643,4 +650,81 @@ mod tests {
)
.await;
}

#[tokio::test]
async fn suggests_columns_in_where_clause() {
let setup = r#"
create table instruments (
id bigint primary key generated always as identity,
name text not null,
z text,
created_at timestamp with time zone default now()
);

create table others (
a text,
b text,
c text
);
"#;

assert_complete_results(
format!("select name from instruments where {} ", CURSOR_POS).as_str(),
vec![
CompletionAssertion::Label("created_at".into()),
CompletionAssertion::Label("id".into()),
CompletionAssertion::Label("name".into()),
CompletionAssertion::Label("z".into()),
],
setup,
)
.await;

assert_complete_results(
format!(
"select name from instruments where z = 'something' and created_at > {}",
CURSOR_POS
)
.as_str(),
// simply do not complete columns + schemas; functions etc. are ok
vec![
CompletionAssertion::KindNotExists(CompletionItemKind::Column),
CompletionAssertion::KindNotExists(CompletionItemKind::Schema),
],
setup,
)
.await;

// prefers not mentioned columns
assert_complete_results(
format!(
"select name from instruments where id = 'something' and {}",
CURSOR_POS
)
.as_str(),
vec![
CompletionAssertion::Label("created_at".into()),
CompletionAssertion::Label("name".into()),
CompletionAssertion::Label("z".into()),
],
setup,
)
.await;

// // uses aliases
assert_complete_results(
format!(
"select name from instruments i join others o on i.z = o.a where i.{}",
CURSOR_POS
)
.as_str(),
vec![
CompletionAssertion::Label("created_at".into()),
CompletionAssertion::Label("id".into()),
CompletionAssertion::Label("name".into()),
],
setup,
)
.await;
}
}
12 changes: 11 additions & 1 deletion crates/pgt_completions/src/relevance/filtering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ impl CompletionFilter<'_> {
.as_ref()
.is_some_and(|n| n == &WrappingNode::List),

// only autocomplete left side of binary expression
WrappingClause::Where => {
ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"])
|| (ctx.before_cursor_matches_kind(&["."])
&& ctx.parent_matches_one_of_kind(&["field"]))
}

_ => true,
}
}
Expand All @@ -133,12 +140,15 @@ impl CompletionFilter<'_> {

CompletionRelevanceData::Schema(_) => match clause {
WrappingClause::Select
| WrappingClause::Where
| WrappingClause::From
| WrappingClause::Join { .. }
| WrappingClause::Update
| WrappingClause::Delete => true,

WrappingClause::Where => {
ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"])
}

WrappingClause::DropTable | WrappingClause::AlterTable => ctx
.before_cursor_matches_kind(&[
"keyword_exists",
Expand Down
19 changes: 6 additions & 13 deletions crates/pgt_completions/src/sanitization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,6 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool
return false;
}

// not okay to be on the semi.
if byte == leaf_node.start_byte() {
return false;
}

leaf_node
.prev_named_sibling()
.map(|n| n.end_byte() < byte)
Expand Down Expand Up @@ -355,19 +350,17 @@ mod tests {
// select * from| ; <-- still touches the from
assert!(!cursor_before_semicolon(&tree, TextSize::new(13)));

// not okay to be ON the semi.
// select * from |;
assert!(!cursor_before_semicolon(&tree, TextSize::new(18)));

// anything is fine here
// select * from | ;
// select * from | ;
// select * from | ;
// select * from |;
// select * from | ;
// select * from | ;
// select * from | ;
// select * from | ;
// select * from |;
assert!(cursor_before_semicolon(&tree, TextSize::new(14)));
assert!(cursor_before_semicolon(&tree, TextSize::new(15)));
assert!(cursor_before_semicolon(&tree, TextSize::new(16)));
assert!(cursor_before_semicolon(&tree, TextSize::new(17)));
assert!(cursor_before_semicolon(&tree, TextSize::new(18)));
}

#[test]
Expand Down
13 changes: 13 additions & 0 deletions crates/pgt_treesitter_queries/src/queries/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ mod parameters;
mod relations;
mod select_columns;
mod table_aliases;
mod where_columns;

pub use insert_columns::*;
pub use parameters::*;
pub use relations::*;
pub use select_columns::*;
pub use table_aliases::*;
pub use where_columns::*;

#[derive(Debug)]
pub enum QueryResult<'a> {
Expand All @@ -17,6 +19,7 @@ pub enum QueryResult<'a> {
TableAliases(TableAliasMatch<'a>),
SelectClauseColumns(SelectColumnMatch<'a>),
InsertClauseColumns(InsertColumnMatch<'a>),
WhereClauseColumns(WhereColumnMatch<'a>),
}

impl QueryResult<'_> {
Expand Down Expand Up @@ -53,6 +56,16 @@ impl QueryResult<'_> {

start >= range.start_point && end <= range.end_point
}
Self::WhereClauseColumns(cm) => {
let start = match cm.alias {
Some(n) => n.start_position(),
None => cm.column.start_position(),
};

let end = cm.column.end_position();

start >= range.start_point && end <= range.end_point
}
Self::InsertClauseColumns(cm) => {
let start = cm.column.start_position();
let end = cm.column.end_position();
Expand Down
96 changes: 96 additions & 0 deletions crates/pgt_treesitter_queries/src/queries/where_columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::sync::LazyLock;

use crate::{Query, QueryResult};

use super::QueryTryFrom;

static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
static QUERY_STR: &str = r#"
(where
(binary_expression
(binary_expression
(field
(object_reference)? @alias
"."?
(identifier) @column
)
)
)
)
"#;
tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query")
});

#[derive(Debug)]
pub struct WhereColumnMatch<'a> {
pub(crate) alias: Option<tree_sitter::Node<'a>>,
pub(crate) column: tree_sitter::Node<'a>,
}

impl WhereColumnMatch<'_> {
pub fn get_alias(&self, sql: &str) -> Option<String> {
let str = self
.alias
.as_ref()?
.utf8_text(sql.as_bytes())
.expect("Failed to get alias from ColumnMatch");

Some(str.to_string())
}

pub fn get_column(&self, sql: &str) -> String {
self.column
.utf8_text(sql.as_bytes())
.expect("Failed to get column from ColumnMatch")
.to_string()
}
}

impl<'a> TryFrom<&'a QueryResult<'a>> for &'a WhereColumnMatch<'a> {
type Error = String;

fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
match q {
QueryResult::WhereClauseColumns(c) => Ok(c),

#[allow(unreachable_patterns)]
_ => Err("Invalid QueryResult type".into()),
}
}
}

impl<'a> QueryTryFrom<'a> for WhereColumnMatch<'a> {
type Ref = &'a WhereColumnMatch<'a>;
}

impl<'a> Query<'a> for WhereColumnMatch<'a> {
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
let mut cursor = tree_sitter::QueryCursor::new();

let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());

let mut to_return = vec![];

for m in matches {
if m.captures.len() == 1 {
let capture = m.captures[0].node;
to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch {
alias: None,
column: capture,
}));
}

if m.captures.len() == 2 {
let alias = m.captures[0].node;
let column = m.captures[1].node;

to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch {
alias: Some(alias),
column,
}));
}
}

to_return
}
}