diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 3adfde9774d62480d96d90deb40260536b470fd4..069fe60e4bacbcdcdbfc9e39ea748731eda4170f 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -53,6 +53,8 @@ pub enum Error { Path(#[from] axum::extract::rejection::PathRejection), #[error("{0}")] Http(#[from] http::Error), + #[error("{0}")] + HttpHeader(#[from] http::header::InvalidHeaderValue), // ruma #[error("{0}")] diff --git a/src/router/layers.rs b/src/router/layers.rs index db664b38aa4779e2109c13a8f20f4ca3cd6b14f1..8c2e114b0e0ba6b21793216788780a6c8377649e 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -1,11 +1,11 @@ -use std::{any::Any, io, sync::Arc, time::Duration}; +use std::{any::Any, sync::Arc, time::Duration}; use axum::{ extract::{DefaultBodyLimit, MatchedPath}, Router, }; use axum_client_ip::SecureClientIpSource; -use conduit::Server; +use conduit::{Result, Server}; use http::{ header::{self, HeaderName}, HeaderValue, Method, StatusCode, @@ -22,11 +22,19 @@ use crate::{request, router}; -const CONDUWUIT_CSP: &str = "sandbox; default-src 'none'; font-src 'none'; script-src 'none'; frame-ancestors 'none'; \ - form-action 'none'; base-uri 'none';"; -const CONDUWUIT_PERMISSIONS_POLICY: &str = "interest-cohort=(),browsing-topics=()"; +const CONDUWUIT_CSP: &[&str] = &[ + "sandbox", + "default-src 'none'", + "font-src 'none'", + "script-src 'none'", + "frame-ancestors 'none'", + "form-action 'none'", + "base-uri 'none'", +]; -pub(crate) fn build(server: &Arc<Server>) -> io::Result<Router> { +const CONDUWUIT_PERMISSIONS_POLICY: &[&str] = &["interest-cohort=()", "browsing-topics=()"]; + +pub(crate) fn build(server: &Arc<Server>) -> Result<Router> { let layers = ServiceBuilder::new(); #[cfg(feature = "sentry_telemetry")] @@ -65,11 +73,11 @@ pub(crate) fn build(server: &Arc<Server>) -> io::Result<Router> { )) .layer(SetResponseHeaderLayer::if_not_present( HeaderName::from_static("permissions-policy"), - HeaderValue::from_static(CONDUWUIT_PERMISSIONS_POLICY), + HeaderValue::from_str(&CONDUWUIT_PERMISSIONS_POLICY.join(","))?, )) .layer(SetResponseHeaderLayer::if_not_present( header::CONTENT_SECURITY_POLICY, - HeaderValue::from_static(CONDUWUIT_CSP), + HeaderValue::from_str(&CONDUWUIT_CSP.join("; "))?, )) .layer(cors_layer(server)) .layer(body_limit_layer(server)) diff --git a/src/router/request.rs b/src/router/request.rs index 9256fb9c98e7db340a4f8f7a5c740595248192bb..851bd1684949f369460a89206f6be3875fbe3bb1 100644 --- a/src/router/request.rs +++ b/src/router/request.rs @@ -1,6 +1,9 @@ use std::sync::{atomic::Ordering, Arc}; -use axum::{extract::State, response::IntoResponse}; +use axum::{ + extract::State, + response::{IntoResponse, Response}, +}; use conduit::{debug, debug_error, debug_warn, defer, error, trace, Error, Result, Server}; use http::{Method, StatusCode, Uri}; use ruma::api::client::error::{Error as RumaError, ErrorBody, ErrorKind}; @@ -8,7 +11,7 @@ #[tracing::instrument(skip_all, level = "debug")] pub(crate) async fn spawn( State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next, -) -> Result<axum::response::Response, StatusCode> { +) -> Result<Response, StatusCode> { if !server.running() { debug_warn!("unavailable pending shutdown"); return Err(StatusCode::SERVICE_UNAVAILABLE); @@ -30,7 +33,7 @@ pub(crate) async fn spawn( #[tracing::instrument(skip_all, name = "handle")] pub(crate) async fn handle( State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next, -) -> Result<axum::response::Response, StatusCode> { +) -> Result<Response, StatusCode> { if !server.running() { debug_warn!( method = %req.method(), @@ -57,9 +60,7 @@ pub(crate) async fn handle( handle_result(&method, &uri, result) } -fn handle_result( - method: &Method, uri: &Uri, result: axum::response::Response, -) -> Result<axum::response::Response, StatusCode> { +fn handle_result(method: &Method, uri: &Uri, result: Response) -> Result<Response, StatusCode> { handle_result_log(method, uri, &result); match result.status() { StatusCode::METHOD_NOT_ALLOWED => handle_result_405(method, uri, &result), @@ -67,9 +68,7 @@ fn handle_result( } } -fn handle_result_405( - _method: &Method, _uri: &Uri, result: &axum::response::Response, -) -> Result<axum::response::Response, StatusCode> { +fn handle_result_405(_method: &Method, _uri: &Uri, result: &Response) -> Result<Response, StatusCode> { let error = Error::RumaError(RumaError { status_code: result.status(), body: ErrorBody::Standard { @@ -81,7 +80,7 @@ fn handle_result_405( Ok(error.into_response()) } -fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { +fn handle_result_log(method: &Method, uri: &Uri, result: &Response) { let status = result.status(); let reason = status.canonical_reason().unwrap_or("Unknown Reason"); let code = status.as_u16();