Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# will have compiled files and executables
/target/
/sqlparser_bench/target/
/derive/target/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note to self

I tested using

cargo test --all --features=visitor

We need to add this feature to the CI tests

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that explains the poor coverage I guess

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock
Expand Down
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ version = "0.28.0"
authors = ["Andy Grove <andygrove73@gmail.com>"]
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
documentation = "https://docs.rs/sqlparser/"
keywords = [ "ansi", "sql", "lexer", "parser" ]
keywords = ["ansi", "sql", "lexer", "parser"]
repository = "https://github.com/sqlparser-rs/sqlparser-rs"
license = "Apache-2.0"
include = [
Expand All @@ -23,6 +23,7 @@ default = ["std"]
std = []
# Enable JSON output in the `cli` example:
json_example = ["serde_json", "serde"]
visitor = ["sqlparser_derive"]

[dependencies]
bigdecimal = { version = "0.3", features = ["serde"], optional = true }
Expand All @@ -32,6 +33,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
# of dev-dependencies because of
# https://github.com/rust-lang/cargo/issues/1596
serde_json = { version = "1.0", optional = true }
sqlparser_derive = { version = "0.1", path = "derive", optional = true }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself: if we merge this PR I think we should document this feature flag -- it seems documentation is missing for the existing feature flags

Copy link
Copy Markdown
Contributor

@alamb alamb Dec 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs in #774 #776


[dev-dependencies]
simple_logger = "4.0"
Expand Down
23 changes: 23 additions & 0 deletions derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "sqlparser_derive"
description = "proc macro for sqlparser"
version = "0.1.0"
authors = ["Andy Grove <andygrove73@gmail.com>"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 maybe this should be something different. I will think about a more appropriate author -- maybe sqlparser-rs contributors 🤔

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I'm not sure as well. Theoretically I'd say that the owner/creator should be the author, but since this is an expansion, not sure whether put both of you or just the @tustvold

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok as is

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to "sqlparser authors" in #775

homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
documentation = "https://docs.rs/sqlparser/"
keywords = ["ansi", "sql", "lexer", "parser"]
repository = "https://github.com/sqlparser-rs/sqlparser-rs"
license = "Apache-2.0"
include = [
"src/**/*.rs",
"Cargo.toml",
]
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = "1.0"
proc-macro2 = "1.0"
quote = "1.0"
79 changes: 79 additions & 0 deletions derive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SQL Parser Derive Macro

## Visit

This crate contains a procedural macro that can automatically derive implementations of the `Visit` trait
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self -- this should have a doc link back to the sqlparser crate

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in #779


```rust
#[derive(Visit)]
struct Foo {
boolean: bool,
bar: Bar,
}

#[derive(Visit)]
enum Bar {
A(),
B(String, bool),
C { named: i32 },
}
```

Will generate code akin to

```rust
impl Visit for Foo {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
self.boolean.visit(visitor)?;
self.bar.visit(visitor)?;
ControlFlow::Continue(())
}
}

impl Visit for Bar {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
match self {
Self::A() => {}
Self::B(_1, _2) => {
_1.visit(visitor)?;
_2.visit(visitor)?;
}
Self::C { named } => {
named.visit(visitor)?;
}
}
ControlFlow::Continue(())
}
}
```

Additionally certain types may wish to call a corresponding method on visitor before recursing

```rust
#[derive(Visit)]
#[visit(with = "visit_expr")]
enum Expr {
A(),
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
}
```

Will generate

```rust
impl Visit for Bar {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.visit_expr(self)?;
match self {
Self::A() => {}
Self::B(_1, _2, _3) => {
_1.visit(visitor)?;
visitor.visit_relation(_3)?;
_2.visit(visitor)?;
_3.visit(visitor)?;
}
}
ControlFlow::Continue(())
}
}
```
184 changes: 184 additions & 0 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
};

/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(Visit, attributes(visit))]
Comment thread
alamb marked this conversation as resolved.
pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Parse the input tokens into a syntax tree.
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;

let attributes = Attributes::parse(&input.attrs);
// Add a bound `T: HeapSize` to every type parameter T.
let generics = add_trait_bounds(input.generics);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let (pre_visit, post_visit) = attributes.visit(quote!(self));
let children = visit_children(&input.data);

let expanded = quote! {
// The generated impl.
impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::Visitor>(&self, visitor: &mut V) -> ::std::ops::ControlFlow<V::Break> {
#pre_visit
#children
#post_visit
::std::ops::ControlFlow::Continue(())
}
}
};

proc_macro::TokenStream::from(expanded)
}

/// Parses attributes that can be provided to this macro
///
/// `#[visit(leaf, with = "visit_expr")]`
#[derive(Default)]
struct Attributes {
/// Content for the `with` attribute
with: Option<Ident>,
}

impl Attributes {
fn parse(attrs: &[Attribute]) -> Self {
let mut out = Self::default();
for attr in attrs.iter().filter(|a| a.path.is_ident("visit")) {
let meta = attr.parse_meta().expect("visit attribute");
match meta {
Meta::List(l) => {
for nested in &l.nested {
match nested {
NestedMeta::Meta(Meta::NameValue(v)) => out.parse_name_value(v),
_ => panic!("Expected #[visit(key = \"value\")]"),
}
}
}
_ => panic!("Expected #[visit(...)]"),
}
}
out
}

/// Updates self with a name value attribute
fn parse_name_value(&mut self, v: &MetaNameValue) {
if v.path.is_ident("with") {
match &v.lit {
Lit::Str(s) => self.with = Some(format_ident!("{}", s.value(), span = s.span())),
_ => panic!("Expected a string value, got {}", v.lit.to_token_stream()),
}
return;
}
panic!("Unrecognised kv attribute {}", v.path.to_token_stream())
}

/// Returns the pre and post visit token streams
fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
let pre_visit = self.with.as_ref().map(|m| {
let m = format_ident!("pre_{}", m);
quote!(visitor.#m(#s)?;)
});
let post_visit = self.with.as_ref().map(|m| {
let m = format_ident!("post_{}", m);
quote!(visitor.#m(#s)?;)
});
(pre_visit, post_visit)
}
}

// Add a bound `T: Visit` to every type parameter T.
fn add_trait_bounds(mut generics: Generics) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(sqlparser::ast::Visit));
}
}
generics
}

// Generate the body of the visit implementation for the given type
fn visit_children(data: &Data) -> TokenStream {
match data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#name, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
}
}
Fields::Unnamed(fields) => {
let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
let index = Index::from(i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#index, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
}
}
Fields::Unit => {
quote!()
}
},
Data::Enum(data) => {
let statements = data.variants.iter().map(|v| {
let name = &v.ident;
match &v.fields {
Fields::Named(fields) => {
let names = fields.named.iter().map(|f| &f.ident);
let visit = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
});

quote!(
Self::#name { #(#names),* } => {
#(#visit)*
}
)
}
Fields::Unnamed(fields) => {
let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span()));
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
let name = format_ident!("_{}", i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
});

quote! {
Self::#name ( #(#names),*) => {
#(#visit)*
}
}
}
Fields::Unit => {
quote! {
Self::#name => {}
}
}
}
});

quote! {
match self {
#(#statements),*
}
}
}
Data::Union(_) => unimplemented!(),
}
}
8 changes: 8 additions & 0 deletions src/ast/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ use core::fmt;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;

use crate::ast::ObjectName;

use super::value::escape_single_quote_string;

/// SQL data types
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
pub enum DataType {
/// Fixed-length character type e.g. CHARACTER(10)
Character(Option<CharacterLength>),
Expand Down Expand Up @@ -337,6 +341,7 @@ fn format_datetime_precision_and_tz(
/// guarantee compatibility with the input query we must maintain its exact information.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a big fan of this method as a way to minimize maintenance burden on new contributions

I picked a random struct and removed the #[cfg_attr(feature = "visitor", derive(Visit))] to see what would happen if it was forgotten

I think the error is fairly clear 👍

   Compiling sqlparser v0.28.0 (/Users/alamb/Software/sqlparser-rs)
error[E0277]: the trait bound `ast::Password: visitor::Visit` is not satisfied
    --> src/ast/mod.rs:1262:9
     |
1262 |         password: Option<Password>,
     |         ^^^^^^^^ the trait `visitor::Visit` is not implemented for `ast::Password`
     |
     = help: the following other types implement trait `visitor::Visit`:
               Box<T>
               CommentObject
               CreateTableBuilder
               Keyword
               Option<T>
               String
               Vec<T>
               ast::Action
             and 120 others
note: required for `Option<ast::Password>` to implement `visitor::Visit`
    --> src/ast/visitor.rs:21:16
     |
21   | impl<T: Visit> Visit for Option<T> {
     |                ^^^^^     ^^^^^^^^^

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also making it easier to keep up with the merge conflicts 😆

pub enum TimezoneInfo {
/// No information about time zone. E.g., TIMESTAMP
None,
Expand Down Expand Up @@ -384,6 +389,7 @@ impl fmt::Display for TimezoneInfo {
/// [standard]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#exact-numeric-type
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
pub enum ExactNumberInfo {
/// No additional information e.g. `DECIMAL`
None,
Expand Down Expand Up @@ -414,6 +420,7 @@ impl fmt::Display for ExactNumberInfo {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#character-length
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
pub struct CharacterLength {
/// Default (if VARYING) or maximum (if not VARYING) length
pub length: u64,
Expand All @@ -436,6 +443,7 @@ impl fmt::Display for CharacterLength {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#char-length-units
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
pub enum CharLengthUnits {
/// CHARACTERS unit
Characters,
Expand Down
Loading