Skip to content
Snippets Groups Projects
axum.rs 13.6 KiB
Newer Older
🥺's avatar
🥺 committed
use std::{collections::BTreeMap, str};
Jonas Platte's avatar
Jonas Platte committed

use axum::{
	async_trait,
	body::{Full, HttpBody},
	extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader},
	headers::{
		authorization::{Bearer, Credentials},
		Authorization,
	},
	response::{IntoResponse, Response},
	BoxError, RequestExt, RequestPartsExt,
Jonas Platte's avatar
Jonas Platte committed
};
Jonas Platte's avatar
Jonas Platte committed
use bytes::{Buf, BufMut, Bytes, BytesMut};
use http::{Request, StatusCode};
Jonas Platte's avatar
Jonas Platte committed
use ruma::{
	api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
	CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId,
Jonas Platte's avatar
Jonas Platte committed
};
use serde::Deserialize;
use tracing::{debug, error, warn};
Jonas Platte's avatar
Jonas Platte committed

use super::{Ruma, RumaResponse};
Timo Kösters's avatar
Timo Kösters committed
use crate::{services, Error, Result};
#[derive(Deserialize)]
struct QueryParams {
	access_token: Option<String>,
	user_id: Option<String>,
Jonas Platte's avatar
Jonas Platte committed
#[async_trait]
Jonas Platte's avatar
Jonas Platte committed
impl<T, S, B> FromRequest<S, B> for Ruma<T>
Jonas Platte's avatar
Jonas Platte committed
where
	T: IncomingRequest,
	B: HttpBody + Send + 'static,
	B::Data: Send,
	B::Error: Into<BoxError>,
	type Rejection = Error;

	#[allow(unused_qualifications)] // async traits
	async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
		let (mut parts, mut body) = match req.with_limited_body() {
			Ok(limited_req) => {
				let (parts, body) = limited_req.into_parts();
				let body =
					to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
				(parts, body)
			},
			Err(original_req) => {
				let (parts, body) = original_req.into_parts();
				let body =
					to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
				(parts, body)
			},
		};

		let metadata = T::METADATA;
		let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?;
		let path_params: Path<Vec<String>> = parts.extract().await?;

		let query = parts.uri.query().unwrap_or_default();
		let query_params: QueryParams = match serde_html_form::from_str(query) {
			Ok(params) => params,
			Err(e) => {
				error!(%query, "Failed to deserialize query parameters: {}", e);
				return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"));
			},
		};

		let token = match &auth_header {
			Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
			None => query_params.access_token.as_deref(),
		};

		let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();

