diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 1b70fdec387f22be71ec5b0185a711c49d5720d0..c25c08d101f2e71747022b8cd71f6d1318059c8d 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1273,16 +1273,12 @@ async fn make_join_request( make_join_counter = make_join_counter.saturating_add(1); if let Err(ref e) = make_join_response { - trace!("make_join ErrorKind string: {:?}", e.error_code().to_string()); + trace!("make_join ErrorKind string: {:?}", e.kind().to_string()); // converting to a string is necessary (i think) because ruma is forcing us to // fill in the struct for M_INCOMPATIBLE_ROOM_VERSION - if e.error_code() - .to_string() - .contains("M_INCOMPATIBLE_ROOM_VERSION") - || e.error_code() - .to_string() - .contains("M_UNSUPPORTED_ROOM_VERSION") + if e.kind().to_string().contains("M_INCOMPATIBLE_ROOM_VERSION") + || e.kind().to_string().contains("M_UNSUPPORTED_ROOM_VERSION") { incompatible_room_version_count = incompatible_room_version_count.saturating_add(1); } diff --git a/src/core/error.rs b/src/core/error.rs deleted file mode 100644 index 0146021757333dea98306db28d270dc284f42db3..0000000000000000000000000000000000000000 --- a/src/core/error.rs +++ /dev/null @@ -1,384 +0,0 @@ -use std::{ - any::Any, - convert::Infallible, - fmt, - panic::{RefUnwindSafe, UnwindSafe}, -}; - -use bytes::BytesMut; -use http::StatusCode; -use http_body_util::Full; -use ruma::{ - api::{client::uiaa::UiaaResponse, OutgoingResponse}, - OwnedServerName, -}; - -use crate::{debug::panic_str, debug_error, error}; - -#[macro_export] -macro_rules! err { - (error!($($args:tt),+)) => {{ - $crate::error!($($args),+); - $crate::error::Error::Err(std::format!($($args),+)) - }}; - - (debug_error!($($args:tt),+)) => {{ - $crate::debug_error!($($args),+); - $crate::error::Error::Err(std::format!($($args),+)) - }}; - - ($variant:ident(error!($($args:tt),+))) => {{ - $crate::error!($($args),+); - $crate::error::Error::$variant(std::format!($($args),+)) - }}; - - ($variant:ident(debug_error!($($args:tt),+))) => {{ - $crate::debug_error!($($args),+); - $crate::error::Error::$variant(std::format!($($args),+)) - }}; - - (Config($item:literal, $($args:tt),+)) => {{ - $crate::error!(config = %$item, $($args),+); - $crate::error::Error::Config($item, std::format!($($args),+)) - }}; - - ($variant:ident($($args:tt),+)) => { - $crate::error::Error::$variant(std::format!($($args),+)) - }; - - ($string:literal$(,)? $($args:tt),*) => { - $crate::error::Error::Err(std::format!($string, $($args),*)) - }; -} - -#[macro_export] -macro_rules! Err { - ($($args:tt)*) => { - Err($crate::err!($($args)*)) - }; -} - -#[derive(thiserror::Error)] -pub enum Error { - #[error("PANIC!")] - PanicAny(Box<dyn Any + Send>), - #[error("PANIC! {0}")] - Panic(&'static str, Box<dyn Any + Send + 'static>), - - // std - #[error("{0}")] - Fmt(#[from] fmt::Error), - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - #[error("{0}")] - Utf8Error(#[from] std::str::Utf8Error), - #[error("{0}")] - FromUtf8Error(#[from] std::string::FromUtf8Error), - #[error("{0}")] - TryFromSliceError(#[from] std::array::TryFromSliceError), - #[error("{0}")] - TryFromIntError(#[from] std::num::TryFromIntError), - #[error("{0}")] - ParseIntError(#[from] std::num::ParseIntError), - #[error("{0}")] - ParseFloatError(#[from] std::num::ParseFloatError), - - // third-party - #[error("Join error: {0}")] - JoinError(#[from] tokio::task::JoinError), - #[error("Regex error: {0}")] - Regex(#[from] regex::Error), - #[error("Tracing filter error: {0}")] - TracingFilter(#[from] tracing_subscriber::filter::ParseError), - #[error("Tracing reload error: {0}")] - TracingReload(#[from] tracing_subscriber::reload::Error), - #[error("Image error: {0}")] - Image(#[from] image::error::ImageError), - #[error("Request error: {0}")] - Reqwest(#[from] reqwest::Error), - #[error("{0}")] - Extension(#[from] axum::extract::rejection::ExtensionRejection), - #[error("{0}")] - Path(#[from] axum::extract::rejection::PathRejection), - #[error("{0}")] - Http(#[from] http::Error), - - // ruma - #[error("{0}")] - IntoHttpError(#[from] ruma::api::error::IntoHttpError), - #[error("{0}")] - RumaError(#[from] ruma::api::client::error::Error), - #[error("uiaa")] - Uiaa(ruma::api::client::uiaa::UiaaInfo), - #[error("{0}")] - Mxid(#[from] ruma::IdParseError), - #[error("{0}: {1}")] - BadRequest(ruma::api::client::error::ErrorKind, &'static str), - #[error("from {0}: {1}")] - Redaction(OwnedServerName, ruma::canonical_json::RedactionError), - #[error("Remote server {0} responded with: {1}")] - Federation(OwnedServerName, ruma::api::client::error::Error), - #[error("{0} in {1}")] - InconsistentRoomState(&'static str, ruma::OwnedRoomId), - - // conduwuit - #[error("Arithmetic operation failed: {0}")] - Arithmetic(&'static str), - #[error("There was a problem with the '{0}' directive in your configuration: {1}")] - Config(&'static str, String), - #[error("{0}")] - BadDatabase(&'static str), - #[error("{0}")] - Database(String), - #[error("{0}")] - BadServerResponse(&'static str), - #[error("{0}")] - Conflict(&'static str), // This is only needed for when a room alias already exists - - // unique / untyped - #[error("{0}")] - Err(String), -} - -impl Error { - pub fn bad_database(message: &'static str) -> Self { - error!("BadDatabase: {}", message); - Self::BadDatabase(message) - } - - #[must_use] - pub fn from_panic(e: Box<dyn Any + Send>) -> Self { Self::Panic(panic_str(&e), e) } - - pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { - match self { - Self::Panic(_, e) | Self::PanicAny(e) => e, - Self::JoinError(e) => e.into_panic(), - _ => std::panic::panic_any(self), - } - } - - /// Sanitizes public-facing errors that can leak sensitive information. - pub fn sanitized_string(&self) -> String { - match self { - Self::Database(..) => String::from("Database error occurred."), - Self::Io(..) => String::from("I/O error occurred."), - _ => self.to_string(), - } - } - - /// Get the panic message string. - pub fn panic_str(self) -> Option<&'static str> { self.is_panic().then_some(panic_str(&self.into_panic())) } - - /// Returns the Matrix error code / error kind - #[inline] - pub fn error_code(&self) -> ruma::api::client::error::ErrorKind { - use ruma::api::client::error::ErrorKind::Unknown; - - match self { - Self::Federation(_, error) => ruma_error_kind(error).clone(), - Self::BadRequest(kind, _) => kind.clone(), - _ => Unknown, - } - } - - /// Check if the Error is trafficking a panic object. - #[inline] - pub fn is_panic(&self) -> bool { - match &self { - Self::Panic(..) | Self::PanicAny(..) => true, - Self::JoinError(e) => e.is_panic(), - _ => false, - } - } -} - -impl UnwindSafe for Error {} -impl RefUnwindSafe for Error {} - -impl From<Infallible> for Error { - #[cold] - #[inline(never)] - fn from(_i: Infallible) -> Self { - panic!("infallible error should never exist"); - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } -} - -#[inline] -pub fn else_log<T, E>(error: E) -> Result<T, Infallible> -where - T: Default, - Error: From<E>, -{ - Ok(default_log(error)) -} - -#[inline] -pub fn else_debug_log<T, E>(error: E) -> Result<T, Infallible> -where - T: Default, - Error: From<E>, -{ - Ok(default_debug_log(error)) -} - -#[inline] -pub fn default_log<T, E>(error: E) -> T -where - T: Default, - Error: From<E>, -{ - let error = Error::from(error); - inspect_log(&error); - T::default() -} - -#[inline] -pub fn default_debug_log<T, E>(error: E) -> T -where - T: Default, - Error: From<E>, -{ - let error = Error::from(error); - inspect_debug_log(&error); - T::default() -} - -#[inline] -pub fn map_log<E>(error: E) -> Error -where - Error: From<E>, -{ - let error = Error::from(error); - inspect_log(&error); - error -} - -#[inline] -pub fn map_debug_log<E>(error: E) -> Error -where - Error: From<E>, -{ - let error = Error::from(error); - inspect_debug_log(&error); - error -} - -#[inline] -pub fn inspect_log<E: fmt::Display>(error: &E) { - error!("{error}"); -} - -#[inline] -pub fn inspect_debug_log<E: fmt::Debug>(error: &E) { - debug_error!("{error:?}"); -} - -#[cold] -#[inline(never)] -pub fn infallible(_e: &Infallible) { - panic!("infallible error should never exist"); -} - -impl axum::response::IntoResponse for Error { - fn into_response(self) -> axum::response::Response { - let response: UiaaResponse = self.into(); - response - .try_into_http_response::<BytesMut>() - .inspect_err(|e| error!("error response error: {e}")) - .map_or_else( - |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), - |r| r.map(BytesMut::freeze).map(Full::new).into_response(), - ) - } -} - -impl From<Error> for UiaaResponse { - fn from(error: Error) -> Self { - if let Error::Uiaa(uiaainfo) = error { - return Self::AuthResponse(uiaainfo); - } - - let kind = match &error { - Error::Federation(_, ref error) | Error::RumaError(ref error) => ruma_error_kind(error), - Error::BadRequest(kind, _) => kind, - _ => &ruma::api::client::error::ErrorKind::Unknown, - }; - - let status_code = match &error { - Error::Federation(_, ref error) | Error::RumaError(ref error) => error.status_code, - Error::BadRequest(ref kind, _) => bad_request_code(kind), - Error::Conflict(_) => StatusCode::CONFLICT, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; - - let message = match &error { - Error::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), - Error::RumaError(ref error) => ruma_error_message(error), - _ => format!("{error}"), - }; - - let body = ruma::api::client::error::ErrorBody::Standard { - kind: kind.clone(), - message, - }; - - Self::MatrixError(ruma::api::client::error::Error { - status_code, - body, - }) - } -} - -fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> StatusCode { - use ruma::api::client::error::ErrorKind::*; - - match kind { - GuestAccessForbidden - | ThreepidAuthFailed - | UserDeactivated - | ThreepidDenied - | WrongRoomKeysVersion { - .. - } - | Forbidden { - .. - } => StatusCode::FORBIDDEN, - - UnknownToken { - .. - } - | MissingToken - | Unauthorized => StatusCode::UNAUTHORIZED, - - LimitExceeded { - .. - } => StatusCode::TOO_MANY_REQUESTS, - - TooLarge => StatusCode::PAYLOAD_TOO_LARGE, - - NotFound | Unrecognized => StatusCode::NOT_FOUND, - - _ => StatusCode::BAD_REQUEST, - } -} - -fn ruma_error_message(error: &ruma::api::client::error::Error) -> String { - if let ruma::api::client::error::ErrorBody::Standard { - message, - .. - } = &error.body - { - return message.to_string(); - } - - format!("{error}") -} - -fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ruma::api::client::error::ErrorKind { - e.error_kind() - .unwrap_or(&ruma::api::client::error::ErrorKind::Unknown) -} diff --git a/src/core/error/err.rs b/src/core/error/err.rs new file mode 100644 index 0000000000000000000000000000000000000000..12dd4f3b717c098e7a24119b47c9017696db96dc --- /dev/null +++ b/src/core/error/err.rs @@ -0,0 +1,42 @@ +#[macro_export] +macro_rules! Err { + ($($args:tt)*) => { + Err($crate::err!($($args)*)) + }; +} + +#[macro_export] +macro_rules! err { + (error!($($args:tt),+)) => {{ + $crate::error!($($args),+); + $crate::error::Error::Err(std::format!($($args),+)) + }}; + + (debug_error!($($args:tt),+)) => {{ + $crate::debug_error!($($args),+); + $crate::error::Error::Err(std::format!($($args),+)) + }}; + + ($variant:ident(error!($($args:tt),+))) => {{ + $crate::error!($($args),+); + $crate::error::Error::$variant(std::format!($($args),+)) + }}; + + ($variant:ident(debug_error!($($args:tt),+))) => {{ + $crate::debug_error!($($args),+); + $crate::error::Error::$variant(std::format!($($args),+)) + }}; + + (Config($item:literal, $($args:tt),+)) => {{ + $crate::error!(config = %$item, $($args),+); + $crate::error::Error::Config($item, std::format!($($args),+)) + }}; + + ($variant:ident($($args:tt),+)) => { + $crate::error::Error::$variant(std::format!($($args),+)) + }; + + ($string:literal$(,)? $($args:tt),*) => { + $crate::error::Error::Err(std::format!($string, $($args),*)) + }; +} diff --git a/src/core/error/log.rs b/src/core/error/log.rs new file mode 100644 index 0000000000000000000000000000000000000000..c272bf730c42ede2dd18d18e313546bdb18bf241 --- /dev/null +++ b/src/core/error/log.rs @@ -0,0 +1,74 @@ +use std::{convert::Infallible, fmt}; + +use super::Error; +use crate::{debug_error, error}; + +#[inline] +pub fn else_log<T, E>(error: E) -> Result<T, Infallible> +where + T: Default, + Error: From<E>, +{ + Ok(default_log(error)) +} + +#[inline] +pub fn else_debug_log<T, E>(error: E) -> Result<T, Infallible> +where + T: Default, + Error: From<E>, +{ + Ok(default_debug_log(error)) +} + +#[inline] +pub fn default_log<T, E>(error: E) -> T +where + T: Default, + Error: From<E>, +{ + let error = Error::from(error); + inspect_log(&error); + T::default() +} + +#[inline] +pub fn default_debug_log<T, E>(error: E) -> T +where + T: Default, + Error: From<E>, +{ + let error = Error::from(error); + inspect_debug_log(&error); + T::default() +} + +#[inline] +pub fn map_log<E>(error: E) -> Error +where + Error: From<E>, +{ + let error = Error::from(error); + inspect_log(&error); + error +} + +#[inline] +pub fn map_debug_log<E>(error: E) -> Error +where + Error: From<E>, +{ + let error = Error::from(error); + inspect_debug_log(&error); + error +} + +#[inline] +pub fn inspect_log<E: fmt::Display>(error: &E) { + error!("{error}"); +} + +#[inline] +pub fn inspect_debug_log<E: fmt::Debug>(error: &E) { + debug_error!("{error:?}"); +} diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..cbc22f3e7ab4b67eb0fc5b7167fb42c3babf6845 --- /dev/null +++ b/src/core/error/mod.rs @@ -0,0 +1,156 @@ +mod err; +mod log; +mod panic; +mod response; + +use std::{any::Any, convert::Infallible, fmt}; + +pub use log::*; + +use crate::error; + +#[derive(thiserror::Error)] +pub enum Error { + #[error("PANIC!")] + PanicAny(Box<dyn Any + Send>), + #[error("PANIC! {0}")] + Panic(&'static str, Box<dyn Any + Send + 'static>), + + // std + #[error("{0}")] + Fmt(#[from] fmt::Error), + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("{0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("{0}")] + FromUtf8Error(#[from] std::string::FromUtf8Error), + #[error("{0}")] + TryFromSliceError(#[from] std::array::TryFromSliceError), + #[error("{0}")] + TryFromIntError(#[from] std::num::TryFromIntError), + #[error("{0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("{0}")] + ParseFloatError(#[from] std::num::ParseFloatError), + + // third-party + #[error("Join error: {0}")] + JoinError(#[from] tokio::task::JoinError), + #[error("Regex error: {0}")] + Regex(#[from] regex::Error), + #[error("Tracing filter error: {0}")] + TracingFilter(#[from] tracing_subscriber::filter::ParseError), + #[error("Tracing reload error: {0}")] + TracingReload(#[from] tracing_subscriber::reload::Error), + #[error("Image error: {0}")] + Image(#[from] image::error::ImageError), + #[error("Request error: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("{0}")] + Extension(#[from] axum::extract::rejection::ExtensionRejection), + #[error("{0}")] + Path(#[from] axum::extract::rejection::PathRejection), + #[error("{0}")] + Http(#[from] http::Error), + + // ruma + #[error("{0}")] + IntoHttpError(#[from] ruma::api::error::IntoHttpError), + #[error("{0}")] + RumaError(#[from] ruma::api::client::error::Error), + #[error("uiaa")] + Uiaa(ruma::api::client::uiaa::UiaaInfo), + #[error("{0}")] + Mxid(#[from] ruma::IdParseError), + #[error("{0}: {1}")] + BadRequest(ruma::api::client::error::ErrorKind, &'static str), + #[error("from {0}: {1}")] + Redaction(ruma::OwnedServerName, ruma::canonical_json::RedactionError), + #[error("Remote server {0} responded with: {1}")] + Federation(ruma::OwnedServerName, ruma::api::client::error::Error), + #[error("{0} in {1}")] + InconsistentRoomState(&'static str, ruma::OwnedRoomId), + + // conduwuit + #[error("Arithmetic operation failed: {0}")] + Arithmetic(&'static str), + #[error("There was a problem with the '{0}' directive in your configuration: {1}")] + Config(&'static str, String), + #[error("{0}")] + BadDatabase(&'static str), + #[error("{0}")] + Database(String), + #[error("{0}")] + BadServerResponse(&'static str), + #[error("{0}")] + Conflict(&'static str), // This is only needed for when a room alias already exists + + // unique / untyped + #[error("{0}")] + Err(String), +} + +impl Error { + pub fn bad_database(message: &'static str) -> Self { + error!("BadDatabase: {}", message); + Self::BadDatabase(message) + } + + /// Sanitizes public-facing errors that can leak sensitive information. + pub fn sanitized_string(&self) -> String { + match self { + Self::Database(..) => String::from("Database error occurred."), + Self::Io(..) => String::from("I/O error occurred."), + _ => self.to_string(), + } + } + + pub fn message(&self) -> String { + match self { + Self::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), + Self::RumaError(ref error) => response::ruma_error_message(error), + _ => format!("{self}"), + } + } + + /// Returns the Matrix error code / error kind + #[inline] + pub fn kind(&self) -> ruma::api::client::error::ErrorKind { + use ruma::api::client::error::ErrorKind::Unknown; + + match self { + Self::Federation(_, error) => response::ruma_error_kind(error).clone(), + Self::BadRequest(kind, _) => kind.clone(), + _ => Unknown, + } + } + + pub fn status_code(&self) -> http::StatusCode { + match self { + Self::Federation(_, ref error) | Self::RumaError(ref error) => error.status_code, + Self::BadRequest(ref kind, _) => response::bad_request_code(kind), + Self::Conflict(_) => http::StatusCode::CONFLICT, + _ => http::StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } +} + +#[allow(clippy::fallible_impl_from)] +impl From<Infallible> for Error { + #[cold] + #[inline(never)] + fn from(_e: Infallible) -> Self { + panic!("infallible error should never exist"); + } +} + +#[cold] +#[inline(never)] +pub fn infallible(_e: &Infallible) { + panic!("infallible error should never exist"); +} diff --git a/src/core/error/panic.rs b/src/core/error/panic.rs new file mode 100644 index 0000000000000000000000000000000000000000..c070f78669ddbd484a5c8835e2bc345073dfe634 --- /dev/null +++ b/src/core/error/panic.rs @@ -0,0 +1,41 @@ +use std::{ + any::Any, + panic::{panic_any, RefUnwindSafe, UnwindSafe}, +}; + +use super::Error; +use crate::debug; + +impl UnwindSafe for Error {} +impl RefUnwindSafe for Error {} + +impl Error { + pub fn panic(self) -> ! { panic_any(self.into_panic()) } + + #[must_use] + pub fn from_panic(e: Box<dyn Any + Send>) -> Self { Self::Panic(debug::panic_str(&e), e) } + + pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { + match self { + Self::Panic(_, e) | Self::PanicAny(e) => e, + Self::JoinError(e) => e.into_panic(), + _ => Box::new(self), + } + } + + /// Get the panic message string. + pub fn panic_str(self) -> Option<&'static str> { + self.is_panic() + .then_some(debug::panic_str(&self.into_panic())) + } + + /// Check if the Error is trafficking a panic object. + #[inline] + pub fn is_panic(&self) -> bool { + match &self { + Self::Panic(..) | Self::PanicAny(..) => true, + Self::JoinError(e) => e.is_panic(), + _ => false, + } + } +} diff --git a/src/core/error/response.rs b/src/core/error/response.rs new file mode 100644 index 0000000000000000000000000000000000000000..4ea76e264072786255307af24c832afb8165491d --- /dev/null +++ b/src/core/error/response.rs @@ -0,0 +1,88 @@ +use bytes::BytesMut; +use http::StatusCode; +use http_body_util::Full; +use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; + +use super::Error; +use crate::error; + +impl axum::response::IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + let response: UiaaResponse = self.into(); + response + .try_into_http_response::<BytesMut>() + .inspect_err(|e| error!("error response error: {e}")) + .map_or_else( + |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), + |r| r.map(BytesMut::freeze).map(Full::new).into_response(), + ) + } +} + +impl From<Error> for UiaaResponse { + fn from(error: Error) -> Self { + if let Error::Uiaa(uiaainfo) = error { + return Self::AuthResponse(uiaainfo); + } + + let body = ruma::api::client::error::ErrorBody::Standard { + kind: error.kind(), + message: error.message(), + }; + + Self::MatrixError(ruma::api::client::error::Error { + status_code: error.status_code(), + body, + }) + } +} + +pub(super) fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> StatusCode { + use ruma::api::client::error::ErrorKind::*; + + match kind { + GuestAccessForbidden + | ThreepidAuthFailed + | UserDeactivated + | ThreepidDenied + | WrongRoomKeysVersion { + .. + } + | Forbidden { + .. + } => StatusCode::FORBIDDEN, + + UnknownToken { + .. + } + | MissingToken + | Unauthorized => StatusCode::UNAUTHORIZED, + + LimitExceeded { + .. + } => StatusCode::TOO_MANY_REQUESTS, + + TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + + NotFound | Unrecognized => StatusCode::NOT_FOUND, + + _ => StatusCode::BAD_REQUEST, + } +} + +pub(super) fn ruma_error_message(error: &ruma::api::client::error::Error) -> String { + if let ruma::api::client::error::ErrorBody::Standard { + message, + .. + } = &error.body + { + return message.to_string(); + } + + format!("{error}") +} + +pub(super) fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ruma::api::client::error::ErrorKind { + e.error_kind() + .unwrap_or(&ruma::api::client::error::ErrorKind::Unknown) +}