From 96f6a75bc82a9dabd25f84b917b5358244bd55a4 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Fri, 26 Jul 2024 06:13:30 +0000 Subject: [PATCH] add refutable pattern function macro Signed-off-by: Jason Volk <jason@zemos.net> --- src/macros/mod.rs | 4 +++ src/macros/refutable.rs | 61 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 src/macros/refutable.rs diff --git a/src/macros/mod.rs b/src/macros/mod.rs index b01e5275a..a0e61324a 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,6 +1,7 @@ mod admin; mod cargo; mod debug; +mod refutable; mod rustc; mod utils; @@ -19,3 +20,6 @@ pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { d #[proc_macro] pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) } + +#[proc_macro_attribute] +pub fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { refutable::refutable(args, input) } diff --git a/src/macros/refutable.rs b/src/macros/refutable.rs new file mode 100644 index 000000000..6a6884e0e --- /dev/null +++ b/src/macros/refutable.rs @@ -0,0 +1,61 @@ +use proc_macro::{Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::{parse_macro_input, AttributeArgs, FnArg::Typed, Ident, ItemFn, Pat, PatIdent, PatType, Stmt}; + +pub(super) fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { + let _args = parse_macro_input!(args as AttributeArgs); + let mut item = parse_macro_input!(input as ItemFn); + + let inputs = item.sig.inputs.clone(); + let stmt = &mut item.block.stmts; + let sig = &mut item.sig; + for (i, input) in inputs.iter().enumerate() { + let Typed(PatType { + pat, + .. + }) = input + else { + continue; + }; + + let Pat::Struct(ref pat) = **pat else { + continue; + }; + + let variant = &pat.path; + let fields = &pat.fields; + + // new versions of syn can replace this kronecker kludge with get_mut() + for (j, input) in sig.inputs.iter_mut().enumerate() { + if i != j { + continue; + } + + let Typed(PatType { + ref mut pat, + .. + }) = input + else { + continue; + }; + + let name = format!("_args_{i}"); + *pat = Box::new(Pat::Ident(PatIdent { + ident: Ident::new(&name, Span::call_site().into()), + attrs: Vec::new(), + by_ref: None, + mutability: None, + subpat: None, + })); + + let field = fields.iter(); + let refute = quote! { + let #variant { #( #field ),*, .. } = #name else { panic!("incorrect variant passed to function argument {i}"); }; + }; + + stmt.insert(0, syn::parse2::<Stmt>(refute).expect("syntax error")); + } + } + + item.into_token_stream().into() +} -- GitLab