		let appservices = services().appservice.all().unwrap();
		let appservice_registration =
			appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token);

		let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) =
			appservice_registration
		{
			match metadata.authentication {
				// TODO: verify if just or'ing `AuthScheme::AppserviceToken` is correct here
				AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken => {
					let user_id = query_params.user_id.map_or_else(
						|| {
							UserId::parse_with_server_name(
								registration.sender_localpart.as_str(),
								services().globals.server_name(),
							)
							.unwrap()
						},
						|s| UserId::parse(s).unwrap(),
					);

					if !services().users.exists(&user_id)? {
						return Err(Error::BadRequest(ErrorKind::Forbidden, "User does not exist."));
					}

					// TODO: Check if appservice is allowed to be that user
					(Some(user_id), None, None, true)
				},
				AuthScheme::ServerSignatures => (None, None, None, true),
				AuthScheme::None => (None, None, None, true),
			}
		} else {
			match metadata.authentication {
				AuthScheme::AccessToken => {
					let token = match token {
						Some(token) => token,
						_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
					};

					match services().users.find_from_token(token)? {
						None => {
							return Err(Error::BadRequest(
								ErrorKind::UnknownToken {
									soft_logout: false,
								},
								"Unknown access token.",
							))
						},
						Some((user_id, device_id)) => {
							(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
						},
					}
				},
				AuthScheme::AccessTokenOptional => {
					let token = token.unwrap_or("");

					if token.is_empty() {
						(None, None, None, false)
					} else {
						match services().users.find_from_token(token)? {
							None => {
								return Err(Error::BadRequest(
									ErrorKind::UnknownToken {
										soft_logout: false,
									},
									"Unknown access token.",
								))
							},
							Some((user_id, device_id)) => {
								(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
							},
						}
					}
				},
				// treat non-appservice registrations as None authentication
				AuthScheme::AppserviceToken => (None, None, None, false),
				AuthScheme::ServerSignatures => {
					if !services().globals.allow_federation() {
						return Err(Error::bad_config("Federation is disabled."));
					}

					let TypedHeader(Authorization(x_matrix)) =
						parts.extract::<TypedHeader<Authorization<XMatrix>>>().await.map_err(|e| {
							warn!("Missing or invalid Authorization header: {}", e);

							let msg = match e.reason() {
								TypedHeaderRejectionReason::Missing => "Missing Authorization header.",
								TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.",
								_ => "Unknown header-related error",
							};

							Error::BadRequest(ErrorKind::Forbidden, msg)
						})?;

					let origin_signatures =
						BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]);

					let signatures = BTreeMap::from_iter([(
						x_matrix.origin.as_str().to_owned(),
						CanonicalJsonValue::Object(origin_signatures),
					)]);

					let server_destination = services().globals.server_name().as_str().to_owned();

					if let Some(destination) = x_matrix.destination.as_ref() {
						if destination != &server_destination {
							return Err(Error::BadRequest(ErrorKind::Forbidden, "Invalid authorization."));
						}
					}

					let mut request_map = BTreeMap::from_iter([
						("method".to_owned(), CanonicalJsonValue::String(parts.method.to_string())),
						("uri".to_owned(), CanonicalJsonValue::String(parts.uri.to_string())),
						(
							"origin".to_owned(),
							CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
						),
						("destination".to_owned(), CanonicalJsonValue::String(server_destination)),
						("signatures".to_owned(), CanonicalJsonValue::Object(signatures)),
					]);

					if let Some(json_body) = &json_body {
						request_map.insert("content".to_owned(), json_body.clone());
					};

					let keys_result = services()
						.rooms
						.event_handler
						.fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()])
						.await;

					let keys = match keys_result {
						Ok(b) => b,
						Err(e) => {
							warn!("Failed to fetch signing keys: {}", e);
							return Err(Error::BadRequest(ErrorKind::Forbidden, "Failed to fetch signing keys."));
						},
					};

					let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);

					match ruma::signatures::verify_json(&pub_key_map, &request_map) {
						Ok(()) => (None, None, Some(x_matrix.origin), false),
						Err(e) => {
							warn!(
								"Failed to verify json request from {}: {}\n{:?}",
								x_matrix.origin, e, request_map
							);

							if parts.uri.to_string().contains('@') {
								warn!(
									"Request uri contained '@' character. Make sure your reverse proxy gives Conduit \
									 the raw uri (apache: use nocanon)"
								);
							}

							return Err(Error::BadRequest(
								ErrorKind::Forbidden,
								"Failed to verify X-Matrix signatures.",
							));
						},
					}
				},
				AuthScheme::None => match parts.uri.path() {
					// allow_public_room_directory_without_auth
					"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
						if !services().globals.config.allow_public_room_directory_without_auth {
							let token = match token {
								Some(token) => token,
								_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
							};

							match services().users.find_from_token(token)? {
								None => {
									return Err(Error::BadRequest(
										ErrorKind::UnknownToken {
											soft_logout: false,
										},
										"Unknown access token.",
									))
								},
								Some((user_id, device_id)) => {
									(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
								},
							}
						} else {
							(None, None, None, false)
						}
					},
					_ => (None, None, None, false),
				},
			}
		};

		let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
		*http_request.headers_mut().unwrap() = parts.headers;

		if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
			let user_id = sender_user.clone().unwrap_or_else(|| {
				UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid")
			});

			let uiaa_request = json_body
				.get("auth")
				.and_then(|auth| auth.as_object())
				.and_then(|auth| auth.get("session"))
				.and_then(|session| session.as_str())
				.and_then(|session| {
					services().uiaa.get_uiaa_request(
						&user_id,
						&sender_device.clone().unwrap_or_else(|| "".into()),
						session,
					)
				});

			if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
				for (key, value) in initial_request {
					json_body.entry(key).or_insert(value);
				}
			}

			let mut buf = BytesMut::new().writer();
			serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail");
			body = buf.into_inner().freeze();
		}

		let http_request = http_request.body(&*body).unwrap();

		debug!("{:?}", http_request);

		let body = T::try_from_http_request(http_request, &path_params).map_err(|e| {
			warn!("try_from_http_request failed: {:?}", e);
			debug!("JSON body: {:?}", json_body);
			Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.")
		})?;

		Ok(Ruma {
			body,
			sender_user,
			sender_device,
			sender_servername,
			from_appservice,
			json_body,
		})
	}
