From 2a027264b3a4019f0f6ff0d87ba0b951f1a236d0 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Mon, 27 Apr 2026 18:40:50 +0200 Subject: [PATCH] packages/ak-axum/accept/catch_panic: add acceptor to catch panics in lower acceptors, streams and services (#21860) --- Cargo.lock | 1 + packages/ak-axum/Cargo.toml | 1 + packages/ak-axum/src/accept/catch_panic.rs | 737 +++++++++++++++++++++ packages/ak-axum/src/accept/mod.rs | 1 + packages/ak-axum/src/server.rs | 23 +- packages/ak-common/src/arbiter.rs | 4 +- 6 files changed, 759 insertions(+), 8 deletions(-) create mode 100644 packages/ak-axum/src/accept/catch_panic.rs diff --git a/Cargo.lock b/Cargo.lock index c8d0de7a1d..10273d476c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -216,6 +216,7 @@ dependencies = [ "eyre", "forwarded-header-value", "futures", + "pin-project-lite", "tokio", "tokio-rustls", "tower", diff --git a/packages/ak-axum/Cargo.toml b/packages/ak-axum/Cargo.toml index 67f5df6c7c..8aaa07970b 100644 --- a/packages/ak-axum/Cargo.toml +++ b/packages/ak-axum/Cargo.toml @@ -18,6 +18,7 @@ durstr.workspace = true eyre.workspace = true forwarded-header-value.workspace = true futures.workspace = true +pin-project-lite.workspace = true tokio-rustls.workspace = true tokio.workspace = true tower-http.workspace = true diff --git a/packages/ak-axum/src/accept/catch_panic.rs b/packages/ak-axum/src/accept/catch_panic.rs new file mode 100644 index 0000000000..152d8781db --- /dev/null +++ b/packages/ak-axum/src/accept/catch_panic.rs @@ -0,0 +1,737 @@ +//! axum-server acceptor that catches panics and shuts down the application. + +use std::{ + any::Any, + io::{self, IoSlice}, + panic::{AssertUnwindSafe, catch_unwind, resume_unwind}, + task::{Context, Poll}, +}; + +use ak_common::Arbiter; +use axum_server::accept::Accept; +use futures::{FutureExt as _, future::BoxFuture}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tower::Service; +use tracing::{error, instrument}; + +fn extract_panic_msg<'a>(panic: &'a Box) -> &'a str { + panic + .downcast_ref::<&str>() + .copied() + .or_else(|| panic.downcast_ref::().map(String::as_str)) + .unwrap_or("unknown panic message") +} + +/// Acceptor catching panics from the underlying acceptor. +/// +/// Also wraps the stream and service to catch panics. +#[derive(Clone)] +pub(crate) struct CatchPanicAcceptor { + inner: A, + arbiter: Arbiter, +} + +impl CatchPanicAcceptor { + pub(crate) fn new(inner: A, arbiter: Arbiter) -> Self { + Self { inner, arbiter } + } +} + +impl Accept for CatchPanicAcceptor +where + A: Accept + Clone + Send + 'static, + A::Stream: AsyncRead + AsyncWrite + Send, + A::Service: Send, + A::Future: Send, + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: Send + 'static, +{ + type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>; + type Service = CatchPanicService; + type Stream = CatchPanicStream; + + #[instrument(skip_all)] + fn accept(&self, stream: I, service: S) -> Self::Future { + let acceptor = self.inner.clone(); + let arbiter = self.arbiter.clone(); + + Box::pin(async move { + match AssertUnwindSafe(acceptor.accept(stream, service)) + .catch_unwind() + .await + { + Ok(result) => { + let (stream, service) = result?; + Ok(( + CatchPanicStream::new(stream, arbiter.clone()), + CatchPanicService::new(service, arbiter), + )) + } + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "acceptor panicked, shutting down immediately" + ); + arbiter.do_fast_shutdown().await; + resume_unwind(panic); + } + } + }) + } +} + +pin_project! { + /// A stream wrapper that catches panics from the underlying stream. + pub(crate) struct CatchPanicStream { + #[pin] + inner: S, + arbiter: Arbiter, + } +} + +impl CatchPanicStream { + pub(crate) fn new(inner: S, arbiter: Arbiter) -> Self { + Self { inner, arbiter } + } +} + +impl AsyncRead for CatchPanicStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.project(); + + match catch_unwind(AssertUnwindSafe(|| this.inner.poll_read(cx, buf))) { + Ok(result) => result, + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "stream poll_read panicked, shutting down immediately" + ); + let arbiter = this.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic); + } + } + } +} + +impl AsyncWrite for CatchPanicStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + + match catch_unwind(AssertUnwindSafe(|| this.inner.poll_write(cx, buf))) { + Ok(result) => result, + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "stream poll_write panicked, shutting down immediately" + ); + let arbiter = this.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic); + } + } + } + + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let this = self.project(); + + match catch_unwind(AssertUnwindSafe(|| { + this.inner.poll_write_vectored(cx, bufs) + })) { + Ok(result) => result, + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "stream poll_write_vectored panicked, shutting down immediately" + ); + let arbiter = this.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic) + } + } + } + + fn is_write_vectored(&self) -> bool { + match catch_unwind(AssertUnwindSafe(|| self.inner.is_write_vectored())) { + Ok(result) => result, + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "stream is_write_vectored panicked, shutting down immediately" + ); + let arbiter = self.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic); + } + } + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match catch_unwind(AssertUnwindSafe(|| this.inner.poll_flush(cx))) { + Ok(result) => result, + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "stream poll_flush panicked, shutting down immediately" + ); + let arbiter = this.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic); + } + } + } + + fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match catch_unwind(AssertUnwindSafe(|| this.inner.poll_shutdown(cx))) { + Ok(result) => result, + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "stream poll_shutdown panicked, shutting down immediately" + ); + let arbiter = this.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic); + } + } + } +} + +/// A panic wrapper that catches panics from the underlying service. +#[derive(Clone)] +pub(crate) struct CatchPanicService { + inner: S, + arbiter: Arbiter, +} + +impl CatchPanicService { + pub(crate) fn new(inner: S, arbiter: Arbiter) -> Self { + Self { inner, arbiter } + } +} + +impl Service for CatchPanicService +where + S: Service, +{ + type Error = S::Error; + type Future = CatchPanicFuture; + type Response = S::Response; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + let inner = &mut self.inner; + + match catch_unwind(AssertUnwindSafe(|| inner.poll_ready(cx))) { + Ok(result) => result, + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "service poll_ready panicked, shutting down immediately" + ); + let arbiter = self.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic); + } + } + } + + fn call(&mut self, req: R) -> Self::Future { + let inner = &mut self.inner; + + match catch_unwind(AssertUnwindSafe(|| inner.call(req))) { + Ok(future) => CatchPanicFuture::new(future, self.arbiter.clone()), + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "service call panicked, shutting down immediately" + ); + let arbiter = self.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic); + } + } + } +} + +pin_project! { + /// A Future wrapper that catches panics from the inner future. + pub(crate) struct CatchPanicFuture { + #[pin] + inner: F, + arbiter: Arbiter, + } +} + +impl CatchPanicFuture { + fn new(inner: F, arbiter: Arbiter) -> Self { + Self { inner, arbiter } + } +} + +impl Future for CatchPanicFuture { + type Output = F::Output; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match catch_unwind(AssertUnwindSafe(|| this.inner.poll(cx))) { + Ok(result) => result, + Err(panic) => { + error!( + panic = extract_panic_msg(&panic), + "service future panicked, shutting down immediately" + ); + let arbiter = this.arbiter.clone(); + tokio::spawn(async move { arbiter.do_fast_shutdown().await }); + resume_unwind(panic); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + convert::Infallible, + io, + panic::{AssertUnwindSafe, panic_any}, + task::{Context, Poll}, + }; + + use ak_common::{Arbiter, Tasks}; + use axum_server::accept::Accept; + use futures::{ + FutureExt as _, + future::{BoxFuture, poll_fn}, + }; + use tokio::{ + io::{AsyncReadExt as _, AsyncWriteExt as _, DuplexStream, ReadBuf, duplex}, + time::{Duration, timeout}, + }; + use tower::Service; + + use super::{CatchPanicAcceptor, CatchPanicService, CatchPanicStream}; + + fn duplex_stream() -> DuplexStream { + let (stream, _peer) = duplex(1024); + stream + } + + /// Returns `true` if the arbiter's fast-shutdown has already been triggered. + async fn fast_shutdown_triggered(arbiter: &Arbiter) -> bool { + timeout(Duration::from_millis(50), arbiter.fast_shutdown()) + .await + .is_ok() + } + + #[derive(Clone)] + struct OkAcceptor; + + impl Accept for OkAcceptor { + type Future = BoxFuture<'static, io::Result<(I, S)>>; + type Service = S; + type Stream = I; + + fn accept(&self, stream: I, service: S) -> Self::Future { + Box::pin(async move { Ok((stream, service)) }) + } + } + + #[derive(Clone)] + struct ErrorAcceptor; + + impl Accept for ErrorAcceptor { + type Future = BoxFuture<'static, io::Result<(I, S)>>; + type Service = S; + type Stream = I; + + fn accept(&self, _stream: I, _service: S) -> Self::Future { + Box::pin(async move { Err(io::Error::other("inner error")) }) + } + } + + /// Panics with a `&'static str` payload. + #[derive(Clone)] + struct PanicStrAcceptor; + + impl Accept for PanicStrAcceptor { + type Future = BoxFuture<'static, io::Result<(I, S)>>; + type Service = S; + type Stream = I; + + fn accept(&self, _stream: I, _service: S) -> Self::Future { + Box::pin(async move { panic!("str panic message") }) + } + } + + /// Panics with a `String` payload. + #[derive(Clone)] + struct PanicStringAcceptor; + + impl Accept for PanicStringAcceptor { + type Future = BoxFuture<'static, io::Result<(I, S)>>; + type Service = S; + type Stream = I; + + fn accept(&self, _stream: I, _service: S) -> Self::Future { + Box::pin(async move { + let msg = "string panic message".to_owned(); + panic_any(msg) + }) + } + } + + /// Panics with a payload that is neither `&str` nor `String`. + #[derive(Clone)] + struct PanicUnknownAcceptor; + + impl Accept for PanicUnknownAcceptor { + type Future = BoxFuture<'static, io::Result<(I, S)>>; + type Service = S; + type Stream = I; + + fn accept(&self, _stream: I, _service: S) -> Self::Future { + Box::pin(async move { panic_any(42u32) }) + } + } + + struct PanicStream; + + impl tokio::io::AsyncRead for PanicStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + panic!("poll_read panic") + } + } + + impl tokio::io::AsyncWrite for PanicStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + panic!("poll_write panic") + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + panic!("poll_flush panic") + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + panic!("poll_shutdown panic") + } + } + + #[derive(Clone)] + struct OkService; + + impl Service<()> for OkService { + type Error = Infallible; + type Future = futures::future::Ready>; + type Response = (); + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: ()) -> Self::Future { + futures::future::ready(Ok(())) + } + } + + struct PanicPollReadyService; + + impl Service<()> for PanicPollReadyService { + type Error = Infallible; + type Future = futures::future::Ready>; + type Response = (); + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + panic!("poll_ready panic") + } + + fn call(&mut self, _req: ()) -> Self::Future { + unreachable!() + } + } + + struct PanicCallBodyService; + + impl Service<()> for PanicCallBodyService { + type Error = Infallible; + type Future = futures::future::Ready>; + type Response = (); + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: ()) -> Self::Future { + panic!("call body panic") + } + } + + struct PanicFuture; + + impl Future for PanicFuture { + type Output = Result<(), Infallible>; + + fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + panic!("future panic") + } + } + + struct PanicCallFutureService; + + impl Service<()> for PanicCallFutureService { + type Error = Infallible; + type Future = PanicFuture; + type Response = (); + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: ()) -> Self::Future { + PanicFuture + } + } + + #[tokio::test] + async fn acceptor_passes_through_success() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let acceptor = CatchPanicAcceptor::new(OkAcceptor, arbiter.clone()); + + let result = acceptor.accept(duplex_stream(), OkService).await; + + assert!(result.is_ok()); + assert!(!fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn acceptor_passes_through_error() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let acceptor = CatchPanicAcceptor::new(ErrorAcceptor, arbiter.clone()); + + let result = acceptor.accept(duplex_stream(), OkService).await; + + assert!(result.is_err()); + assert_eq!(result.err().unwrap().to_string(), "inner error"); + assert!(!fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn acceptor_catches_str_panic_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let acceptor = CatchPanicAcceptor::new(PanicStrAcceptor, arbiter.clone()); + + let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService)) + .catch_unwind() + .await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn acceptor_catches_string_panic_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let acceptor = CatchPanicAcceptor::new(PanicStringAcceptor, arbiter.clone()); + + let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService)) + .catch_unwind() + .await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn acceptor_catches_unknown_panic_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let acceptor = CatchPanicAcceptor::new(PanicUnknownAcceptor, arbiter.clone()); + + let result = AssertUnwindSafe(acceptor.accept(duplex_stream(), OkService)) + .catch_unwind() + .await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn stream_poll_read_passes_through() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let (mut a, mut b) = duplex(1024); + b.write_all(b"hello").await.unwrap(); + + let mut stream = CatchPanicStream::new(&mut a, arbiter.clone()); + let mut buf = [0u8; 5]; + let result = stream.read(&mut buf).await; + + assert!(result.is_ok()); + assert_eq!(&buf, b"hello"); + assert!(!fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn stream_poll_read_panic_returns_error_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone()); + + let result = AssertUnwindSafe(stream.read(&mut [0u8; 10])) + .catch_unwind() + .await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn stream_poll_write_passes_through() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let (mut a, _b) = duplex(1024); + + let mut stream = CatchPanicStream::new(&mut a, arbiter.clone()); + let result = stream.write_all(b"hello").await; + + assert!(result.is_ok()); + assert!(!fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn stream_poll_write_panic_returns_error_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone()); + + let result = AssertUnwindSafe(stream.write(b"hello")) + .catch_unwind() + .await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn stream_poll_flush_panic_returns_error_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone()); + + let result = AssertUnwindSafe(stream.flush()).catch_unwind().await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn stream_poll_shutdown_panic_returns_error_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut stream = CatchPanicStream::new(PanicStream, arbiter.clone()); + + let result = AssertUnwindSafe(stream.shutdown()).catch_unwind().await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn service_poll_ready_passes_through() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut service = CatchPanicService::new(OkService, arbiter.clone()); + + let result = poll_fn(|cx| service.poll_ready(cx)).await; + + assert!(result.is_ok()); + assert!(!fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn service_poll_ready_panic_re_panics_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut service = CatchPanicService::new(PanicPollReadyService, arbiter.clone()); + + let result = AssertUnwindSafe(poll_fn(|cx| service.poll_ready(cx))) + .catch_unwind() + .await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn service_call_passes_through() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut service = CatchPanicService::new(OkService, arbiter.clone()); + + let result = service.call(()).await; + + assert!(result.is_ok()); + assert!(!fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn service_call_body_panic_re_panics_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut service = CatchPanicService::new(PanicCallBodyService, arbiter.clone()); + + let result = AssertUnwindSafe(async { service.call(()).await }) + .catch_unwind() + .await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } + + #[tokio::test] + async fn service_call_future_panic_re_panics_and_shuts_down() { + let tasks = Tasks::new().expect("failed to create tasks"); + let arbiter = tasks.arbiter(); + let mut service = CatchPanicService::new(PanicCallFutureService, arbiter.clone()); + + let result = AssertUnwindSafe(service.call(())).catch_unwind().await; + + assert!(result.is_err()); + assert!(fast_shutdown_triggered(&arbiter).await); + } +} diff --git a/packages/ak-axum/src/accept/mod.rs b/packages/ak-axum/src/accept/mod.rs index 2f4d0266b4..92cbcfc02a 100644 --- a/packages/ak-axum/src/accept/mod.rs +++ b/packages/ak-axum/src/accept/mod.rs @@ -1,2 +1,3 @@ +pub mod catch_panic; pub mod proxy_protocol; pub mod tls; diff --git a/packages/ak-axum/src/server.rs b/packages/ak-axum/src/server.rs index 4a2246f385..5debeeddb1 100644 --- a/packages/ak-axum/src/server.rs +++ b/packages/ak-axum/src/server.rs @@ -12,7 +12,9 @@ use axum_server::{ use eyre::Result; use tracing::{info, trace}; -use crate::accept::{proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor}; +use crate::accept::{ + catch_panic::CatchPanicAcceptor, proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor, +}; async fn run_plain( arbiter: Arbiter, @@ -27,7 +29,10 @@ async fn run_plain( arbiter.add_net_handle(handle.clone()).await; let res = axum_server::Server::bind(addr) - .acceptor(ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new())) + .acceptor(CatchPanicAcceptor::new( + ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new()), + arbiter.clone(), + )) .handle(handle) .serve(router.into_make_service_with_connect_info::()) .await; @@ -80,7 +85,10 @@ pub(crate) async fn run_unix( } } let res = axum_server::Server::bind(addr.clone()) - .acceptor(DefaultAcceptor::new()) + .acceptor(CatchPanicAcceptor::new( + DefaultAcceptor::new(), + arbiter.clone(), + )) .handle(handle) .serve(router.into_make_service()) .await; @@ -133,9 +141,12 @@ async fn run_tls( arbiter.add_net_handle(handle.clone()).await; axum_server::Server::bind(addr) - .acceptor(ProxyProtocolAcceptor::new().acceptor(TlsAcceptor::new( - RustlsAcceptor::new(config).acceptor(DefaultAcceptor::new()), - ))) + .acceptor(CatchPanicAcceptor::new( + ProxyProtocolAcceptor::new().acceptor(TlsAcceptor::new( + RustlsAcceptor::new(config).acceptor(DefaultAcceptor::new()), + )), + arbiter.clone(), + )) .handle(handle) .serve(router.into_make_service_with_connect_info::()) .await?; diff --git a/packages/ak-common/src/arbiter.rs b/packages/ak-common/src/arbiter.rs index 5f69b58757..1e92e20dcf 100644 --- a/packages/ak-common/src/arbiter.rs +++ b/packages/ak-common/src/arbiter.rs @@ -235,7 +235,7 @@ impl Arbiter { } /// Shutdown the application immediately. - async fn do_fast_shutdown(&self) { + pub async fn do_fast_shutdown(&self) { info!("arbiter has been told to shutdown immediately"); self.unix_handles .lock() @@ -253,7 +253,7 @@ impl Arbiter { } /// Shutdown the application gracefully. - async fn do_graceful_shutdown(&self) { + pub async fn do_graceful_shutdown(&self) { info!("arbiter has been told to shutdown gracefully"); // Match the value in lifecycle/gunicorn.conf.py for graceful shutdown let timeout = Some(Duration::from_secs(30 + 5));