Skip to content

Commit 8fc8082

Browse files
authored
Extend Visitor trait for Value type (apache#1725)
1 parent 3ace97c commit 8fc8082

File tree

2 files changed

+98
-8
lines changed

2 files changed

+98
-8
lines changed

src/ast/value.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ use sqlparser_derive::{Visit, VisitMut};
3333
/// Primitive SQL values such as number and string
3434
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
3535
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36-
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
36+
#[cfg_attr(
37+
feature = "visitor",
38+
derive(Visit, VisitMut),
39+
visit(with = "visit_value")
40+
)]
41+
3742
pub enum Value {
3843
/// Numeric literal
3944
#[cfg(not(feature = "bigdecimal"))]

src/ast/visitor.rs

+92-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
1919
20-
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
20+
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
2121
use core::ops::ControlFlow;
2222

2323
/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
@@ -233,6 +233,16 @@ pub trait Visitor {
233233
fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
234234
ControlFlow::Continue(())
235235
}
236+
237+
/// Invoked for any Value that appear in the AST before visiting children
238+
fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
239+
ControlFlow::Continue(())
240+
}
241+
242+
/// Invoked for any Value that appear in the AST after visiting children
243+
fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
244+
ControlFlow::Continue(())
245+
}
236246
}
237247

238248
/// A visitor that can be used to mutate an AST tree.
@@ -337,6 +347,16 @@ pub trait VisitorMut {
337347
fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
338348
ControlFlow::Continue(())
339349
}
350+
351+
/// Invoked for any value that appear in the AST before visiting children
352+
fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
353+
ControlFlow::Continue(())
354+
}
355+
356+
/// Invoked for any statements that appear in the AST after visiting children
357+
fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
358+
ControlFlow::Continue(())
359+
}
340360
}
341361

342362
struct RelationVisitor<F>(F);
@@ -647,6 +667,7 @@ where
647667
#[cfg(test)]
648668
mod tests {
649669
use super::*;
670+
use crate::ast::Statement;
650671
use crate::dialect::GenericDialect;
651672
use crate::parser::Parser;
652673
use crate::tokenizer::Tokenizer;
@@ -720,17 +741,16 @@ mod tests {
720741
}
721742
}
722743

723-
fn do_visit(sql: &str) -> Vec<String> {
744+
fn do_visit<V: Visitor>(sql: &str, visitor: &mut V) -> Statement {
724745
let dialect = GenericDialect {};
725746
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
726747
let s = Parser::new(&dialect)
727748
.with_tokens(tokens)
728749
.parse_statement()
729750
.unwrap();
730751

731-
let mut visitor = TestVisitor::default();
732-
s.visit(&mut visitor);
733-
visitor.visited
752+
s.visit(visitor);
753+
s
734754
}
735755

736756
#[test]
@@ -889,8 +909,9 @@ mod tests {
889909
),
890910
];
891911
for (sql, expected) in tests {
892-
let actual = do_visit(sql);
893-
let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect();
912+
let mut visitor = TestVisitor::default();
913+
let _ = do_visit(sql, &mut visitor);
914+
let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
894915
assert_eq!(actual, expected)
895916
}
896917
}
@@ -920,3 +941,67 @@ mod tests {
920941
s.visit(&mut visitor);
921942
}
922943
}
944+
945+
#[cfg(test)]
946+
mod visit_mut_tests {
947+
use crate::ast::{Statement, Value, VisitMut, VisitorMut};
948+
use crate::dialect::GenericDialect;
949+
use crate::parser::Parser;
950+
use crate::tokenizer::Tokenizer;
951+
use core::ops::ControlFlow;
952+
953+
#[derive(Default)]
954+
struct MutatorVisitor {
955+
index: u64,
956+
}
957+
958+
impl VisitorMut for MutatorVisitor {
959+
type Break = ();
960+
961+
fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
962+
self.index += 1;
963+
*value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
964+
ControlFlow::Continue(())
965+
}
966+
967+
fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
968+
ControlFlow::Continue(())
969+
}
970+
}
971+
972+
fn do_visit_mut<V: VisitorMut>(sql: &str, visitor: &mut V) -> Statement {
973+
let dialect = GenericDialect {};
974+
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
975+
let mut s = Parser::new(&dialect)
976+
.with_tokens(tokens)
977+
.parse_statement()
978+
.unwrap();
979+
980+
s.visit(visitor);
981+
s
982+
}
983+
984+
#[test]
985+
fn test_value_redact() {
986+
let tests = vec![
987+
(
988+
concat!(
989+
"SELECT * FROM monthly_sales ",
990+
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
991+
"ORDER BY EMPID"
992+
),
993+
concat!(
994+
"SELECT * FROM monthly_sales ",
995+
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
996+
"ORDER BY EMPID"
997+
),
998+
),
999+
];
1000+
1001+
for (sql, expected) in tests {
1002+
let mut visitor = MutatorVisitor::default();
1003+
let mutated = do_visit_mut(sql, &mut visitor);
1004+
assert_eq!(mutated.to_string(), expected)
1005+
}
1006+
}
1007+
}

0 commit comments

Comments
 (0)