Skip to content

Commit 6c54519

Browse files
authored
support create function definition with $$ (#755)
* support create function definition using '2700775' * fix warn
1 parent d420001 commit 6c54519

File tree

5 files changed

+91
-10
lines changed

5 files changed

+91
-10
lines changed

src/ast/mod.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3777,6 +3777,23 @@ impl fmt::Display for FunctionBehavior {
37773777
}
37783778
}
37793779

3780+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
3781+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
3782+
pub enum FunctionDefinition {
3783+
SingleQuotedDef(String),
3784+
DoubleDollarDef(String),
3785+
}
3786+
3787+
impl fmt::Display for FunctionDefinition {
3788+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3789+
match self {
3790+
FunctionDefinition::SingleQuotedDef(s) => write!(f, "'{s}'")?,
3791+
FunctionDefinition::DoubleDollarDef(s) => write!(f, "$${s}$$")?,
3792+
}
3793+
Ok(())
3794+
}
3795+
}
3796+
37803797
/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html
37813798
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
37823799
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -3788,7 +3805,7 @@ pub struct CreateFunctionBody {
37883805
/// AS 'definition'
37893806
///
37903807
/// Note that Hive's `AS class_name` is also parsed here.
3791-
pub as_: Option<String>,
3808+
pub as_: Option<FunctionDefinition>,
37923809
/// RETURN expression
37933810
pub return_: Option<Expr>,
37943811
/// USING ... (Hive only)
@@ -3804,7 +3821,7 @@ impl fmt::Display for CreateFunctionBody {
38043821
write!(f, " {behavior}")?;
38053822
}
38063823
if let Some(definition) = &self.as_ {
3807-
write!(f, " AS '{definition}'")?;
3824+
write!(f, " AS {definition}")?;
38083825
}
38093826
if let Some(expr) = &self.return_ {
38103827
write!(f, " RETURN {expr}")?;

src/parser.rs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,7 +2310,7 @@ impl<'a> Parser<'a> {
23102310
if dialect_of!(self is HiveDialect) {
23112311
let name = self.parse_object_name()?;
23122312
self.expect_keyword(Keyword::AS)?;
2313-
let class_name = self.parse_literal_string()?;
2313+
let class_name = self.parse_function_definition()?;
23142314
let params = CreateFunctionBody {
23152315
as_: Some(class_name),
23162316
using: self.parse_optional_create_function_using()?,
@@ -2400,7 +2400,7 @@ impl<'a> Parser<'a> {
24002400
}
24012401
if self.parse_keyword(Keyword::AS) {
24022402
ensure_not_set(&body.as_, "AS")?;
2403-
body.as_ = Some(self.parse_literal_string()?);
2403+
body.as_ = Some(self.parse_function_definition()?);
24042404
} else if self.parse_keyword(Keyword::LANGUAGE) {
24052405
ensure_not_set(&body.language, "LANGUAGE")?;
24062406
body.language = Some(self.parse_identifier()?);
@@ -3883,6 +3883,33 @@ impl<'a> Parser<'a> {
38833883
}
38843884
}
38853885

3886+
pub fn parse_function_definition(&mut self) -> Result<FunctionDefinition, ParserError> {
3887+
let peek_token = self.peek_token();
3888+
match peek_token.token {
3889+
Token::DoubleDollarQuoting if dialect_of!(self is PostgreSqlDialect) => {
3890+
self.next_token();
3891+
let mut func_desc = String::new();
3892+
loop {
3893+
if let Some(next_token) = self.next_token_no_skip() {
3894+
match &next_token.token {
3895+
Token::DoubleDollarQuoting => break,
3896+
Token::EOF => {
3897+
return self.expected(
3898+
"literal string",
3899+
TokenWithLocation::wrap(Token::EOF),
3900+
);
3901+
}
3902+
token => func_desc.push_str(token.to_string().as_str()),
3903+
}
3904+
}
3905+
}
3906+
Ok(FunctionDefinition::DoubleDollarDef(func_desc))
3907+
}
3908+
_ => Ok(FunctionDefinition::SingleQuotedDef(
3909+
self.parse_literal_string()?,
3910+
)),
3911+
}
3912+
}
38863913
/// Parse a literal string
38873914
pub fn parse_literal_string(&mut self) -> Result<String, ParserError> {
38883915
let next_token = self.next_token();

src/tokenizer.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ pub enum Token {
145145
PGCubeRoot,
146146
/// `?` or `$` , a prepared statement arg placeholder
147147
Placeholder(String),
148+
/// `$$`, used for PostgreSQL create function definition
149+
DoubleDollarQuoting,
148150
/// ->, used as a operator to extract json field in PostgreSQL
149151
Arrow,
150152
/// ->>, used as a operator to extract json field as text in PostgreSQL
@@ -215,6 +217,7 @@ impl fmt::Display for Token {
215217
Token::LongArrow => write!(f, "->>"),
216218
Token::HashArrow => write!(f, "#>"),
217219
Token::HashLongArrow => write!(f, "#>>"),
220+
Token::DoubleDollarQuoting => write!(f, "$$"),
218221
}
219222
}
220223
}
@@ -770,8 +773,14 @@ impl<'a> Tokenizer<'a> {
770773
}
771774
'$' => {
772775
chars.next();
773-
let s = peeking_take_while(chars, |ch| ch.is_alphanumeric() || ch == '_');
774-
Ok(Some(Token::Placeholder(String::from("$") + &s)))
776+
match chars.peek() {
777+
Some('$') => self.consume_and_return(chars, Token::DoubleDollarQuoting),
778+
_ => {
779+
let s =
780+
peeking_take_while(chars, |ch| ch.is_alphanumeric() || ch == '_');
781+
Ok(Some(Token::Placeholder(String::from("$") + &s)))
782+
}
783+
}
775784
}
776785
//whitespace check (including unicode chars) should be last as it covers some of the chars above
777786
ch if ch.is_whitespace() => {

tests/sqlparser_hive.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
//! is also tested (on the inputs it can handle).
1717
1818
use sqlparser::ast::{
19-
CreateFunctionBody, CreateFunctionUsing, Expr, Function, Ident, ObjectName, SelectItem,
20-
Statement, TableFactor, UnaryOperator, Value,
19+
CreateFunctionBody, CreateFunctionUsing, Expr, Function, FunctionDefinition, Ident, ObjectName,
20+
SelectItem, Statement, TableFactor, UnaryOperator, Value,
2121
};
2222
use sqlparser::dialect::{GenericDialect, HiveDialect};
2323
use sqlparser::parser::ParserError;
@@ -252,7 +252,9 @@ fn parse_create_function() {
252252
assert_eq!(
253253
params,
254254
CreateFunctionBody {
255-
as_: Some("org.random.class.Name".to_string()),
255+
as_: Some(FunctionDefinition::SingleQuotedDef(
256+
"org.random.class.Name".to_string()
257+
)),
256258
using: Some(CreateFunctionUsing::Jar(
257259
"hdfs://somewhere.com:8020/very/far".to_string()
258260
)),

tests/sqlparser_postgres.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2257,7 +2257,9 @@ fn parse_create_function() {
22572257
params: CreateFunctionBody {
22582258
language: Some("SQL".into()),
22592259
behavior: Some(FunctionBehavior::Immutable),
2260-
as_: Some("select $1 + $2;".into()),
2260+
as_: Some(FunctionDefinition::SingleQuotedDef(
2261+
"select $1 + $2;".into()
2262+
)),
22612263
..Default::default()
22622264
},
22632265
}
@@ -2292,4 +2294,28 @@ fn parse_create_function() {
22922294
},
22932295
}
22942296
);
2297+
2298+
let sql = r#"CREATE OR REPLACE FUNCTION increment(i INTEGER) RETURNS INTEGER LANGUAGE plpgsql AS $$ BEGIN RETURN i + 1; END; $$"#;
2299+
assert_eq!(
2300+
pg().verified_stmt(sql),
2301+
Statement::CreateFunction {
2302+
or_replace: true,
2303+
temporary: false,
2304+
name: ObjectName(vec![Ident::new("increment")]),
2305+
args: Some(vec![CreateFunctionArg::with_name(
2306+
"i",
2307+
DataType::Integer(None)
2308+
)]),
2309+
return_type: Some(DataType::Integer(None)),
2310+
params: CreateFunctionBody {
2311+
language: Some("plpgsql".into()),
2312+
behavior: None,
2313+
return_: None,
2314+
as_: Some(FunctionDefinition::DoubleDollarDef(
2315+
" BEGIN RETURN i + 1; END; ".into()
2316+
)),
2317+
using: None
2318+
},
2319+
}
2320+
);
22952321
}

0 commit comments

Comments
 (0)