From e4dc4a1ba59d09ba037a7160ae0e39c0eeb9cc03 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Sun, 14 Jul 2024 06:07:54 +0000
Subject: [PATCH] fix graceful shutdown on unix socket

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/router/serve/unix.rs | 23 ++++++++++++++++++-----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs
index 266936dd9..b5938673c 100644
--- a/src/router/serve/unix.rs
+++ b/src/router/serve/unix.rs
@@ -3,7 +3,7 @@
 use std::{
 	net::{self, IpAddr, Ipv4Addr},
 	path::Path,
-	sync::Arc,
+	sync::{atomic::Ordering, Arc},
 };
 
 use axum::{
@@ -21,12 +21,14 @@
 	net::{unix::SocketAddr, UnixListener, UnixStream},
 	sync::broadcast::{self},
 	task::JoinSet,
+	time::{sleep, Duration},
 };
 use tower::{Service, ServiceExt};
 
 type MakeService = IntoMakeServiceWithConnectInfo<Router, net::SocketAddr>;
 
-static NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
+const NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
+const FINI_POLL_INTERVAL: Duration = Duration::from_millis(750);
 
 #[tracing::instrument(skip_all)]
 pub(super) async fn serve(server: &Arc<Server>, app: Router, mut shutdown: broadcast::Receiver<()>) -> Result<()> {
@@ -47,7 +49,7 @@ pub(super) async fn serve(server: &Arc<Server>, app: Router, mut shutdown: broad
 		}
 	}
 
-	fini(listener, tasks).await;
+	fini(server, listener, tasks).await;
 
 	Ok(())
 }
@@ -111,15 +113,26 @@ async fn init(server: &Arc<Server>) -> Result<UnixListener> {
 		return Err!("Failed to set socket {path:?} permissions: {e}");
 	}
 
-	info!("Listening at {:?}", path);
+	info!("Listening at {path:?}");
 
 	Ok(listener.unwrap())
 }
 
-async fn fini(listener: UnixListener, mut tasks: JoinSet<()>) {
+async fn fini(server: &Arc<Server>, listener: UnixListener, mut tasks: JoinSet<()>) {
 	let local = listener.local_addr();
 
+	debug!("Closing listener at {local:?} ...");
 	drop(listener);
+
+	debug!("Waiting for requests to finish...");
+	while server.metrics.requests_spawn_active.load(Ordering::Relaxed) > 0 {
+		tokio::select! {
+			_ = tasks.join_next() => {}
+			() = sleep(FINI_POLL_INTERVAL) => {}
+		}
+	}
+
+	debug!("Shutting down...");
 	tasks.shutdown().await;
 
 	if let Ok(local) = local {
-- 
GitLab