Skip to content
Snippets Groups Projects
Commit 57e6af6e authored by Jason Volk's avatar Jason Volk Committed by 🥺
Browse files

split sending/send base functions


Signed-off-by: default avatarJason Volk <jason@zemos.net>
parent f919fa87
No related branches found
No related tags found
No related merge requests found
...@@ -234,7 +234,7 @@ pub(crate) async fn send_federation_request<T>(&self, dest: &ServerName, request ...@@ -234,7 +234,7 @@ pub(crate) async fn send_federation_request<T>(&self, dest: &ServerName, request
let permit = self.maximum_requests.acquire().await; let permit = self.maximum_requests.acquire().await;
let timeout = Duration::from_secs(self.timeout); let timeout = Duration::from_secs(self.timeout);
let client = &services().globals.client.federation; let client = &services().globals.client.federation;
let response = tokio::time::timeout(timeout, send::send_request(client, dest, request)) let response = tokio::time::timeout(timeout, send::send(client, dest, request))
.await .await
.map_err(|_| { .map_err(|_| {
warn!("Timeout after 300 seconds waiting for server response of {dest}"); warn!("Timeout after 300 seconds waiting for server response of {dest}");
...@@ -795,7 +795,7 @@ async fn send_events_dest_normal( ...@@ -795,7 +795,7 @@ async fn send_events_dest_normal(
let permit = services().sending.maximum_requests.acquire().await; let permit = services().sending.maximum_requests.acquire().await;
let client = &services().globals.client.sender; let client = &services().globals.client.sender;
let response = send::send_request( let response = send::send(
client, client,
server_name, server_name,
send_transaction_message::v1::Request { send_transaction_message::v1::Request {
......
...@@ -51,7 +51,7 @@ struct ActualDest { ...@@ -51,7 +51,7 @@ struct ActualDest {
} }
#[tracing::instrument(skip_all, name = "send")] #[tracing::instrument(skip_all, name = "send")]
pub(crate) async fn send_request<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse> pub(crate) async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
where where
T: OutgoingRequest + Debug, T: OutgoingRequest + Debug,
{ {
...@@ -59,22 +59,19 @@ pub(crate) async fn send_request<T>(client: &Client, dest: &ServerName, req: T) ...@@ -59,22 +59,19 @@ pub(crate) async fn send_request<T>(client: &Client, dest: &ServerName, req: T)
return Err(Error::bad_config("Federation is disabled.")); return Err(Error::bad_config("Federation is disabled."));
} }
trace!("Preparing to send request");
validate_dest(dest)?;
let actual = get_actual_dest(dest).await?; let actual = get_actual_dest(dest).await?;
let mut http_request = req let request = prepare::<T>(dest, &actual, req).await?;
.try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_5]) execute::<T>(client, dest, &actual, request).await
.map_err(|e| { }
debug_warn!("Failed to find destination {}: {}", actual.string, e);
Error::BadServerResponse("Invalid destination")
})?;
sign_request::<T>(dest, &mut http_request); async fn execute<T>(
let request = Request::try_from(http_request)?; client: &Client, dest: &ServerName, actual: &ActualDest, request: Request,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
let method = request.method().clone(); let method = request.method().clone();
let url = request.url().clone(); let url = request.url().clone();
validate_url(&url)?;
debug!( debug!(
method = ?method, method = ?method,
url = ?url, url = ?url,
...@@ -82,12 +79,32 @@ pub(crate) async fn send_request<T>(client: &Client, dest: &ServerName, req: T) ...@@ -82,12 +79,32 @@ pub(crate) async fn send_request<T>(client: &Client, dest: &ServerName, req: T)
); );
match client.execute(request).await { match client.execute(request).await {
Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await, Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await,
Err(e) => handle_error::<T>(dest, &actual, &method, &url, e), Err(e) => handle_error::<T>(dest, actual, &method, &url, e),
} }
} }
async fn prepare<T>(dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request>
where
T: OutgoingRequest + Debug,
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5];
trace!("Preparing request");
let mut http_request = req
.try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &VERSIONS)
.map_err(|_e| Error::BadServerResponse("Invalid destination"))?;
sign_request::<T>(dest, &mut http_request);
let request = Request::try_from(http_request)?;
validate_url(request.url())?;
Ok(request)
}
async fn handle_response<T>( async fn handle_response<T>(
dest: &ServerName, actual: ActualDest, method: &Method, url: &Url, mut response: Response, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response,
) -> Result<T::IncomingResponse> ) -> Result<T::IncomingResponse>
where where
T: OutgoingRequest + Debug, T: OutgoingRequest + Debug,
...@@ -126,7 +143,7 @@ async fn handle_response<T>( ...@@ -126,7 +143,7 @@ async fn handle_response<T>(
.actual_destinations() .actual_destinations()
.write() .write()
.await .await
.insert(OwnedServerName::from(dest), (actual.dest, actual.host)); .insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone()));
} }
match response { match response {
...@@ -176,6 +193,7 @@ async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> { ...@@ -176,6 +193,7 @@ async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
result result
} else { } else {
cached = false; cached = false;
validate_dest(server_name)?;
resolve_actual_dest(server_name).await? resolve_actual_dest(server_name).await?
}; };
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment