Skip to content

Commit 2ada420

Browse files
feat(completions): lower priority of already mentioned columns in SELECT (#399)
1 parent 26e82c7 commit 2ada420

File tree

7 files changed

+531
-16
lines changed

7 files changed

+531
-16
lines changed

crates/pgt_completions/src/context.rs

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use pgt_treesitter_queries::{
88

99
use crate::sanitization::SanitizedCompletionParams;
1010

11-
#[derive(Debug, PartialEq, Eq)]
11+
#[derive(Debug, PartialEq, Eq, Hash)]
1212
pub enum WrappingClause<'a> {
1313
Select,
1414
Where,
@@ -26,6 +26,12 @@ pub(crate) enum NodeText<'a> {
2626
Original(&'a str),
2727
}
2828

29+
#[derive(PartialEq, Eq, Hash, Debug)]
30+
pub(crate) struct MentionedColumn {
31+
pub(crate) column: String,
32+
pub(crate) alias: Option<String>,
33+
}
34+
2935
/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
3036
/// That gives us a lot of insight for completions.
3137
/// Other nodes, such as the "relation" node, gives us less but still
@@ -108,8 +114,8 @@ pub(crate) struct CompletionContext<'a> {
108114
pub is_in_error_node: bool,
109115

110116
pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
111-
112117
pub mentioned_table_aliases: HashMap<String, String>,
118+
pub mentioned_columns: HashMap<Option<WrappingClause<'a>>, HashSet<MentionedColumn>>,
113119
}
114120

115121
impl<'a> CompletionContext<'a> {
@@ -127,6 +133,7 @@ impl<'a> CompletionContext<'a> {
127133
is_invocation: false,
128134
mentioned_relations: HashMap::new(),
129135
mentioned_table_aliases: HashMap::new(),
136+
mentioned_columns: HashMap::new(),
130137
is_in_error_node: false,
131138
};
132139

@@ -144,33 +151,46 @@ impl<'a> CompletionContext<'a> {
144151

145152
executor.add_query_results::<queries::RelationMatch>();
146153
executor.add_query_results::<queries::TableAliasMatch>();
154+
executor.add_query_results::<queries::SelectColumnMatch>();
147155

148156
for relation_match in executor.get_iter(stmt_range) {
149157
match relation_match {
150158
QueryResult::Relation(r) => {
151159
let schema_name = r.get_schema(sql);
152160
let table_name = r.get_table(sql);
153161

154-
let current = self.mentioned_relations.get_mut(&schema_name);
155-
156-
match current {
157-
Some(c) => {
158-
c.insert(table_name);
159-
}
160-
None => {
161-
let mut new = HashSet::new();
162-
new.insert(table_name);
163-
self.mentioned_relations.insert(schema_name, new);
164-
}
165-
};
162+
if let Some(c) = self.mentioned_relations.get_mut(&schema_name) {
163+
c.insert(table_name);
164+
} else {
165+
let mut new = HashSet::new();
166+
new.insert(table_name);
167+
self.mentioned_relations.insert(schema_name, new);
168+
}
166169
}
167-
168170
QueryResult::TableAliases(table_alias_match) => {
169171
self.mentioned_table_aliases.insert(
170172
table_alias_match.get_alias(sql),
171173
table_alias_match.get_table(sql),
172174
);
173175
}
176+
QueryResult::SelectClauseColumns(c) => {
177+
let mentioned = MentionedColumn {
178+
column: c.get_column(sql),
179+
alias: c.get_alias(sql),
180+
};
181+
182+
if let Some(cols) = self
183+
.mentioned_columns
184+
.get_mut(&Some(WrappingClause::Select))
185+
{
186+
cols.insert(mentioned);
187+
} else {
188+
let mut new = HashSet::new();
189+
new.insert(mentioned);
190+
self.mentioned_columns
191+
.insert(Some(WrappingClause::Select), new);
192+
}
193+
}
174194
};
175195
}
176196
}

crates/pgt_completions/src/providers/columns.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,4 +484,93 @@ mod tests {
484484
)
485485
.await;
486486
}
487+
488+
#[tokio::test]
489+
async fn prefers_not_mentioned_columns() {
490+
let setup = r#"
491+
create schema auth;
492+
493+
create table public.one (
494+
id serial primary key,
495+
a text,
496+
b text,
497+
z text
498+
);
499+
500+
create table public.two (
501+
id serial primary key,
502+
c text,
503+
d text,
504+
e text
505+
);
506+
"#;
507+
508+
assert_complete_results(
509+
format!(
510+
"select {} from public.one o join public.two on o.id = t.id;",
511+
CURSOR_POS
512+
)
513+
.as_str(),
514+
vec![
515+
CompletionAssertion::Label("a".to_string()),
516+
CompletionAssertion::Label("b".to_string()),
517+
CompletionAssertion::Label("c".to_string()),
518+
CompletionAssertion::Label("d".to_string()),
519+
CompletionAssertion::Label("e".to_string()),
520+
],
521+
setup,
522+
)
523+
.await;
524+
525+
// "a" is already mentioned, so it jumps down
526+
assert_complete_results(
527+
format!(
528+
"select a, {} from public.one o join public.two on o.id = t.id;",
529+
CURSOR_POS
530+
)
531+
.as_str(),
532+
vec![
533+
CompletionAssertion::Label("b".to_string()),
534+
CompletionAssertion::Label("c".to_string()),
535+
CompletionAssertion::Label("d".to_string()),
536+
CompletionAssertion::Label("e".to_string()),
537+
CompletionAssertion::Label("id".to_string()),
538+
CompletionAssertion::Label("z".to_string()),
539+
CompletionAssertion::Label("a".to_string()),
540+
],
541+
setup,
542+
)
543+
.await;
544+
545+
// "id" of table one is mentioned, but table two isn't –
546+
// its priority stays up
547+
assert_complete_results(
548+
format!(
549+
"select o.id, a, b, c, d, e, {} from public.one o join public.two on o.id = t.id;",
550+
CURSOR_POS
551+
)
552+
.as_str(),
553+
vec![
554+
CompletionAssertion::LabelAndDesc(
555+
"id".to_string(),
556+
"Table: public.two".to_string(),
557+
),
558+
CompletionAssertion::Label("z".to_string()),
559+
],
560+
setup,
561+
)
562+
.await;
563+
564+
// "id" is ambiguous, so both "id" columns are lowered in priority
565+
assert_complete_results(
566+
format!(
567+
"select id, a, b, c, d, e, {} from public.one o join public.two on o.id = t.id;",
568+
CURSOR_POS
569+
)
570+
.as_str(),
571+
vec![CompletionAssertion::Label("z".to_string())],
572+
setup,
573+
)
574+
.await;
575+
}
487576
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
use crate::{
2+
CompletionItemKind,
3+
builder::{CompletionBuilder, PossibleCompletionItem},
4+
context::CompletionContext,
5+
relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore},
6+
};
7+
8+
use super::helper::get_completion_text_with_schema_or_alias;
9+
10+
pub fn complete_functions<'a>(ctx: &'a CompletionContext, builder: &mut CompletionBuilder<'a>) {
11+
let available_functions = &ctx.schema_cache.functions;
12+
13+
for func in available_functions {
14+
let relevance = CompletionRelevanceData::Function(func);
15+
16+
let item = PossibleCompletionItem {
17+
label: func.name.clone(),
18+
score: CompletionScore::from(relevance.clone()),
19+
filter: CompletionFilter::from(relevance),
20+
description: format!("Schema: {}", func.schema),
21+
kind: CompletionItemKind::Function,
22+
completion_text: get_completion_text_with_schema_or_alias(
23+
ctx,
24+
&func.name,
25+
&func.schema,
26+
),
27+
};
28+
29+
builder.add_item(item);
30+
}
31+
}
32+
33+
#[cfg(test)]
34+
mod tests {
35+
use crate::{
36+
CompletionItem, CompletionItemKind, complete,
37+
test_helper::{CURSOR_POS, get_test_deps, get_test_params},
38+
};
39+
40+
#[tokio::test]
41+
async fn completes_fn() {
42+
let setup = r#"
43+
create or replace function cool()
44+
returns trigger
45+
language plpgsql
46+
security invoker
47+
as $$
48+
begin
49+
raise exception 'dont matter';
50+
end;
51+
$$;
52+
"#;
53+
54+
let query = format!("select coo{}", CURSOR_POS);
55+
56+
let (tree, cache) = get_test_deps(setup, query.as_str().into()).await;
57+
let params = get_test_params(&tree, &cache, query.as_str().into());
58+
let results = complete(params);
59+
60+
let CompletionItem { label, .. } = results
61+
.into_iter()
62+
.next()
63+
.expect("Should return at least one completion item");
64+
65+
assert_eq!(label, "cool");
66+
}
67+
68+
#[tokio::test]
69+
async fn prefers_fn_if_invocation() {
70+
let setup = r#"
71+
create table coos (
72+
id serial primary key,
73+
name text
74+
);
75+
76+
create or replace function cool()
77+
returns trigger
78+
language plpgsql
79+
security invoker
80+
as $$
81+
begin
82+
raise exception 'dont matter';
83+
end;
84+
$$;
85+
"#;
86+
87+
let query = format!(r#"select * from coo{}()"#, CURSOR_POS);
88+
89+
let (tree, cache) = get_test_deps(setup, query.as_str().into()).await;
90+
let params = get_test_params(&tree, &cache, query.as_str().into());
91+
let results = complete(params);
92+
93+
let CompletionItem { label, kind, .. } = results
94+
.into_iter()
95+
.next()
96+
.expect("Should return at least one completion item");
97+
98+
assert_eq!(label, "cool");
99+
assert_eq!(kind, CompletionItemKind::Function);
100+
}
101+
102+
#[tokio::test]
103+
async fn prefers_fn_in_select_clause() {
104+
let setup = r#"
105+
create table coos (
106+
id serial primary key,
107+
name text
108+
);
109+
110+
create or replace function cool()
111+
returns trigger
112+
language plpgsql
113+
security invoker
114+
as $$
115+
begin
116+
raise exception 'dont matter';
117+
end;
118+
$$;
119+
"#;
120+
121+
let query = format!(r#"select coo{}"#, CURSOR_POS);
122+
123+
let (tree, cache) = get_test_deps(setup, query.as_str().into()).await;
124+
let params = get_test_params(&tree, &cache, query.as_str().into());
125+
let results = complete(params);
126+
127+
let CompletionItem { label, kind, .. } = results
128+
.into_iter()
129+
.next()
130+
.expect("Should return at least one completion item");
131+
132+
assert_eq!(label, "cool");
133+
assert_eq!(kind, CompletionItemKind::Function);
134+
}
135+
136+
#[tokio::test]
137+
async fn prefers_function_in_from_clause_if_invocation() {
138+
let setup = r#"
139+
create table coos (
140+
id serial primary key,
141+
name text
142+
);
143+
144+
create or replace function cool()
145+
returns trigger
146+
language plpgsql
147+
security invoker
148+
as $$
149+
begin
150+
raise exception 'dont matter';
151+
end;
152+
$$;
153+
"#;
154+
155+
let query = format!(r#"select * from coo{}()"#, CURSOR_POS);
156+
157+
let (tree, cache) = get_test_deps(setup, query.as_str().into()).await;
158+
let params = get_test_params(&tree, &cache, query.as_str().into());
159+
let results = complete(params);
160+
161+
let CompletionItem { label, kind, .. } = results
162+
.into_iter()
163+
.next()
164+
.expect("Should return at least one completion item");
165+
166+
assert_eq!(label, "cool");
167+
assert_eq!(kind, CompletionItemKind::Function);
168+
}
169+
}

0 commit comments

Comments
 (0)