Jonas Platte's avatar
Jonas Platte committed
}

struct XMatrix {
	origin: OwnedServerName,
	destination: Option<String>,
	key: String, // KeyName?
	sig: String,
Jonas Platte's avatar
Jonas Platte committed
}

impl Credentials for XMatrix {
	const SCHEME: &'static str = "X-Matrix";

	fn decode(value: &http::HeaderValue) -> Option<Self> {
		debug_assert!(
			value.as_bytes().starts_with(b"X-Matrix "),
			"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
		);

		let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]).ok()?.trim_start();

		let mut origin = None;
		let mut destination = None;
		let mut key = None;
		let mut sig = None;

		for entry in parameters.split_terminator(',') {
			let (name, value) = entry.split_once('=')?;

			// It's not at all clear why some fields are quoted and others not in the spec,
			// let's simply accept either form for every field.
			let value = value.strip_prefix('"').and_then(|rest| rest.strip_suffix('"')).unwrap_or(value);

			// FIXME: Catch multiple fields of the same name
			match name {
				"origin" => origin = Some(value.try_into().ok()?),
				"key" => key = Some(value.to_owned()),
				"sig" => sig = Some(value.to_owned()),
				"destination" => destination = Some(value.to_owned()),
				_ => debug!("Unexpected field `{}` in X-Matrix Authorization header", name),
			}
		}

		Some(Self {
			origin: origin?,
			key: key?,
			sig: sig?,
			destination,
		})
	}

	fn encode(&self) -> http::HeaderValue { todo!() }
Timo Kösters's avatar
Timo Kösters committed
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
	fn into_response(self) -> Response {
		match self.0.try_into_http_response::<BytesMut>() {
			Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
			Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
		}
	}
Jonas Platte's avatar
Jonas Platte committed

// copied from hyper under the following license:
// Copyright (c) 2014-2021 Sean McArthur

// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:

// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error>
where
	T: HttpBody,
Jonas Platte's avatar
Jonas Platte committed
{
	futures_util::pin_mut!(body);

	// If there's only 1 chunk, we can just return Buf::to_bytes()
	let mut first = if let Some(buf) = body.data().await {
		buf?
	} else {
		return Ok(Bytes::new());
	};

	let second = if let Some(buf) = body.data().await {
		buf?
	} else {
		return Ok(first.copy_to_bytes(first.remaining()));
	};

	// With more than 1 buf, we gotta flatten into a Vec first.
	let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize;
	let mut vec = Vec::with_capacity(cap);
	vec.put(first);
	vec.put(second);

	while let Some(buf) = body.data().await {
		vec.put(buf?);
	}

	Ok(vec.into())
Jonas Platte's avatar
Jonas Platte committed
}