Skip to content

Commit 2f06c49

Browse files
authored
Merge pull request #38 from supabase/feat/deep-cst-attempt-2
deep cst attempt 2
2 parents 2edf9eb + 197f351 commit 2f06c49

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2313
-936
lines changed
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
use pg_query_proto_parser::{FieldType, Node, ProtoParser};
2+
use proc_macro2::{Ident, TokenStream};
3+
use quote::{format_ident, quote};
4+
5+
pub fn get_child_token_range_mod(_item: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
6+
let parser = ProtoParser::new("./libpg_query/protobuf/pg_query.proto");
7+
8+
let proto_file = parser.parse();
9+
10+
let node_identifiers = node_identifiers(&proto_file.nodes);
11+
let node_handlers = node_handlers(&proto_file.nodes);
12+
13+
quote! {
14+
use log::{debug};
15+
use pg_query::{protobuf::ScanToken, protobuf::Token, NodeEnum, protobuf::SortByDir};
16+
use cstree::text::{TextRange, TextSize};
17+
18+
#[derive(Debug)]
19+
struct TokenProperty {
20+
value: Option<String>,
21+
token: Option<Token>,
22+
}
23+
24+
impl From<i32> for TokenProperty {
25+
fn from(value: i32) -> TokenProperty {
26+
TokenProperty {
27+
value: Some(value.to_string()),
28+
token: None,
29+
}
30+
}
31+
}
32+
33+
impl From<u32> for TokenProperty {
34+
fn from(value: u32) -> TokenProperty {
35+
TokenProperty {
36+
value: Some(value.to_string()),
37+
token: None,
38+
}
39+
}
40+
}
41+
42+
43+
impl From<i64> for TokenProperty {
44+
fn from(value: i64) -> TokenProperty {
45+
TokenProperty {
46+
value: Some(value.to_string()),
47+
token: None,
48+
}
49+
}
50+
}
51+
52+
impl From<u64> for TokenProperty {
53+
fn from(value: u64) -> TokenProperty {
54+
TokenProperty {
55+
value: Some(value.to_string()),
56+
token: None,
57+
}
58+
}
59+
}
60+
61+
impl From<f64> for TokenProperty {
62+
fn from(value: f64) -> TokenProperty {
63+
TokenProperty {
64+
value: Some(value.to_string()),
65+
token: None,
66+
}
67+
}
68+
}
69+
70+
impl From<bool> for TokenProperty {
71+
fn from(value: bool) -> TokenProperty {
72+
TokenProperty {
73+
value: Some(value.to_string()),
74+
token: None,
75+
}
76+
}
77+
}
78+
79+
impl From<String> for TokenProperty {
80+
fn from(value: String) -> TokenProperty {
81+
assert!(value.len() > 0, "String property value has length 0");
82+
TokenProperty {
83+
value: Some(value.to_lowercase()),
84+
token: None,
85+
}
86+
}
87+
}
88+
89+
90+
impl From<&pg_query::protobuf::Integer> for TokenProperty {
91+
fn from(node: &pg_query::protobuf::Integer) -> TokenProperty {
92+
TokenProperty {
93+
value: Some(node.ival.to_string()),
94+
token: Some(Token::Iconst)
95+
}
96+
}
97+
}
98+
99+
impl From<&pg_query::protobuf::Boolean> for TokenProperty {
100+
fn from(node: &pg_query::protobuf::Boolean) -> TokenProperty {
101+
TokenProperty {
102+
value: Some(node.boolval.to_string()),
103+
token: match node.boolval {
104+
true => Some(Token::TrueP),
105+
false => Some(Token::FalseP),
106+
}
107+
}
108+
}
109+
}
110+
111+
impl From<Token> for TokenProperty {
112+
fn from(token: Token) -> TokenProperty {
113+
TokenProperty {
114+
value: None,
115+
token: Some(token),
116+
}
117+
}
118+
}
119+
120+
fn get_token_text(token: &ScanToken ,text: &str) -> String {
121+
let start = usize::try_from(token.start).unwrap();
122+
let end = usize::try_from(token.end).unwrap();
123+
text.chars()
124+
.skip(start)
125+
.take(end - start)
126+
.collect::<String>()
127+
.to_lowercase()
128+
}
129+
130+
131+
/// list of aliases from https://www.postgresql.org/docs/current/datatype.html
132+
const ALIASES: [&[&str]; 2]= [
133+
&["integer", "int", "int4"],
134+
&["real", "float4"],
135+
];
136+
137+
/// returns a list of aliases for a string. primarily used for data types.
138+
fn aliases(text: &str) -> Vec<&str> {
139+
for alias in ALIASES {
140+
if alias.contains(&text) {
141+
return alias.to_vec();
142+
}
143+
}
144+
return vec![text];
145+
}
146+
147+
#[derive(Debug)]
148+
pub enum ChildTokenRangeResult {
149+
TooManyTokens,
150+
NoTokens,
151+
/// indices are the .start of all child tokens used to estimate the range
152+
ChildTokenRange { used_token_indices: Vec<i32>, range: TextRange },
153+
}
154+
155+
pub fn get_child_token_range(node: &NodeEnum, tokens: Vec<&ScanToken>, text: &str, nearest_parent_location: Option<u32>) -> ChildTokenRangeResult {
156+
let mut child_tokens: Vec<&ScanToken> = Vec::new();
157+
158+
// if true, we found more than one valid token for at least one property of the node
159+
let mut has_too_many_tokens: bool = false;
160+
161+
let mut get_token = |property: TokenProperty| {
162+
let possible_tokens = tokens
163+
.iter()
164+
.filter_map(|t| {
165+
if property.token.is_some() {
166+
// if a token is set, we can safely ignore all tokens that are not of the same type
167+
if t.token() != property.token.unwrap() {
168+
return None;
169+
}
170+
}
171+
172+
// make a string comparison of the text of the token and the property value
173+
if property.value.is_some() {
174+
let mut token_text = get_token_text(t, text);
175+
// if token is Sconst, remove leading and trailing quotes
176+
if t.token() == Token::Sconst {
177+
let string_delimiter: &[char; 2] = &['\'', '$'];
178+
token_text = token_text.trim_start_matches(string_delimiter).trim_end_matches(string_delimiter).to_string();
179+
}
180+
181+
if !aliases(property.value.as_ref().unwrap()).contains(&token_text.as_str()) {
182+
return None;
183+
}
184+
}
185+
186+
Some(t)
187+
})
188+
.collect::<Vec<&&ScanToken>>();
189+
190+
if possible_tokens.len() == 0 {
191+
debug!(
192+
"No matching token found for property {:#?} of node {:#?} in {:#?} with tokens {:#?}",
193+
property, node, text, tokens
194+
);
195+
return;
196+
}
197+
198+
if possible_tokens.len() == 1 {
199+
debug!(
200+
"Found token {:#?} for property {:#?} of node {:#?}",
201+
possible_tokens[0], property, node
202+
);
203+
child_tokens.push(possible_tokens[0]);
204+
return;
205+
}
206+
207+
if nearest_parent_location.is_none() {
208+
debug!("Found {:#?} for property {:#?} and no nearest_parent_location set", possible_tokens, property);
209+
has_too_many_tokens = true;
210+
return;
211+
}
212+
213+
let token = possible_tokens
214+
.iter().map(|t| ((nearest_parent_location.unwrap() as i32 - t.start), t))
215+
.min_by_key(|(d, _)| d.to_owned())
216+
.map(|(_, t)| t);
217+
218+
debug!("Selected {:#?} as token closest from parent {:#?} as location {:#?}", token.unwrap(), node, nearest_parent_location);
219+
220+
child_tokens.push(token.unwrap());
221+
};
222+
223+
match node {
224+
#(NodeEnum::#node_identifiers(n) => {#node_handlers}),*,
225+
};
226+
227+
228+
if has_too_many_tokens == true {
229+
ChildTokenRangeResult::TooManyTokens
230+
} else if child_tokens.len() == 0 {
231+
ChildTokenRangeResult::NoTokens
232+
} else {
233+
ChildTokenRangeResult::ChildTokenRange {
234+
used_token_indices: child_tokens.iter().map(|t| t.start).collect(),
235+
range: TextRange::new(
236+
TextSize::from(child_tokens.iter().min_by_key(|t| t.start).unwrap().start as u32),
237+
TextSize::from(child_tokens.iter().max_by_key(|t| t.end).unwrap().end as u32),
238+
)
239+
}
240+
}
241+
}
242+
}
243+
}
244+
245+
fn node_identifiers(nodes: &[Node]) -> Vec<Ident> {
246+
nodes
247+
.iter()
248+
.map(|node| format_ident!("{}", &node.name))
249+
.collect()
250+
}
251+
252+
fn node_handlers(nodes: &[Node]) -> Vec<TokenStream> {
253+
nodes
254+
.iter()
255+
.map(|node| {
256+
let string_property_handlers = string_property_handlers(&node);
257+
let custom_handlers = custom_handlers(&node);
258+
quote! {
259+
#custom_handlers
260+
#(#string_property_handlers)*
261+
}
262+
})
263+
.collect()
264+
}
265+
266+
fn custom_handlers(node: &Node) -> TokenStream {
267+
match node.name.as_str() {
268+
"SelectStmt" => quote! {
269+
get_token(TokenProperty::from(Token::Select));
270+
if n.distinct_clause.len() > 0 {
271+
get_token(TokenProperty::from(Token::Distinct));
272+
}
273+
},
274+
"Integer" => quote! {
275+
get_token(TokenProperty::from(n));
276+
},
277+
"WindowDef" => quote! {
278+
if n.partition_clause.len() > 0 {
279+
get_token(TokenProperty::from(Token::Window));
280+
} else {
281+
get_token(TokenProperty::from(Token::Over));
282+
}
283+
},
284+
"Boolean" => quote! {
285+
get_token(TokenProperty::from(n));
286+
},
287+
"AStar" => quote! {
288+
get_token(TokenProperty::from(Token::Ascii42));
289+
},
290+
"FuncCall" => quote! {
291+
if n.agg_filter.is_some() {
292+
get_token(TokenProperty::from(Token::Filter));
293+
}
294+
},
295+
"SqlvalueFunction" => quote! {
296+
match n.op {
297+
// 1 SvfopCurrentDate
298+
// 2 SvfopCurrentTime
299+
// 3 SvfopCurrentTimeN
300+
// 4 SvfopCurrentTimestamp
301+
// 5 SvfopCurrentTimestampN
302+
// 6 SvfopLocaltime
303+
// 7 SvfopLocaltimeN
304+
// 8 SvfopLocaltimestamp
305+
// 9 SvfopLocaltimestampN
306+
// 10 SvfopCurrentRole
307+
10 => get_token(TokenProperty::from(Token::CurrentRole)),
308+
// 11 SvfopCurrentUser
309+
11 => get_token(TokenProperty::from(Token::CurrentUser)),
310+
// 12 SvfopUser
311+
// 13 SvfopSessionUser
312+
// 14 SvfopCurrentCatalog
313+
// 15 SvfopCurrentSchema
314+
_ => panic!("Unknown SqlvalueFunction {:#?}", n.op),
315+
}
316+
},
317+
"SortBy" => quote! {
318+
get_token(TokenProperty::from(Token::Order));
319+
match n.sortby_dir {
320+
2 => get_token(TokenProperty::from(Token::Asc)),
321+
3 => get_token(TokenProperty::from(Token::Desc)),
322+
_ => {}
323+
}
324+
},
325+
"AConst" => quote! {
326+
if n.isnull {
327+
get_token(TokenProperty::from(Token::NullP));
328+
}
329+
},
330+
_ => quote! {},
331+
}
332+
}
333+
334+
fn string_property_handlers(node: &Node) -> Vec<TokenStream> {
335+
node.fields
336+
.iter()
337+
.filter_map(|field| {
338+
if field.repeated {
339+
return None;
340+
}
341+
let field_name = format_ident!("{}", field.name.as_str());
342+
match field.field_type {
343+
// just handle string values for now
344+
FieldType::String => Some(quote! {
345+
// most string values are never None, but an empty string
346+
if n.#field_name.len() > 0 {
347+
get_token(TokenProperty::from(n.#field_name.to_owned()));
348+
}
349+
}),
350+
_ => None,
351+
}
352+
})
353+
.collect()
354+
}

crates/codegen/src/get_location.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,36 @@ pub fn get_location_mod(_item: proc_macro2::TokenStream) -> proc_macro2::TokenSt
1414
quote! {
1515
use pg_query::NodeEnum;
1616

17-
// Returns the location of a node
18-
pub fn get_location(node: &NodeEnum) -> Option<i32> {
17+
/// Returns the location of a node
18+
pub fn get_location(node: &NodeEnum) -> Option<u32> {
19+
let loc = get_location_internal(node);
20+
if loc.is_some() {
21+
u32::try_from(loc.unwrap()).ok()
22+
} else {
23+
None
24+
}
25+
}
26+
27+
fn get_location_internal(node: &NodeEnum) -> Option<i32> {
1928
let location = match node {
20-
// for some nodes, the location of the node itself is after their childrens location.
29+
// for some nodes, the location of the node itself is after their children location.
2130
// we implement the logic for those nodes manually.
2231
// if you add one, make sure to add its name to `manual_node_names()`.
2332
NodeEnum::BoolExpr(n) => {
2433
let a = n.args.iter().min_by(|a, b| {
25-
let loc_a = get_location(&a.node.as_ref().unwrap());
26-
let loc_b = get_location(&b.node.as_ref().unwrap());
34+
let loc_a = get_location_internal(&a.node.as_ref().unwrap());
35+
let loc_b = get_location_internal(&b.node.as_ref().unwrap());
2736
loc_a.cmp(&loc_b)
2837
});
29-
get_location(&a.unwrap().node.as_ref().unwrap())
38+
get_location_internal(&a.unwrap().node.as_ref().unwrap())
3039
},
31-
NodeEnum::AExpr(n) => get_location(&n.lexpr.as_ref().unwrap().node.as_ref().unwrap()),
40+
NodeEnum::AExpr(n) => get_location_internal(&n.lexpr.as_ref().unwrap().node.as_ref().unwrap()),
3241
#(NodeEnum::#node_identifiers(n) => #location_idents),*
3342
};
3443
if location.is_some() && location.unwrap() < 0 {
3544
None
3645
} else {
37-
location
46+
location
3847
}
3948
}
4049
}

0 commit comments

Comments
 (0)