Skip to content

[WIP] feats(completions): complete insert, drop/alter table, blurb #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 143 additions & 50 deletions crates/pgt_completions/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::collections::{HashMap, HashSet};
use std::{
cmp,
collections::{HashMap, HashSet},
};

use pgt_schema_cache::SchemaCache;
use pgt_treesitter_queries::{
Expand All @@ -8,7 +11,7 @@ use pgt_treesitter_queries::{

use crate::sanitization::SanitizedCompletionParams;

#[derive(Debug, PartialEq, Eq, Hash)]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub enum WrappingClause<'a> {
Select,
Where,
Expand All @@ -18,6 +21,10 @@ pub enum WrappingClause<'a> {
},
Update,
Delete,
ColumnDefinitions,
Insert,
AlterTable,
DropTable,
}

#[derive(PartialEq, Eq, Debug)]
Expand Down Expand Up @@ -45,6 +52,7 @@ pub enum WrappingNode {
Relation,
BinaryExpression,
Assignment,
List,
}

impl TryFrom<&str> for WrappingNode {
Expand All @@ -55,6 +63,7 @@ impl TryFrom<&str> for WrappingNode {
"relation" => Ok(Self::Relation),
"assignment" => Ok(Self::Assignment),
"binary_expression" => Ok(Self::BinaryExpression),
"list" => Ok(Self::List),
_ => {
let message = format!("Unimplemented Relation: {}", value);

Expand All @@ -76,6 +85,7 @@ impl TryFrom<String> for WrappingNode {
}
}

#[derive(Debug)]
pub(crate) struct CompletionContext<'a> {
pub node_under_cursor: Option<tree_sitter::Node<'a>>,

Expand Down Expand Up @@ -110,9 +120,6 @@ pub(crate) struct CompletionContext<'a> {
pub is_invocation: bool,
pub wrapping_statement_range: Option<tree_sitter::Range>,

/// Some incomplete statements can't be correctly parsed by TreeSitter.
pub is_in_error_node: bool,

pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
pub mentioned_table_aliases: HashMap<String, String>,
pub mentioned_columns: HashMap<Option<WrappingClause<'a>>, HashSet<MentionedColumn>>,
Expand All @@ -134,12 +141,19 @@ impl<'a> CompletionContext<'a> {
mentioned_relations: HashMap::new(),
mentioned_table_aliases: HashMap::new(),
mentioned_columns: HashMap::new(),
is_in_error_node: false,
};

ctx.gather_tree_context();
ctx.gather_info_from_ts_queries();

// if cfg!(test) {
// println!("{:?}", ctx.position);
// println!("{:?}", ctx.text);
// println!("{:?}", ctx.wrapping_clause_type);
// println!("{:?}", ctx.wrapping_node_kind);
// println!("{:?}", ctx.before_cursor_matches_kind(&["keyword_table"]));
// }

ctx
}

Expand Down Expand Up @@ -231,10 +245,20 @@ impl<'a> CompletionContext<'a> {
* `select * from use {}` becomes `select * from use{}`.
*/
let current_node = cursor.node();
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
self.position -= 1;

let mut chars = self.text.chars();

if chars
.nth(self.position)
.is_some_and(|c| !c.is_ascii_whitespace() && c != ';')
{
self.position = cmp::min(self.position + 1, self.text.len());
} else {
self.position = cmp::min(self.position, self.text.len());
}

cursor.goto_first_child_for_byte(self.position);

self.gather_context_from_node(cursor, current_node);
}

Expand Down Expand Up @@ -266,24 +290,12 @@ impl<'a> CompletionContext<'a> {
}

// try to gather context from the siblings if we're within an error node.
if self.is_in_error_node {
let mut next_sibling = current_node.next_named_sibling();
while let Some(n) = next_sibling {
if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) {
self.wrapping_clause_type = Some(clause_type);
break;
} else {
next_sibling = n.next_named_sibling();
}
if parent_node_kind == "ERROR" {
if let Some(clause_type) = self.get_wrapping_clause_from_siblings(current_node) {
self.wrapping_clause_type = Some(clause_type);
}
let mut prev_sibling = current_node.prev_named_sibling();
while let Some(n) = prev_sibling {
if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) {
self.wrapping_clause_type = Some(clause_type);
break;
} else {
prev_sibling = n.prev_named_sibling();
}
if let Some(wrapping_node) = self.get_wrapping_node_from_siblings(current_node) {
self.wrapping_node_kind = Some(wrapping_node)
}
}

Expand All @@ -303,19 +315,16 @@ impl<'a> CompletionContext<'a> {
}
}

"where" | "update" | "select" | "delete" | "from" | "join" => {
"where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions"
| "drop_table" | "alter_table" => {
self.wrapping_clause_type =
self.get_wrapping_clause_from_current_node(current_node, &mut cursor);
}

"relation" | "binary_expression" | "assignment" => {
"relation" | "binary_expression" | "assignment" | "list" => {
self.wrapping_node_kind = current_node_kind.try_into().ok();
}

"ERROR" => {
self.is_in_error_node = true;
}

_ => {}
}

Expand All @@ -329,31 +338,99 @@ impl<'a> CompletionContext<'a> {
self.gather_context_from_node(cursor, current_node);
}

fn get_wrapping_clause_from_keyword_node(
fn get_first_sibling(&self, node: tree_sitter::Node<'a>) -> tree_sitter::Node<'a> {
let mut first_sibling = node;
while let Some(n) = first_sibling.prev_sibling() {
first_sibling = n;
}
first_sibling
}

fn get_wrapping_node_from_siblings(&self, node: tree_sitter::Node<'a>) -> Option<WrappingNode> {
self.wrapping_clause_type
.as_ref()
.and_then(|clause| match clause {
WrappingClause::Insert => {
if node.prev_sibling().is_some_and(|n| n.kind() == "(")
|| node.next_sibling().is_some_and(|n| n.kind() == ")")
{
Some(WrappingNode::List)
} else {
None
}
}
_ => None,
})
}

fn get_wrapping_clause_from_siblings(
&self,
node: tree_sitter::Node<'a>,
) -> Option<WrappingClause<'a>> {
if node.kind().starts_with("keyword_") {
if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt {
NodeText::Original(txt) => Some(txt),
NodeText::Replaced => None,
}) {
match txt {
"where" => return Some(WrappingClause::Where),
"update" => return Some(WrappingClause::Update),
"select" => return Some(WrappingClause::Select),
"delete" => return Some(WrappingClause::Delete),
"from" => return Some(WrappingClause::From),
"join" => {
// TODO: not sure if we can infer it here.
return Some(WrappingClause::Join { on_node: None });
let clause_combinations: Vec<(WrappingClause, &[&'static str])> = vec![
(WrappingClause::Where, &["where"]),
(WrappingClause::Update, &["update"]),
(WrappingClause::Select, &["select"]),
(WrappingClause::Delete, &["delete"]),
(WrappingClause::Insert, &["insert", "into"]),
(WrappingClause::From, &["from"]),
(WrappingClause::Join { on_node: None }, &["join"]),
(WrappingClause::AlterTable, &["alter", "table"]),
(
WrappingClause::AlterTable,
&["alter", "table", "if", "exists"],
),
(WrappingClause::DropTable, &["drop", "table"]),
(
WrappingClause::DropTable,
&["drop", "table", "if", "exists"],
),
];

let first_sibling = self.get_first_sibling(node);

/*
* For each clause, we'll iterate from first_sibling to the next ones,
* either until the end or until we land on the node under the cursor.
* We'll score the `WrappingClause` by how many tokens it matches in order.
*/
let mut clauses_with_score: Vec<(WrappingClause, usize)> = clause_combinations
.into_iter()
.map(|(clause, tokens)| {
let mut idx = 0;

let mut sibling = Some(first_sibling);
while let Some(sib) = sibling {
if sib.end_byte() >= node.end_byte() || idx >= tokens.len() {
break;
}

if let Some(sibling_content) =
self.get_ts_node_content(sib).and_then(|txt| match txt {
NodeText::Original(txt) => Some(txt),
NodeText::Replaced => None,
})
{
if sibling_content == tokens[idx] {
idx += 1;
}
} else {
break;
}
_ => {}

sibling = sib.next_sibling();
}
};
}

None
(clause, idx)
})
.collect();

clauses_with_score.sort_by(|(_, score_a), (_, score_b)| score_b.cmp(score_a));
clauses_with_score
.iter()
.filter(|(_, score)| *score > 0)
.next()
.map(|c| c.0.clone())
}

fn get_wrapping_clause_from_current_node(
Expand All @@ -367,6 +444,10 @@ impl<'a> CompletionContext<'a> {
"select" => Some(WrappingClause::Select),
"delete" => Some(WrappingClause::Delete),
"from" => Some(WrappingClause::From),
"drop_table" => Some(WrappingClause::DropTable),
"alter_table" => Some(WrappingClause::AlterTable),
"column_definitions" => Some(WrappingClause::ColumnDefinitions),
"insert" => Some(WrappingClause::Insert),
"join" => {
// sadly, we need to manually iterate over the children –
// `node.child_by_field_id(..)` does not work as expected
Expand All @@ -383,6 +464,18 @@ impl<'a> CompletionContext<'a> {
_ => None,
}
}

pub(crate) fn before_cursor_matches_kind(&self, kinds: &[&'static str]) -> bool {
self.node_under_cursor.is_some_and(|mut node| {
// move up to the parent until we're at top OR we have a prev sibling
while node.prev_sibling().is_none() && node.parent().is_some() {
node = node.parent().unwrap();
}

node.prev_sibling()
.is_some_and(|sib| kinds.contains(&sib.kind()))
})
}
}

#[cfg(test)]
Expand Down
20 changes: 20 additions & 0 deletions crates/pgt_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,4 +573,24 @@ mod tests {
)
.await;
}

#[tokio::test]
async fn suggests_columns_in_insert_clause() {
let setup = r#"
create table instruments (
id bigint primary key generated always as identity,
name text not null
);
"#;

assert_complete_results(
format!("insert into instruments ({})", CURSOR_POS).as_str(),
vec![
CompletionAssertion::Label("id".to_string()),
CompletionAssertion::Label("name".to_string()),
],
setup,
)
.await;
}
}
Loading
Loading