Skip to content

Commit 7f61a05

Browse files
add query
1 parent 26e82c7 commit 7f61a05

File tree

4 files changed

+356
-1
lines changed

4 files changed

+356
-1
lines changed

crates/pgt_completions/src/context.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,14 @@ impl<'a> CompletionContext<'a> {
164164
}
165165
};
166166
}
167-
168167
QueryResult::TableAliases(table_alias_match) => {
169168
self.mentioned_table_aliases.insert(
170169
table_alias_match.get_alias(sql),
171170
table_alias_match.get_table(sql),
172171
);
173172
}
173+
174+
QueryResult::Column(_) => todo!(),
174175
};
175176
}
176177
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
use crate::{
2+
CompletionItemKind,
3+
builder::{CompletionBuilder, PossibleCompletionItem},
4+
context::CompletionContext,
5+
relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore},
6+
};
7+
8+
use super::helper::get_completion_text_with_schema_or_alias;
9+
10+
pub fn complete_functions<'a>(ctx: &'a CompletionContext, builder: &mut CompletionBuilder<'a>) {
11+
let available_functions = &ctx.schema_cache.functions;
12+
13+
for func in available_functions {
14+
let relevance = CompletionRelevanceData::Function(func);
15+
16+
let item = PossibleCompletionItem {
17+
label: func.name.clone(),
18+
score: CompletionScore::from(relevance.clone()),
19+
filter: CompletionFilter::from(relevance),
20+
description: format!("Schema: {}", func.schema),
21+
kind: CompletionItemKind::Function,
22+
completion_text: get_completion_text_with_schema_or_alias(
23+
ctx,
24+
&func.name,
25+
&func.schema,
26+
),
27+
};
28+
29+
builder.add_item(item);
30+
}
31+
}
32+
33+
#[cfg(test)]
34+
mod tests {
35+
use crate::{
36+
CompletionItem, CompletionItemKind, complete,
37+
test_helper::{CURSOR_POS, get_test_deps, get_test_params},
38+
};
39+
40+
#[tokio::test]
41+
async fn completes_fn() {
42+
let setup = r#"
43+
create or replace function cool()
44+
returns trigger
45+
language plpgsql
46+
security invoker
47+
as $$
48+
begin
49+
raise exception 'dont matter';
50+
end;
51+
$$;
52+
"#;
53+
54+
let query = format!("select coo{}", CURSOR_POS);
55+
56+
let (tree, cache) = get_test_deps(setup, query.as_str().into()).await;
57+
let params = get_test_params(&tree, &cache, query.as_str().into());
58+
let results = complete(params);
59+
60+
let CompletionItem { label, .. } = results
61+
.into_iter()
62+
.next()
63+
.expect("Should return at least one completion item");
64+
65+
assert_eq!(label, "cool");
66+
}
67+
68+
#[tokio::test]
69+
async fn prefers_fn_if_invocation() {
70+
let setup = r#"
71+
create table coos (
72+
id serial primary key,
73+
name text
74+
);
75+
76+
create or replace function cool()
77+
returns trigger
78+
language plpgsql
79+
security invoker
80+
as $$
81+
begin
82+
raise exception 'dont matter';
83+
end;
84+
$$;
85+
"#;
86+
87+
let query = format!(r#"select * from coo{}()"#, CURSOR_POS);
88+
89+
let (tree, cache) = get_test_deps(setup, query.as_str().into()).await;
90+
let params = get_test_params(&tree, &cache, query.as_str().into());
91+
let results = complete(params);
92+
93+
let CompletionItem { label, kind, .. } = results
94+
.into_iter()
95+
.next()
96+
.expect("Should return at least one completion item");
97+
98+
assert_eq!(label, "cool");
99+
assert_eq!(kind, CompletionItemKind::Function);
100+
}
101+
102+
#[tokio::test]
103+
async fn prefers_fn_in_select_clause() {
104+
let setup = r#"
105+
create table coos (
106+
id serial primary key,
107+
name text
108+
);
109+
110+
create or replace function cool()
111+
returns trigger
112+
language plpgsql
113+
security invoker
114+
as $$
115+
begin
116+
raise exception 'dont matter';
117+
end;
118+
$$;
119+
"#;
120+
121+
let query = format!(r#"select coo{}"#, CURSOR_POS);
122+
123+
let (tree, cache) = get_test_deps(setup, query.as_str().into()).await;
124+
let params = get_test_params(&tree, &cache, query.as_str().into());
125+
let results = complete(params);
126+
127+
let CompletionItem { label, kind, .. } = results
128+
.into_iter()
129+
.next()
130+
.expect("Should return at least one completion item");
131+
132+
assert_eq!(label, "cool");
133+
assert_eq!(kind, CompletionItemKind::Function);
134+
}
135+
136+
#[tokio::test]
137+
async fn prefers_function_in_from_clause_if_invocation() {
138+
let setup = r#"
139+
create table coos (
140+
id serial primary key,
141+
name text
142+
);
143+
144+
create or replace function cool()
145+
returns trigger
146+
language plpgsql
147+
security invoker
148+
as $$
149+
begin
150+
raise exception 'dont matter';
151+
end;
152+
$$;
153+
"#;
154+
155+
let query = format!(r#"select * from coo{}()"#, CURSOR_POS);
156+
157+
let (tree, cache) = get_test_deps(setup, query.as_str().into()).await;
158+
let params = get_test_params(&tree, &cache, query.as_str().into());
159+
let results = complete(params);
160+
161+
let CompletionItem { label, kind, .. } = results
162+
.into_iter()
163+
.next()
164+
.expect("Should return at least one completion item");
165+
166+
assert_eq!(label, "cool");
167+
assert_eq!(kind, CompletionItemKind::Function);
168+
}
169+
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
use std::sync::LazyLock;
2+
3+
use crate::{Query, QueryResult};
4+
5+
use super::QueryTryFrom;
6+
7+
static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
8+
static QUERY_STR: &str = r#"
9+
(select_expression
10+
(term
11+
(field
12+
(object_reference)? @alias
13+
"."?
14+
(identifier) @column
15+
)
16+
)
17+
","?
18+
)
19+
"#;
20+
tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query")
21+
});
22+
23+
#[derive(Debug)]
24+
pub struct ColumnMatch<'a> {
25+
pub(crate) alias: Option<tree_sitter::Node<'a>>,
26+
pub(crate) column: tree_sitter::Node<'a>,
27+
}
28+
29+
impl ColumnMatch<'_> {
30+
pub fn get_alias(&self, sql: &str) -> Option<String> {
31+
let str = self
32+
.alias
33+
.as_ref()?
34+
.utf8_text(sql.as_bytes())
35+
.expect("Failed to get alias from ColumnMatch");
36+
37+
Some(str.to_string())
38+
}
39+
40+
pub fn get_column(&self, sql: &str) -> String {
41+
self.column
42+
.utf8_text(sql.as_bytes())
43+
.expect("Failed to get column from ColumnMatch")
44+
.to_string()
45+
}
46+
}
47+
48+
impl<'a> TryFrom<&'a QueryResult<'a>> for &'a ColumnMatch<'a> {
49+
type Error = String;
50+
51+
fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
52+
match q {
53+
QueryResult::Column(c) => Ok(c),
54+
55+
#[allow(unreachable_patterns)]
56+
_ => Err("Invalid QueryResult type".into()),
57+
}
58+
}
59+
}
60+
61+
impl<'a> QueryTryFrom<'a> for ColumnMatch<'a> {
62+
type Ref = &'a ColumnMatch<'a>;
63+
}
64+
65+
impl<'a> Query<'a> for ColumnMatch<'a> {
66+
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
67+
let mut cursor = tree_sitter::QueryCursor::new();
68+
69+
let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());
70+
71+
let mut to_return = vec![];
72+
73+
for m in matches {
74+
if m.captures.len() == 1 {
75+
let capture = m.captures[0].node;
76+
to_return.push(QueryResult::Column(ColumnMatch {
77+
alias: None,
78+
column: capture,
79+
}));
80+
}
81+
82+
if m.captures.len() == 2 {
83+
let alias = m.captures[0].node;
84+
let column = m.captures[1].node;
85+
86+
to_return.push(QueryResult::Column(ColumnMatch {
87+
alias: Some(alias),
88+
column,
89+
}));
90+
}
91+
}
92+
93+
to_return
94+
}
95+
}
96+
97+
#[cfg(test)]
98+
mod tests {
99+
use crate::TreeSitterQueriesExecutor;
100+
101+
use super::ColumnMatch;
102+
103+
#[test]
104+
fn finds_all_columns() {
105+
let sql = r#"select aud, id, email from auth.users;"#;
106+
107+
let mut parser = tree_sitter::Parser::new();
108+
parser.set_language(tree_sitter_sql::language()).unwrap();
109+
110+
let tree = parser.parse(sql, None).unwrap();
111+
112+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
113+
114+
executor.add_query_results::<ColumnMatch>();
115+
116+
let results: Vec<&ColumnMatch> = executor
117+
.get_iter(None)
118+
.filter_map(|q| q.try_into().ok())
119+
.collect();
120+
121+
assert_eq!(results[0].get_alias(sql), None);
122+
assert_eq!(results[0].get_column(sql), "aud");
123+
124+
assert_eq!(results[1].get_alias(sql), None);
125+
assert_eq!(results[1].get_column(sql), "id");
126+
127+
assert_eq!(results[2].get_alias(sql), None);
128+
assert_eq!(results[2].get_column(sql), "email");
129+
}
130+
131+
#[test]
132+
fn finds_columns_with_aliases() {
133+
let sql = r#"
134+
select
135+
u.id,
136+
u.email,
137+
cs.user_settings,
138+
cs.client_id
139+
from
140+
auth.users u
141+
join public.client_settings cs
142+
on u.id = cs.user_id;
143+
144+
"#;
145+
146+
let mut parser = tree_sitter::Parser::new();
147+
parser.set_language(tree_sitter_sql::language()).unwrap();
148+
149+
let tree = parser.parse(sql, None).unwrap();
150+
151+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
152+
153+
executor.add_query_results::<ColumnMatch>();
154+
155+
let results: Vec<&ColumnMatch> = executor
156+
.get_iter(None)
157+
.filter_map(|q| q.try_into().ok())
158+
.collect();
159+
160+
assert_eq!(results[0].get_alias(sql), Some("u".into()));
161+
assert_eq!(results[0].get_column(sql), "id");
162+
163+
assert_eq!(results[1].get_alias(sql), Some("u".into()));
164+
assert_eq!(results[1].get_column(sql), "email");
165+
166+
assert_eq!(results[2].get_alias(sql), Some("cs".into()));
167+
assert_eq!(results[2].get_column(sql), "user_settings");
168+
169+
assert_eq!(results[3].get_alias(sql), Some("cs".into()));
170+
assert_eq!(results[3].get_column(sql), "client_id");
171+
}
172+
}

crates/pgt_treesitter_queries/src/queries/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
mod columns;
12
mod relations;
23
mod table_aliases;
34

5+
pub use columns::*;
46
pub use relations::*;
57
pub use table_aliases::*;
68

79
#[derive(Debug)]
810
pub enum QueryResult<'a> {
911
Relation(RelationMatch<'a>),
1012
TableAliases(TableAliasMatch<'a>),
13+
Column(ColumnMatch<'a>),
1114
}
1215

1316
impl QueryResult<'_> {
@@ -28,6 +31,16 @@ impl QueryResult<'_> {
2831
let end = m.alias.end_position();
2932
start >= range.start_point && end <= range.end_point
3033
}
34+
Self::Column(cm) => {
35+
let start = match cm.alias {
36+
Some(n) => n.start_position(),
37+
None => cm.column.start_position(),
38+
};
39+
40+
let end = cm.column.end_position();
41+
42+
start >= range.start_point && end <= range.end_point
43+
}
3144
}
3245
}
3346
}

0 commit comments

Comments
 (0)