From 0627b46f4018e5a589bc106ebae24cb60cc91ae2 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Wed, 10 Jul 2024 06:35:11 +0000
Subject: [PATCH] add panic suite to Error

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/api/server/send.rs |  2 +-
 src/core/debug.rs      |  7 +++--
 src/core/error.rs      | 71 ++++++++++++++++++++++++++++++++----------
 3 files changed, 60 insertions(+), 20 deletions(-)

diff --git a/src/api/server/send.rs b/src/api/server/send.rs
index 08caf1b44..122f564f2 100644
--- a/src/api/server/send.rs
+++ b/src/api/server/send.rs
@@ -84,7 +84,7 @@ pub(crate) async fn send_transaction_message_route(
 	Ok(send_transaction_message::v1::Response {
 		pdus: resolved_map
 			.into_iter()
-			.map(|(e, r)| (e, r.map_err(|e| e.sanitized_error())))
+			.map(|(e, r)| (e, r.map_err(|e| e.sanitized_string())))
 			.collect(),
 	})
 }
diff --git a/src/core/debug.rs b/src/core/debug.rs
index 1f855e520..14d0be87a 100644
--- a/src/core/debug.rs
+++ b/src/core/debug.rs
@@ -1,6 +1,4 @@
-#![allow(dead_code)] // this is a developer's toolbox
-
-use std::panic;
+use std::{any::Any, panic};
 
 /// Export all of the ancillary tools from here as well.
 pub use crate::utils::debug::*;
@@ -79,3 +77,6 @@ pub fn trap() {
 		std::arch::asm!("int3");
 	}
 }
+
+#[must_use]
+pub fn panic_str(p: &Box<dyn Any + Send>) -> &'static str { p.downcast_ref::<&str>().copied().unwrap_or_default() }
diff --git a/src/core/error.rs b/src/core/error.rs
index 7069d1e63..26bac6e06 100644
--- a/src/core/error.rs
+++ b/src/core/error.rs
@@ -1,4 +1,9 @@
-use std::{convert::Infallible, fmt};
+use std::{
+	any::Any,
+	convert::Infallible,
+	fmt,
+	panic::{RefUnwindSafe, UnwindSafe},
+};
 
 use bytes::BytesMut;
 use http::StatusCode;
@@ -8,10 +13,15 @@
 	OwnedServerName,
 };
 
-use crate::{debug_error, error};
+use crate::{debug::panic_str, debug_error, error};
 
 #[derive(thiserror::Error)]
 pub enum Error {
+	#[error("PANIC!")]
+	PanicAny(Box<dyn Any + Send>),
+	#[error("PANIC! {0}")]
+	Panic(&'static str, Box<dyn Any + Send + 'static>),
+
 	// std
 	#[error("{0}")]
 	Fmt(#[from] fmt::Error),
@@ -31,6 +41,8 @@ pub enum Error {
 	ParseFloatError(#[from] std::num::ParseFloatError),
 
 	// third-party
+	#[error("Join error: {0}")]
+	JoinError(#[from] tokio::task::JoinError),
 	#[error("Regex error: {0}")]
 	Regex(#[from] regex::Error),
 	#[error("Tracing filter error: {0}")]
@@ -94,6 +106,29 @@ pub fn bad_config(message: &str) -> Self {
 		Self::BadConfig(message.to_owned())
 	}
 
+	#[must_use]
+	pub fn from_panic(e: Box<dyn Any + Send>) -> Self { Self::Panic(panic_str(&e), e) }
+
+	pub fn into_panic(self) -> Box<dyn Any + Send + 'static> {
+		match self {
+			Self::Panic(_, e) | Self::PanicAny(e) => e,
+			Self::JoinError(e) => e.into_panic(),
+			_ => std::panic::panic_any(self),
+		}
+	}
+
+	/// Sanitizes public-facing errors that can leak sensitive information.
+	pub fn sanitized_string(&self) -> String {
+		match self {
+			Self::Database(..) => String::from("Database error occurred."),
+			Self::Io(..) => String::from("I/O error occurred."),
+			_ => self.to_string(),
+		}
+	}
+
+	/// Get the panic message string.
+	pub fn panic_str(self) -> Option<&'static str> { self.is_panic().then_some(panic_str(&self.into_panic())) }
+
 	/// Returns the Matrix error code / error kind
 	#[inline]
 	pub fn error_code(&self) -> ruma::api::client::error::ErrorKind {
@@ -106,16 +141,28 @@ pub fn error_code(&self) -> ruma::api::client::error::ErrorKind {
 		}
 	}
 
-	/// Sanitizes public-facing errors that can leak sensitive information.
-	pub fn sanitized_error(&self) -> String {
-		match self {
-			Self::Database(..) => String::from("Database error occurred."),
-			Self::Io(..) => String::from("I/O error occurred."),
-			_ => self.to_string(),
+	/// Check if the Error is trafficking a panic object.
+	#[inline]
+	pub fn is_panic(&self) -> bool {
+		match &self {
+			Self::Panic(..) | Self::PanicAny(..) => true,
+			Self::JoinError(e) => e.is_panic(),
+			_ => false,
 		}
 	}
 }
 
+impl UnwindSafe for Error {}
+impl RefUnwindSafe for Error {}
+
+impl From<Infallible> for Error {
+	fn from(i: Infallible) -> Self { match i {} }
+}
+
+impl fmt::Debug for Error {
+	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") }
+}
+
 #[inline]
 pub fn else_log<T, E>(error: E) -> Result<T, Infallible>
 where
@@ -186,14 +233,6 @@ pub fn inspect_debug_log<E: fmt::Debug>(error: &E) {
 	debug_error!("{error:?}");
 }
 
-impl fmt::Debug for Error {
-	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") }
-}
-
-impl From<Infallible> for Error {
-	fn from(i: Infallible) -> Self { match i {} }
-}
-
 impl axum::response::IntoResponse for Error {
 	fn into_response(self) -> axum::response::Response {
 		let response: UiaaResponse = self.into();
-- 
GitLab