Skip to content
Snippets Groups Projects
Commit 96f6a75b authored by Jason Volk's avatar Jason Volk
Browse files

add refutable pattern function macro


Signed-off-by: default avatarJason Volk <jason@zemos.net>
parent 68f42baf
No related branches found
No related tags found
1 merge request!530de-global services
mod admin; mod admin;
mod cargo; mod cargo;
mod debug; mod debug;
mod refutable;
mod rustc; mod rustc;
mod utils; mod utils;
...@@ -19,3 +20,6 @@ pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { d ...@@ -19,3 +20,6 @@ pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { d
#[proc_macro] #[proc_macro]
pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) } pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) }
#[proc_macro_attribute]
pub fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { refutable::refutable(args, input) }
use proc_macro::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, AttributeArgs, FnArg::Typed, Ident, ItemFn, Pat, PatIdent, PatType, Stmt};
pub(super) fn refutable(args: TokenStream, input: TokenStream) -> TokenStream {
let _args = parse_macro_input!(args as AttributeArgs);
let mut item = parse_macro_input!(input as ItemFn);
let inputs = item.sig.inputs.clone();
let stmt = &mut item.block.stmts;
let sig = &mut item.sig;
for (i, input) in inputs.iter().enumerate() {
let Typed(PatType {
pat,
..
}) = input
else {
continue;
};
let Pat::Struct(ref pat) = **pat else {
continue;
};
let variant = &pat.path;
let fields = &pat.fields;
// new versions of syn can replace this kronecker kludge with get_mut()
for (j, input) in sig.inputs.iter_mut().enumerate() {
if i != j {
continue;
}
let Typed(PatType {
ref mut pat,
..
}) = input
else {
continue;
};
let name = format!("_args_{i}");
*pat = Box::new(Pat::Ident(PatIdent {
ident: Ident::new(&name, Span::call_site().into()),
attrs: Vec::new(),
by_ref: None,
mutability: None,
subpat: None,
}));
let field = fields.iter();
let refute = quote! {
let #variant { #( #field ),*, .. } = #name else { panic!("incorrect variant passed to function argument {i}"); };
};
stmt.insert(0, syn::parse2::<Stmt>(refute).expect("syntax error"));
}
}
item.into_token_stream().into()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment