diff --git a/src/main.rs b/src/main.rs index 6481e017e12cb886b175c1a01927fd7a8c788582..6e889175413f197aa3d7fb4ffb58516a3b855202 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ #[cfg(unix)] use std::os::unix::fs::PermissionsExt as _; /* not unix specific, just only for UNIX sockets stuff and *nix * container checks */ -use std::{io, net::SocketAddr, sync::atomic, time::Duration}; +use std::{any::Any, io, net::SocketAddr, sync::atomic, time::Duration}; use axum::{ extract::{DefaultBodyLimit, MatchedPath}, @@ -35,7 +35,7 @@ trace::{DefaultOnFailure, TraceLayer}, ServiceBuilderExt as _, }; -use tracing::{debug, error, info, warn, Level}; +use tracing::{debug, error, info, level_filters::LevelFilter, warn, Level}; use tracing_subscriber::{prelude::*, reload, EnvFilter, Registry}; mod routes; @@ -275,7 +275,6 @@ async fn build(server: &Server) -> io::Result<axum::routing::IntoMakeService<Rou let middlewares = base_middlewares .sensitive_headers([header::AUTHORIZATION]) .sensitive_request_headers([x_forwarded_for].into()) - .layer(CatchPanicLayer::new()) .layer(axum::middleware::from_fn(request_spawn)) .layer( TraceLayer::new_for_http() @@ -290,7 +289,8 @@ async fn build(server: &Server) -> io::Result<axum::routing::IntoMakeService<Rou .max_request_size .try_into() .expect("failed to convert max request size"), - )); + )) + .layer(CatchPanicLayer::custom(catch_panic_layer)); #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] { @@ -590,3 +590,39 @@ fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { Ok(()) } + +#[allow(clippy::needless_pass_by_value)] +fn catch_panic_layer(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> { + let details = if cfg!(debug_assertions) || LevelFilter::current() == LevelFilter::TRACE { + if let Some(s) = err.downcast_ref::<String>() { + s.clone() + } else if let Some(s) = err.downcast_ref::<&str>() { + s.to_string() + } else { + "Unknown internal server error occurred.".to_owned() + } + } else { + "Internal server error occurred.".to_owned() + }; + + let body = if cfg!(debug_assertions) || LevelFilter::current() == LevelFilter::TRACE { + serde_json::json!({ + "errcode": "M_UNKNOWN", + "error": "M_UNKNOWN: Internal server error occurred", + "details": details, + }) + .to_string() + } else { + serde_json::json!({ + "errcode": "M_UNKNOWN", + "error": "M_UNKNOWN: Internal server error occurred", + }) + .to_string() + }; + + http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header(header::CONTENT_TYPE, "application/json") + .body(http_body_util::Full::from(body)) + .expect("Failed to create response for our panic catcher?") +}