diff --git a/Cargo.lock b/Cargo.lock index 6519bb12..67959947 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -44,6 +44,55 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys 0.59.0", +] + [[package]] name = "anyhow" version = "1.0.93" @@ -375,6 +424,46 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.89", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + [[package]] name = "cmake" version = "0.1.52" @@ -384,6 +473,12 @@ dependencies = [ "cc", ] +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1178,6 +1273,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.10.5" @@ -1841,8 +1942,11 @@ name = "pg_test_utils" version = "0.0.0" dependencies = [ "anyhow", + "clap", "dotenv", "sqlx", + "tree-sitter", + "tree_sitter_sql", "uuid", ] @@ -2750,6 +2854,12 @@ dependencies = [ "unicode-properties", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -3170,6 +3280,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.11.0" diff --git a/Cargo.toml b/Cargo.toml index f4a063ea..afaef7ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ line_index = { path = "./lib/line_index", version = "0.0.0" } tree_sitter_sql = { path = "./lib/tree_sitter_sql", version = "0.0.0" } tree-sitter = "0.20.10" tracing = "0.1.40" +tower-lsp = "0.20.0" sqlx = { version = "0.8.2", features = [ "runtime-async-std", "tls-rustls", "postgres", "json" ] } # postgres specific crates diff --git a/crates/pg_completions/README.md b/crates/pg_completions/README.md new file mode 100644 index 00000000..1d3218ac --- /dev/null +++ b/crates/pg_completions/README.md @@ -0,0 +1,15 @@ +# Auto-Completions + +## What does this crate do? + +The `pg_completions` identifies and ranks autocompletion items that can be displayed in your code editor. +Its main export is the `complete` function. The function takes a PostgreSQL statement, a cursor position, and a datastructure representing the underlying databases schema. It returns a list of completion items. + +Postgres's statement-parsing-engine, `libpg_query`, which is used in other parts of this LSP, is only capable of parsing _complete and valid_ statements. Since autocompletion should work for incomplete statements, we rely heavily on tree-sitter – an incremental parsing library. + +### Working with TreeSitter + +In the `pg_test_utils` crate, there's a binary that parses an SQL file and prints out the matching tree-sitter tree. +This makes writing tree-sitter queries for this crate easy. + +To print a tree, run `cargo run --bin tree_print -- -f `. diff --git a/crates/pg_completions/src/builder.rs b/crates/pg_completions/src/builder.rs index 4075050c..c5a89889 100644 --- a/crates/pg_completions/src/builder.rs +++ b/crates/pg_completions/src/builder.rs @@ -1,29 +1,50 @@ -use crate::{CompletionItem, CompletionResult}; +use crate::{item::CompletionItem, CompletionResult}; -pub struct CompletionBuilder<'a> { - pub items: Vec>, +pub(crate) struct CompletionBuilder { + items: Vec, } -pub struct CompletionConfig {} +impl CompletionBuilder { + pub fn new() -> Self { + CompletionBuilder { items: vec![] } + } -impl<'a> From<&'a CompletionConfig> for CompletionBuilder<'a> { - fn from(_config: &CompletionConfig) -> Self { - Self { items: Vec::new() } + pub fn add_item(&mut self, item: CompletionItem) { + self.items.push(item); } -} -impl<'a> CompletionBuilder<'a> { - pub fn finish(mut self) -> CompletionResult<'a> { - self.items.sort_by(|a, b| { - b.preselect - .cmp(&a.preselect) - .then_with(|| b.score.cmp(&a.score)) - .then_with(|| a.data.label().cmp(b.data.label())) - }); + pub fn finish(mut self) -> CompletionResult { + self.items + .sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.label.cmp(&b.label))); - self.items.dedup_by(|a, b| a.data.label() == b.data.label()); + self.items.dedup_by(|a, b| a.label == b.label); self.items.truncate(crate::LIMIT); - let Self { items, .. } = self; + + let should_preselect_first_item = self.should_preselect_first_item(); + + let items: Vec = self + .items + .into_iter() + .enumerate() + .map(|(idx, mut item)| { + if idx == 0 { + item.preselected = Some(should_preselect_first_item); + } + item.into() + }) + .collect(); + CompletionResult { items } } + + fn should_preselect_first_item(&mut self) -> bool { + let mut items_iter = self.items.iter(); + let first = items_iter.next(); + let second = items_iter.next(); + + first.is_some_and(|f| match second { + Some(s) => (f.score - s.score) > 10, + None => true, + }) + } } diff --git a/crates/pg_completions/src/complete.rs b/crates/pg_completions/src/complete.rs new file mode 100644 index 00000000..58c08897 --- /dev/null +++ b/crates/pg_completions/src/complete.rs @@ -0,0 +1,219 @@ +use text_size::TextSize; + +use crate::{ + builder::CompletionBuilder, context::CompletionContext, item::CompletionItem, + providers::complete_tables, +}; + +pub const LIMIT: usize = 50; + +#[derive(Debug)] +pub struct CompletionParams<'a> { + pub position: TextSize, + pub schema: &'a pg_schema_cache::SchemaCache, + pub text: &'a str, + pub tree: Option<&'a tree_sitter::Tree>, +} + +#[derive(Debug, Default)] +pub struct CompletionResult { + pub items: Vec, +} + +impl IntoIterator for CompletionResult { + type Item = CompletionItem; + type IntoIter = as IntoIterator>::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} + +pub fn complete(params: CompletionParams) -> CompletionResult { + let ctx = CompletionContext::new(¶ms); + + let mut builder = CompletionBuilder::new(); + + complete_tables(&ctx, &mut builder); + + builder.finish() +} + +#[cfg(test)] +mod tests { + use pg_schema_cache::SchemaCache; + use pg_test_utils::test_database::*; + + use sqlx::Executor; + + use crate::{complete, CompletionParams}; + + #[tokio::test] + async fn autocompletes_simple_table() { + let test_db = get_new_test_db().await; + + let setup = r#" + create table users ( + id serial primary key, + name text, + password text + ); + "#; + + test_db + .execute(setup) + .await + .expect("Failed to execute setup query"); + + let input = "select * from u"; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let tree = parser.parse(input, None).unwrap(); + let schema_cache = SchemaCache::load(&test_db).await; + + let p = CompletionParams { + position: ((input.len() - 1) as u32).into(), + schema: &schema_cache, + text: input, + tree: Some(&tree), + }; + + let result = complete(p); + + assert!(result.items.len() > 0); + + let best_match = &result.items[0]; + + assert_eq!( + best_match.label, "users", + "Does not return the expected table to autocomplete: {}", + best_match.label + ) + } + + #[tokio::test] + async fn autocompletes_table_alphanumerically() { + let test_db = get_new_test_db().await; + + let setup = r#" + create table addresses ( + id serial primary key + ); + + create table users ( + id serial primary key + ); + + create table emails ( + id serial primary key + ); + "#; + + test_db + .execute(setup) + .await + .expect("Failed to execute setup query"); + + let schema_cache = SchemaCache::load(&test_db).await; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let test_cases = vec![ + ("select * from us", "users"), + ("select * from em", "emails"), + ("select * from ", "addresses"), + ]; + + for (input, expected_label) in test_cases { + let tree = parser.parse(input, None).unwrap(); + + let p = CompletionParams { + position: ((input.len() - 1) as u32).into(), + schema: &schema_cache, + text: input, + tree: Some(&tree), + }; + + let result = complete(p); + + assert!(result.items.len() > 0); + + let best_match = &result.items[0]; + + assert_eq!( + best_match.label, expected_label, + "Does not return the expected table to autocomplete: {}", + best_match.label + ) + } + } + + #[tokio::test] + async fn autocompletes_table_with_schema() { + let test_db = get_new_test_db().await; + + let setup = r#" + create schema customer_support; + create schema private; + + create table private.user_z ( + id serial primary key, + name text, + password text + ); + + create table customer_support.user_y ( + id serial primary key, + request text, + send_at timestamp with time zone + ); + "#; + + test_db + .execute(setup) + .await + .expect("Failed to execute setup query"); + + let schema_cache = SchemaCache::load(&test_db).await; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let test_cases = vec![ + ("select * from u", "user_y"), // user_y is preferred alphanumerically + ("select * from private.u", "user_z"), + ("select * from customer_support.u", "user_y"), + ]; + + for (input, expected_label) in test_cases { + let tree = parser.parse(input, None).unwrap(); + + let p = CompletionParams { + position: ((input.len() - 1) as u32).into(), + schema: &schema_cache, + text: input, + tree: Some(&tree), + }; + + let result = complete(p); + + assert!(result.items.len() > 0); + + let best_match = &result.items[0]; + + assert_eq!( + best_match.label, expected_label, + "Does not return the expected table to autocomplete: {}", + best_match.label + ) + } + } +} diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs new file mode 100644 index 00000000..92cfc559 --- /dev/null +++ b/crates/pg_completions/src/context.rs @@ -0,0 +1,232 @@ +use pg_schema_cache::SchemaCache; + +use crate::CompletionParams; + +pub(crate) struct CompletionContext<'a> { + pub ts_node: Option>, + pub tree: Option<&'a tree_sitter::Tree>, + pub text: &'a str, + pub schema_cache: &'a SchemaCache, + pub position: usize, + + pub schema_name: Option, + pub wrapping_clause_type: Option, + pub is_invocation: bool, +} + +impl<'a> CompletionContext<'a> { + pub fn new(params: &'a CompletionParams) -> Self { + let mut tree = Self { + tree: params.tree, + text: ¶ms.text, + schema_cache: params.schema, + position: usize::from(params.position), + + ts_node: None, + schema_name: None, + wrapping_clause_type: None, + is_invocation: false, + }; + + tree.gather_tree_context(); + + tree + } + + pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> { + let source = self.text; + match ts_node.utf8_text(source.as_bytes()) { + Ok(content) => Some(content), + Err(_) => None, + } + } + + fn gather_tree_context(&mut self) { + if self.tree.is_none() { + return; + } + + let mut cursor = self.tree.as_ref().unwrap().root_node().walk(); + + // go to the statement node that matches the position + let current_node_kind = cursor.node().kind(); + + cursor.goto_first_child_for_byte(self.position); + + self.gather_context_from_node(cursor, current_node_kind); + } + + fn gather_context_from_node( + &mut self, + mut cursor: tree_sitter::TreeCursor<'a>, + previous_node_kind: &str, + ) { + let current_node = cursor.node(); + let current_node_kind = current_node.kind(); + + match previous_node_kind { + "statement" => self.wrapping_clause_type = Some(current_node_kind.to_string()), + "invocation" => self.is_invocation = true, + + _ => {} + } + + match current_node_kind { + "object_reference" => { + let txt = self.get_ts_node_content(current_node); + if let Some(txt) = txt { + let parts: Vec<&str> = txt.split('.').collect(); + if parts.len() == 2 { + self.schema_name = Some(parts[0].to_string()); + } + } + } + + // in Treesitter, the Where clause is nested inside other clauses + "where" => { + self.wrapping_clause_type = Some("where".to_string()); + } + + _ => {} + } + + if current_node.child_count() == 0 { + self.ts_node = Some(current_node); + return; + } + + cursor.goto_first_child_for_byte(self.position); + self.gather_context_from_node(cursor, current_node_kind); + } +} + +#[cfg(test)] +mod tests { + use crate::context::CompletionContext; + + fn get_tree(input: &str) -> tree_sitter::Tree { + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Couldn't set language"); + + parser.parse(input, None).expect("Unable to parse tree") + } + + static CURSOR_POS: &str = "XXX"; + + #[test] + fn identifies_clauses() { + let test_cases = vec![ + (format!("Select {}* from users;", CURSOR_POS), "select"), + (format!("Select * from u{};", CURSOR_POS), "from"), + ( + format!("Select {}* from users where n = 1;", CURSOR_POS), + "select", + ), + ( + format!("Select * from users where {}n = 1;", CURSOR_POS), + "where", + ), + ( + format!("update users set u{} = 1 where n = 2;", CURSOR_POS), + "update", + ), + ( + format!("update users set u = 1 where n{} = 2;", CURSOR_POS), + "where", + ), + (format!("delete{} from users;", CURSOR_POS), "delete"), + (format!("delete from {}users;", CURSOR_POS), "from"), + ( + format!("select name, age, location from public.u{}sers", CURSOR_POS), + "from", + ), + ]; + + for (text, expected_clause) in test_cases { + let position = text.find(CURSOR_POS).unwrap(); + let text = text.replace(CURSOR_POS, ""); + + let tree = get_tree(text.as_str()); + let params = crate::CompletionParams { + position: (position as u32).into(), + text: text.as_str(), + tree: Some(&tree), + schema: &pg_schema_cache::SchemaCache::new(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!(ctx.wrapping_clause_type, Some(expected_clause.to_string())); + } + } + + #[test] + fn identifies_schema() { + let test_cases = vec![ + ( + format!("Select * from private.u{}", CURSOR_POS), + Some("private"), + ), + ( + format!("Select * from private.u{}sers()", CURSOR_POS), + Some("private"), + ), + (format!("Select * from u{}sers", CURSOR_POS), None), + (format!("Select * from u{}sers()", CURSOR_POS), None), + ]; + + for (text, expected_schema) in test_cases { + let position = text.find(CURSOR_POS).unwrap(); + let text = text.replace(CURSOR_POS, ""); + + let tree = get_tree(text.as_str()); + let params = crate::CompletionParams { + position: (position as u32).into(), + text: text.as_str(), + tree: Some(&tree), + schema: &pg_schema_cache::SchemaCache::new(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string())); + } + } + + #[test] + fn identifies_invocation() { + let test_cases = vec![ + (format!("Select * from u{}sers", CURSOR_POS), false), + (format!("Select * from u{}sers()", CURSOR_POS), true), + (format!("Select cool{};", CURSOR_POS), false), + (format!("Select cool{}();", CURSOR_POS), true), + ( + format!("Select upp{}ercase as title from users;", CURSOR_POS), + false, + ), + ( + format!("Select upp{}ercase(name) as title from users;", CURSOR_POS), + true, + ), + ]; + + for (text, is_invocation) in test_cases { + let position = text.find(CURSOR_POS).unwrap(); + let text = text.replace(CURSOR_POS, ""); + + let tree = get_tree(text.as_str()); + let params = crate::CompletionParams { + position: (position as u32).into(), + text: text.as_str(), + tree: Some(&tree), + schema: &pg_schema_cache::SchemaCache::new(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!(ctx.is_invocation, is_invocation); + } + } +} diff --git a/crates/pg_completions/src/item.rs b/crates/pg_completions/src/item.rs new file mode 100644 index 00000000..7a015e72 --- /dev/null +++ b/crates/pg_completions/src/item.rs @@ -0,0 +1,13 @@ +#[derive(Debug)] +pub enum CompletionItemKind { + Table, +} + +#[derive(Debug)] +pub struct CompletionItem { + pub label: String, + pub(crate) score: i32, + pub description: String, + pub preselected: Option, + pub kind: CompletionItemKind, +} diff --git a/crates/pg_completions/src/lib.rs b/crates/pg_completions/src/lib.rs index 382e1fc4..c31e9337 100644 --- a/crates/pg_completions/src/lib.rs +++ b/crates/pg_completions/src/lib.rs @@ -1,165 +1,9 @@ mod builder; +mod complete; +mod context; +mod item; mod providers; +mod relevance; -pub use providers::CompletionProviderParams; -use text_size::{TextRange, TextSize}; - -pub const LIMIT: usize = 50; - -#[derive(Debug)] -pub struct CompletionParams<'a> { - pub position: TextSize, - pub schema: &'a pg_schema_cache::SchemaCache, - pub text: &'a str, - pub tree: Option<&'a tree_sitter::Tree>, -} - -#[derive(Debug, Default)] -pub struct CompletionResult<'a> { - pub items: Vec>, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct CompletionItem<'a> { - pub score: i32, - pub range: TextRange, - pub preselect: bool, - pub data: CompletionItemData<'a>, -} - -#[derive(Debug, PartialEq, Eq)] -pub enum CompletionItemData<'a> { - Table(&'a pg_schema_cache::Table), -} - -impl<'a> CompletionItemData<'a> { - pub fn label(&self) -> &'a str { - match self { - CompletionItemData::Table(t) => t.name.as_str(), - } - } -} - -impl<'a> CompletionItem<'a> { - pub fn new_simple(score: i32, range: TextRange, data: CompletionItemData<'a>) -> Self { - Self { - score, - range, - preselect: false, - data, - } - } -} - -pub fn complete<'a>(params: &'a CompletionParams<'a>) -> CompletionResult<'a> { - let mut builder = builder::CompletionBuilder::from(&builder::CompletionConfig {}); - - let params = CompletionProviderParams::from(params); - - providers::complete_tables(params, &mut builder); - - builder.finish() -} - -#[cfg(test)] -mod tests { - use pg_schema_cache::SchemaCache; - use pg_test_utils::test_database::*; - - use sqlx::Executor; - - use crate::{complete, CompletionParams}; - - #[tokio::test] - async fn test_complete() { - let pool = get_new_test_db().await; - - let input = "select id from c;"; - - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let tree = parser.parse(input, None).unwrap(); - - let schema_cache = SchemaCache::load(&pool).await; - - let p = CompletionParams { - position: 15.into(), - schema: &schema_cache, - text: input, - tree: Some(&tree), - }; - - let result = complete(&p); - - assert!(result.items.len() > 0); - } - - #[tokio::test] - async fn test_complete_two() { - let pool = get_new_test_db().await; - - let input = "select id, name, test1231234123, unknown from co;"; - - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let tree = parser.parse(input, None).unwrap(); - let schema_cache = SchemaCache::load(&pool).await; - - let p = CompletionParams { - position: 47.into(), - schema: &schema_cache, - text: input, - tree: Some(&tree), - }; - - let result = complete(&p); - - assert!(result.items.len() > 0); - } - - #[tokio::test] - async fn test_complete_three() { - let test_db = get_new_test_db().await; - - let setup = r#" - create table users ( - id serial primary key, - name text, - password text - ); - "#; - - test_db - .execute(setup) - .await - .expect("Failed to execute setup query"); - - let input = "select * from u"; - - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let tree = parser.parse(input, None).unwrap(); - let schema_cache = SchemaCache::load(&test_db).await; - - let p = CompletionParams { - position: ((input.len() - 1) as u32).into(), - schema: &schema_cache, - text: input, - tree: Some(&tree), - }; - - let result = complete(&p); - - // TODO: actually assert that we get good autocompletion suggestions - assert!(result.items.len() > 0); - } -} +pub use complete::*; +pub use item::*; diff --git a/crates/pg_completions/src/providers.rs b/crates/pg_completions/src/providers.rs deleted file mode 100644 index 355e0ea0..00000000 --- a/crates/pg_completions/src/providers.rs +++ /dev/null @@ -1,47 +0,0 @@ -mod tables; - -pub use tables::complete_tables; - -use crate::CompletionParams; - -#[derive(Debug)] -pub struct CompletionProviderParams<'a> { - pub ts_node: Option>, - pub schema: &'a pg_schema_cache::SchemaCache, - pub source: &'a str, -} - -impl<'a> From<&'a CompletionParams<'a>> for CompletionProviderParams<'a> { - fn from(params: &'a CompletionParams) -> Self { - let ts_node = if let Some(tree) = params.tree { - let node = tree.root_node().named_descendant_for_byte_range( - usize::from(params.position), - usize::from(params.position), - ); - - if let Some(mut n) = node { - let node_range = n.range(); - - while let Some(parent) = n.parent() { - if parent.range() != node_range { - break; - } - - n = parent; - } - - Some(n) - } else { - None - } - } else { - None - }; - - Self { - ts_node, - schema: params.schema, - source: params.text, - } - } -} diff --git a/crates/pg_completions/src/providers/mod.rs b/crates/pg_completions/src/providers/mod.rs new file mode 100644 index 00000000..81043e5f --- /dev/null +++ b/crates/pg_completions/src/providers/mod.rs @@ -0,0 +1,3 @@ +mod tables; + +pub use tables::*; diff --git a/crates/pg_completions/src/providers/tables.rs b/crates/pg_completions/src/providers/tables.rs index 74a52c7e..ea78deef 100644 --- a/crates/pg_completions/src/providers/tables.rs +++ b/crates/pg_completions/src/providers/tables.rs @@ -1,31 +1,37 @@ -use text_size::{TextRange, TextSize}; +use pg_schema_cache::Table; -use crate::{builder::CompletionBuilder, CompletionItem, CompletionItemData}; +use crate::{ + builder::CompletionBuilder, + context::CompletionContext, + item::{CompletionItem, CompletionItemKind}, + relevance::CompletionRelevance, +}; -use super::CompletionProviderParams; +pub fn complete_tables(ctx: &CompletionContext, builder: &mut CompletionBuilder) { + let available_tables = &ctx.schema_cache.tables; -// todo unify this in a type resolver crate -pub fn complete_tables<'a>( - params: CompletionProviderParams<'a>, - builder: &mut CompletionBuilder<'a>, -) { - if let Some(ts) = params.ts_node { - let range = TextRange::new( - TextSize::try_from(ts.start_byte()).unwrap(), - TextSize::try_from(ts.end_byte()).unwrap(), - ); - match ts.kind() { - "relation" => { - // todo better search - params.schema.tables.iter().for_each(|table| { - builder.items.push(CompletionItem::new_simple( - 1, - range, - CompletionItemData::Table(table), - )); - }); - } - _ => {} - } + let completion_items: Vec = available_tables + .iter() + .map(|table| CompletionItem { + label: table.name.clone(), + score: get_score(ctx, table), + description: format!("Schema: {}", table.schema), + preselected: None, + kind: CompletionItemKind::Table, + }) + .collect(); + + for item in completion_items { + builder.add_item(item); } } + +fn get_score(ctx: &CompletionContext, table: &Table) -> i32 { + let mut relevance = CompletionRelevance::default(); + + relevance.check_matches_query_input(ctx, &table.name); + relevance.check_matches_schema(ctx, &table.schema); + relevance.check_if_catalog(ctx); + + relevance.score() +} diff --git a/crates/pg_completions/src/relevance.rs b/crates/pg_completions/src/relevance.rs new file mode 100644 index 00000000..ddf52ae4 --- /dev/null +++ b/crates/pg_completions/src/relevance.rs @@ -0,0 +1,52 @@ +use crate::context::CompletionContext; + +#[derive(Debug, Default)] +pub(crate) struct CompletionRelevance { + score: i32, +} + +impl CompletionRelevance { + pub fn score(&self) -> i32 { + self.score + } + + pub fn check_matches_query_input(&mut self, ctx: &CompletionContext, name: &str) { + let node = ctx.ts_node.unwrap(); + + let content = match ctx.get_ts_node_content(node) { + Some(c) => c, + None => return, + }; + + if name.starts_with(content) { + let len: i32 = content + .len() + .try_into() + .expect("The length of the input exceeds i32 capacity"); + + self.score += len * 5; + }; + } + + pub fn check_matches_schema(&mut self, ctx: &CompletionContext, schema: &str) { + if ctx.schema_name.is_none() { + return; + } + + let name = ctx.schema_name.as_ref().unwrap(); + + if name == schema { + self.score += 25; + } else { + self.score -= 10; + } + } + + pub fn check_if_catalog(&mut self, ctx: &CompletionContext) { + if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") { + return; + } + + self.score -= 5; // unlikely that the user wants schema data + } +} diff --git a/crates/pg_lsp/Cargo.toml b/crates/pg_lsp/Cargo.toml index 69560b55..cfe2877c 100644 --- a/crates/pg_lsp/Cargo.toml +++ b/crates/pg_lsp/Cargo.toml @@ -23,6 +23,7 @@ text-size = "1.1.1" line_index.workspace = true sqlx.workspace = true +tower-lsp.workspace = true pg_hover.workspace = true pg_fs.workspace = true @@ -35,7 +36,6 @@ pg_workspace.workspace = true pg_diagnostics.workspace = true tokio = { version = "1.40.0", features = ["io-std", "macros", "rt-multi-thread", "sync", "time"] } tokio-util = "0.7.12" -tower-lsp = "0.20.0" tracing = "0.1.40" tracing-subscriber = "0.3.18" diff --git a/crates/pg_lsp/src/session.rs b/crates/pg_lsp/src/session.rs index 9d7ef641..215e54b1 100644 --- a/crates/pg_lsp/src/session.rs +++ b/crates/pg_lsp/src/session.rs @@ -10,11 +10,14 @@ use pg_workspace::Workspace; use text_size::TextSize; use tokio::sync::RwLock; use tower_lsp::lsp_types::{ - CodeActionOrCommand, CompletionItem, CompletionItemKind, CompletionList, Hover, HoverContents, - InlayHint, InlayHintKind, InlayHintLabel, MarkedString, Position, Range, + CodeActionOrCommand, CompletionItem, CompletionList, Hover, HoverContents, InlayHint, + InlayHintKind, InlayHintLabel, MarkedString, Position, Range, }; -use crate::{db_connection::DbConnection, utils::line_index_ext::LineIndexExt}; +use crate::{ + db_connection::DbConnection, + utils::{line_index_ext::LineIndexExt, to_lsp_types::to_completion_kind}, +}; pub struct Session { db: RwLock>, @@ -235,19 +238,24 @@ impl Session { let schema_cache = ide.schema_cache.read().expect("No Schema Cache"); - let completion_items = pg_completions::complete(&CompletionParams { + let completion_items: Vec = pg_completions::complete(CompletionParams { position: offset - range.start() - TextSize::from(1), - text: stmt.text.as_str(), - tree: ide.tree_sitter.tree(&stmt).as_ref().map(|x| x.as_ref()), + text: &stmt.text, + tree: ide + .tree_sitter + .tree(&stmt) + .as_ref() + .and_then(|t| Some(t.as_ref())), schema: &schema_cache, }) - .items .into_iter() - .map(|i| CompletionItem { - // TODO: add more data - label: i.data.label().to_string(), - label_details: None, - kind: Some(CompletionItemKind::CLASS), + .map(|item| CompletionItem { + label: item.label, + label_details: Some(tower_lsp::lsp_types::CompletionItemLabelDetails { + description: Some(item.description), + detail: None, + }), + kind: Some(to_completion_kind(item.kind)), detail: None, documentation: None, deprecated: None, diff --git a/crates/pg_lsp/src/utils.rs b/crates/pg_lsp/src/utils.rs index bfbd57d1..50e9edd8 100644 --- a/crates/pg_lsp/src/utils.rs +++ b/crates/pg_lsp/src/utils.rs @@ -1,4 +1,5 @@ pub mod line_index_ext; +pub mod to_lsp_types; pub mod to_proto; use std::path::PathBuf; diff --git a/crates/pg_lsp/src/utils/to_lsp_types.rs b/crates/pg_lsp/src/utils/to_lsp_types.rs new file mode 100644 index 00000000..d090386b --- /dev/null +++ b/crates/pg_lsp/src/utils/to_lsp_types.rs @@ -0,0 +1,9 @@ +use tower_lsp::lsp_types; + +pub fn to_completion_kind( + kind: pg_completions::CompletionItemKind, +) -> lsp_types::CompletionItemKind { + match kind { + pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, + } +} diff --git a/crates/pg_test_utils/Cargo.toml b/crates/pg_test_utils/Cargo.toml index a61688ee..ce4eb139 100644 --- a/crates/pg_test_utils/Cargo.toml +++ b/crates/pg_test_utils/Cargo.toml @@ -1,3 +1,7 @@ +[[bin]] +name = "tree_print" +path = "src/bin/tree_print.rs" + [package] name = "pg_test_utils" version = "0.0.0" @@ -6,6 +10,10 @@ edition = "2021" [dependencies] anyhow = "1.0.81" uuid = { version = "1.11.0", features = ["v4"] } +dotenv = "0.15.0" +clap = { version = "4.5.23", features = ["derive"] } sqlx.workspace = true -dotenv = "0.15.0" +tree-sitter.workspace = true +tree_sitter_sql.workspace = true + diff --git a/crates/pg_test_utils/src/bin/tree_print.rs b/crates/pg_test_utils/src/bin/tree_print.rs new file mode 100644 index 00000000..8a04365e --- /dev/null +++ b/crates/pg_test_utils/src/bin/tree_print.rs @@ -0,0 +1,47 @@ +use clap::*; + +#[derive(Parser)] +#[command( + name = "tree-printer", + about = "Prints the TreeSitter tree of the given file." +)] +struct Args { + #[arg(long = "file", short = 'f')] + file: String, +} + +fn main() { + let args = Args::parse(); + + let query = std::fs::read_to_string(&args.file).expect("Failed to read file."); + + let mut parser = tree_sitter::Parser::new(); + let lang = tree_sitter_sql::language(); + + parser.set_language(lang).expect("Setting Language failed."); + + let tree = parser + .parse(query.clone(), None) + .expect("Failed to parse query."); + + print_tree(&tree.root_node(), &query, 0); +} + +fn print_tree(node: &tree_sitter::Node, source: &str, level: usize) { + let indent = " ".repeat(level); + let node_text = node.utf8_text(source.as_bytes()).unwrap_or("NO_NAME"); + + println!( + "{}{} [{}..{}] '{}'", + indent, + node.kind(), + node.start_position().column, + node.end_position().column, + node_text + ); + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + print_tree(&child, source, level + 1); + } +}