From 96f6a75bc82a9dabd25f84b917b5358244bd55a4 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Fri, 26 Jul 2024 06:13:30 +0000
Subject: [PATCH] add refutable pattern function macro

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/macros/mod.rs       |  4 +++
 src/macros/refutable.rs | 61 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 65 insertions(+)
 create mode 100644 src/macros/refutable.rs

diff --git a/src/macros/mod.rs b/src/macros/mod.rs
index b01e5275a..a0e61324a 100644
--- a/src/macros/mod.rs
+++ b/src/macros/mod.rs
@@ -1,6 +1,7 @@
 mod admin;
 mod cargo;
 mod debug;
+mod refutable;
 mod rustc;
 mod utils;
 
@@ -19,3 +20,6 @@ pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { d
 
 #[proc_macro]
 pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) }
+
+#[proc_macro_attribute]
+pub fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { refutable::refutable(args, input) }
diff --git a/src/macros/refutable.rs b/src/macros/refutable.rs
new file mode 100644
index 000000000..6a6884e0e
--- /dev/null
+++ b/src/macros/refutable.rs
@@ -0,0 +1,61 @@
+use proc_macro::{Span, TokenStream};
+use quote::{quote, ToTokens};
+use syn::{parse_macro_input, AttributeArgs, FnArg::Typed, Ident, ItemFn, Pat, PatIdent, PatType, Stmt};
+
+pub(super) fn refutable(args: TokenStream, input: TokenStream) -> TokenStream {
+	let _args = parse_macro_input!(args as AttributeArgs);
+	let mut item = parse_macro_input!(input as ItemFn);
+
+	let inputs = item.sig.inputs.clone();
+	let stmt = &mut item.block.stmts;
+	let sig = &mut item.sig;
+	for (i, input) in inputs.iter().enumerate() {
+		let Typed(PatType {
+			pat,
+			..
+		}) = input
+		else {
+			continue;
+		};
+
+		let Pat::Struct(ref pat) = **pat else {
+			continue;
+		};
+
+		let variant = &pat.path;
+		let fields = &pat.fields;
+
+		// new versions of syn can replace this kronecker kludge with get_mut()
+		for (j, input) in sig.inputs.iter_mut().enumerate() {
+			if i != j {
+				continue;
+			}
+
+			let Typed(PatType {
+				ref mut pat,
+				..
+			}) = input
+			else {
+				continue;
+			};
+
+			let name = format!("_args_{i}");
+			*pat = Box::new(Pat::Ident(PatIdent {
+				ident: Ident::new(&name, Span::call_site().into()),
+				attrs: Vec::new(),
+				by_ref: None,
+				mutability: None,
+				subpat: None,
+			}));
+
+			let field = fields.iter();
+			let refute = quote! {
+				let #variant { #( #field ),*, .. } = #name else { panic!("incorrect variant passed to function argument {i}"); };
+			};
+
+			stmt.insert(0, syn::parse2::<Stmt>(refute).expect("syntax error"));
+		}
+	}
+
+	item.into_token_stream().into()
+}
-- 
GitLab