Skip to content

Commit

Permalink
feat: version working with different types of dereferencing
Browse files Browse the repository at this point in the history
  • Loading branch information
MingweiSamuel committed Sep 13, 2024
1 parent a572ba6 commit 070dee1
Show file tree
Hide file tree
Showing 17 changed files with 541 additions and 174 deletions.
23 changes: 23 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ repository = "https://github.com/hydro-project/matchbox"
proc-macro = true

[dependencies]
proc-macro-error2 = "2.0.0"
proc-macro2 = "1.0.80"
quote = "1.0.37"
syn = { version = "2.0.0", features = ["fold", "full", "extra-traits"] }
syn = { version = "2.0.7", features = ["fold", "full", "extra-traits"] }

[dev-dependencies]
insta = "1.40.0"
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
`Deref` patterns in `match` for stable Rust. Now you can match through `Rc`, `String`, etc.

`matchbox::match_deref!{...}` is a procedural macro, which allows you to use deref patterns right now in stable Rust.
`matchbox::matchbox!{...}` is a procedural macro, which allows you to use deref patterns right now in stable Rust.

For example:
```rust,no_run
Expand All @@ -15,7 +15,7 @@ enum Value {
use Value::*;
let v: &Value = todo!();
matchbox::match_deref!{
matchbox::matchbox!{
match v {
Nil => todo!(),
Cons(Deref @ Symbol(Deref @ "quote"), Deref @ Cons(x, Deref @ Nil)) => todo!(),
Expand All @@ -34,7 +34,7 @@ The macro calls `Deref::deref` internally. Keep in mind that `Deref::deref` take

Consider this code:
```rust,ignore
matchbox::match_deref!{
matchbox::matchbox!{
match v {
Symbol(Deref @ x) => {
// some_code_here
Expand Down
181 changes: 135 additions & 46 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,138 @@
#![warn(missing_docs)]
#![doc = include_str!("../README.md")]

use std::str::FromStr;

use syn::spanned::Spanned;

mod test;

struct PatSingle(syn::Pat);
impl syn::parse::Parse for PatSingle {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let pat = syn::Pat::parse_single(input)?;
Ok(Self(pat))
}
}

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
enum Type {
Owned,
Stamp,
Deref,
}
impl Type {
fn add_ref(self) -> Self {
match self {
Self::Owned => Self::Stamp,
Self::Stamp => Self::Deref,
Self::Deref => Self::Deref,
}
}
fn as_op(self, span: proc_macro2::Span) -> proc_macro2::TokenStream {
match self {
Self::Owned => quote::quote_spanned! {span=> * },
Self::Stamp => quote::quote_spanned! {span=> &* },
Self::Deref => quote::quote_spanned! {span=> &** },
}
}
}
impl FromStr for Type {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"owned" => Ok(Self::Owned),
"stamp" => Ok(Self::Stamp),
"deref" => Ok(Self::Deref),
other => Err(other.to_owned()),
}
}
}

#[derive(Default)]
struct MyFold {
binds: Vec<(syn::Ident, syn::Pat)>,
binds: Vec<(syn::Ident, syn::Pat, Type, proc_macro2::Span)>,
counter: i32,
diagnostics: Vec<proc_macro_error2::Diagnostic>,
}
impl MyFold {
fn handle(&mut self, subpat: syn::Pat, typ: Type, span: proc_macro2::Span) -> syn::PatIdent {
let id = syn::Ident::new(
&format!("a{}", self.counter),
subpat.span().resolved_at(proc_macro2::Span::mixed_site()),
);
self.counter += 1;
self.binds.push((id.clone(), subpat, typ, span));
syn::PatIdent {
attrs: vec![],
by_ref: None,
mutability: None,
ident: id,
subpat: None,
}
}
}

impl syn::fold::Fold for MyFold {
fn fold_pat(&mut self, i: syn::Pat) -> syn::Pat {
if let syn::Pat::Macro(expr_macro) = i {
let span = expr_macro.mac.path.span();
if let Some(typ @ ("deref" | "owned" | "stamp")) = expr_macro
.mac
.path
.get_ident()
.map(ToString::to_string)
.as_deref()
{
match syn::parse2::<PatSingle>(expr_macro.mac.tokens) {
Ok(PatSingle(subpat)) => {
let typ = typ.parse().unwrap();
let pat_ident = self.handle(subpat, typ, span);
syn::Pat::Ident(pat_ident)
}
Err(err) => {
self.diagnostics.push(err.into());
syn::parse_quote_spanned!(span=> _error) // Error placeholder pattern.
}
}
} else {
syn::Pat::Macro(syn::fold::fold_expr_macro(self, expr_macro))
}
} else {
syn::fold::fold_pat(self, i)
}
}

fn fold_pat_ident(&mut self, i: syn::PatIdent) -> syn::PatIdent {
if i.by_ref.is_some() || i.mutability.is_some() || i.ident != "Deref" {
syn::fold::fold_pat_ident(self, i)
} else if let Some(subpat) = i.subpat {
let id = syn::Ident::new(
&format!("a{}", self.counter),
proc_macro2::Span::mixed_site(),
);
self.counter += 1;
self.binds.push((id.clone(), *subpat.1));
syn::PatIdent {
attrs: vec![],
by_ref: None,
mutability: None,
ident: id,
subpat: None,
}
} else if let Some((_at, subpat)) = i.subpat {
self.handle(*subpat, Type::Deref, i.ident.span())
} else {
syn::fold::fold_pat_ident(self, i)
}
}
}

fn tower(binds: &[(syn::Ident, syn::Pat)], yes: syn::Expr, no: &syn::Expr) -> syn::Expr {
fn tower(
binds: &[(syn::Ident, syn::Pat, Type, proc_macro2::Span)],
yes: syn::Expr,
no: &syn::Expr,
add_ref: bool,
) -> syn::Expr {
if binds.is_empty() {
yes
} else {
let id = &binds[0].0;
let pat = &binds[0].1;
let rec = tower(&binds[1..], yes, no);
syn::parse_quote! {
if let #pat = ::core::ops::Deref::deref(#id) {
let (ref id, ref pat, mut typ, span) = binds[0];
let rec = tower(&binds[1..], yes, no, add_ref);
if add_ref {
typ = typ.add_ref();
}
let op = typ.as_op(span);
syn::parse_quote_spanned! {span=>
if let #pat = #op #id {
#rec
} else {
#no
Expand All @@ -49,48 +141,45 @@ fn tower(binds: &[(syn::Ident, syn::Pat)], yes: syn::Expr, no: &syn::Expr) -> sy
}
}

fn do_match_deref(mut m: syn::ExprMatch) -> syn::ExprMatch {
fn matchbox_impl(mut m: syn::ExprMatch) -> syn::ExprMatch {
let mut new_arms = vec![];
for mut arm in m.arms {
use syn::fold::Fold;
let mut my_fold = MyFold {
binds: vec![],
counter: 0,
};

let span = arm.pat.span();
let mut my_fold = MyFold::default();
arm.pat = my_fold.fold_pat(arm.pat);
{
// recurse only after top layer is complete.
let mut i = 0;
while i < my_fold.binds.len() {
let a = std::mem::replace(
&mut my_fold.binds[i].1,
syn::Pat::Verbatim(Default::default()),
syn::Pat::Verbatim(quote::quote_spanned!(span=> )), // Temp placeholder
);
my_fold.binds[i].1 = my_fold.fold_pat(a);
i += 1;
}
}
if !my_fold.binds.is_empty() {
if let Some((if_token, src_guard)) = arm.guard {
let t = tower(&my_fold.binds, *src_guard, &syn::parse_quote! { false });
arm.guard = Some((
if_token,
Box::new(syn::parse_quote! { { #[allow(unused_variables)] #t } }),
));
let (yes, no) = if let Some((_if_token, src_guard)) = arm.guard {
(*src_guard, syn::parse_quote_spanned! {span=> false })
} else {
let t = tower(
&my_fold.binds,
syn::parse_quote! { true },
&syn::parse_quote! { false },
);
arm.guard = Some((
Default::default(),
Box::new(syn::parse_quote! { { #[allow(unused_variables)] #t } }),
));
}
(
syn::parse_quote_spanned! {span=> true },
syn::parse_quote_spanned! {span=> false },
)
};
let t = tower(&my_fold.binds, yes, &no, true);
arm.guard = Some((
syn::Token![if](span),
Box::new(syn::parse_quote_spanned! {span=> { #[allow(unused_variables)] #t } }),
));
*arm.body = tower(
&my_fold.binds,
*arm.body,
&syn::parse_quote! { panic!("Two invocations of Deref::deref returned different outputs on same inputs") },
&syn::parse_quote_spanned! {span=> panic!("Two invocations of Deref::deref returned different outputs on same inputs") },
false,
);
}
new_arms.push(arm);
Expand All @@ -101,7 +190,7 @@ fn do_match_deref(mut m: syn::ExprMatch) -> syn::ExprMatch {

/// See [crate].
#[proc_macro]
pub fn match_deref(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let a = do_match_deref(syn::parse_macro_input!(tokens as syn::ExprMatch));
pub fn matchbox(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let a = matchbox_impl(syn::parse_macro_input!(tokens as syn::ExprMatch));
quote::quote! { #a }.into()
}
22 changes: 11 additions & 11 deletions src/snapshots/matchbox__test__basic-2.snap
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
---
source: src/test.rs
expression: "snapshot!\n{ match () { Deref @ (Deref @ x,) => (), Deref @ (Deref @ x,) => () } }"
expression: "snapshot!\n{ match () { deref!((deref!(x),)) => {} Deref @ (Deref @ x,) => {} } }"
---
fn main() {
match () {
a0 if {
#[allow(unused_variables)]
if let (a1,) = ::core::ops::Deref::deref(a0) {
if let x = ::core::ops::Deref::deref(a1) { true } else { false }
if let (a1,) = &**a0 {
if let x = &**a1 { true } else { false }
} else {
false
}
} => {
if let (a1,) = ::core::ops::Deref::deref(a0) {
if let x = ::core::ops::Deref::deref(a1) {
()
if let (a1,) = &**a0 {
if let x = &**a1 {
{}
} else {
panic!(
"Two invocations of Deref::deref returned different outputs on same inputs",
Expand All @@ -28,15 +28,15 @@ fn main() {
}
a0 if {
#[allow(unused_variables)]
if let (a1,) = ::core::ops::Deref::deref(a0) {
if let x = ::core::ops::Deref::deref(a1) { true } else { false }
if let (a1,) = &**a0 {
if let x = &**a1 { true } else { false }
} else {
false
}
} => {
if let (a1,) = ::core::ops::Deref::deref(a0) {
if let x = ::core::ops::Deref::deref(a1) {
()
if let (a1,) = &**a0 {
if let x = &**a1 {
{}
} else {
panic!(
"Two invocations of Deref::deref returned different outputs on same inputs",
Expand Down
Loading

0 comments on commit 070dee1

Please sign in to comment.