From dedbbee55cdac8e56cbeabca54d4a794112d3a48 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Thu, 9 Apr 2026 11:57:15 +0000 Subject: [PATCH] packages/ak-axum/extract/host: init (#21323) --- packages/ak-axum/src/extract/host.rs | 272 +++++++++++++++++++++++++++ packages/ak-axum/src/extract/mod.rs | 1 + packages/ak-axum/src/router.rs | 5 +- packages/ak-axum/src/tracing.rs | 1 + 4 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 packages/ak-axum/src/extract/host.rs diff --git a/packages/ak-axum/src/extract/host.rs b/packages/ak-axum/src/extract/host.rs new file mode 100644 index 0000000000..c9b0ef24c4 --- /dev/null +++ b/packages/ak-axum/src/extract/host.rs @@ -0,0 +1,272 @@ +//! axum extractor and middleware to retrieve the host. +use axum::{ + Extension, RequestPartsExt as _, + extract::{FromRequestParts, Request}, + http::{ + header::{FORWARDED, HOST}, + request::Parts, + status::StatusCode, + }, + middleware::Next, + response::{IntoResponse as _, Response}, +}; +use forwarded_header_value::ForwardedHeaderValue; +use tracing::{Span, instrument}; + +use crate::extract::trusted_proxy::TrustedProxy; + +const X_FORWARDED_HOST: &str = "X-Forwarded-Host"; + +/// Request host. +/// +/// The [`host_middleware`] must be added to the router before using this extractor, +/// otherwise this will result in requests erroring. +#[derive(Clone, Debug)] +pub struct Host(pub String); + +impl FromRequestParts for Host +where + S: Send + Sync, +{ + type Rejection = as FromRequestParts>::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + Extension::::from_request_parts(parts, state) + .await + .map(|Extension(host)| host) + } +} + +/// Get the host from the request. +#[instrument(skip_all)] +async fn extract_host(parts: &mut Parts) -> Result { + let is_trusted = parts + .extract::() + .await + .unwrap_or(TrustedProxy(false)) + .0; + + if is_trusted { + if let Some(host) = parts + .headers + .get(X_FORWARDED_HOST) + .and_then(|host| host.to_str().ok()) + { + return Ok(host.to_owned()); + } + + if let Some(forwarded) = parts.headers.get(FORWARDED) + && let Ok(forwarded) = forwarded.to_str() + && let Ok(forwarded) = ForwardedHeaderValue::from_forwarded(forwarded) + { + for stanza in forwarded.iter() { + if let Some(forwarded_host) = &stanza.forwarded_host { + return Ok(forwarded_host.to_owned()); + } + } + } + } + + if let Some(host) = parts.headers.get(HOST).and_then(|host| host.to_str().ok()) { + return Ok(host.to_owned()); + } + + if let Some(host) = parts.uri.host() { + Ok(host.to_owned()) + } else { + Err((StatusCode::BAD_REQUEST, "missing host header")) + } +} + +/// Middleware required by the [`Host`] extractor. +/// +/// Use with [`axum::middleware::from_fn`]. +pub async fn host_middleware(request: Request, next: Next) -> Response { + let (mut parts, body) = request.into_parts(); + + let host = match extract_host(&mut parts).await { + Ok(host) => host, + Err(err) => return err.into_response(), + }; + Span::current().record("host", host.clone()); + parts.extensions.insert::(Host(host)); + + let request = Request::from_parts(parts, body); + + next.run(request).await +} + +#[cfg(test)] +mod tests { + use axum::{body::Body, http::Request}; + + use super::*; + + #[tokio::test] + async fn host_header() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("host", "example.com:8080") + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_ok()); + assert_eq!( + result.expect("Host extraction should succeed"), + "example.com:8080", + ); + } + + #[tokio::test] + async fn from_uri() { + let (mut parts, _) = Request::builder() + .uri("http://example.com:8080/path") + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_ok()); + assert_eq!( + result.expect("Host extraction should succeed"), + "example.com", + ); + } + + #[tokio::test] + async fn x_forwarded_host_trusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-host", "forwarded.example.com") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_ok()); + assert_eq!( + result.expect("Host extraction should succeed"), + "forwarded.example.com", + ); + } + + #[tokio::test] + async fn forwarded_header_trusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("forwarded", "host=forwarded.example.com") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_ok()); + assert_eq!( + result.expect("Host extraction should succeed"), + "forwarded.example.com", + ); + } + + #[tokio::test] + async fn forwarded_host_untrusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-host", "malicious.example.com") + .extension(TrustedProxy(false)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_ok()); + assert_eq!( + result.expect("Host extraction should succeed"), + "example.com", + ); + } + + #[tokio::test] + async fn forwarded_header_untrusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("forwarded", "host=malicious.example.com") + .extension(TrustedProxy(false)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_ok()); + assert_eq!( + result.expect("Host extraction should succeed"), + "example.com", + ); + } + + #[tokio::test] + async fn priority_order() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-host", "x-forwarded.example.com") + .header("forwarded", "host=forwarded.example.com") + .header("host", "host-header.example.com") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_ok()); + assert_eq!( + result.expect("Host extraction should succeed"), + "x-forwarded.example.com", + ); + } + + #[tokio::test] + async fn no_host_found() { + let (mut parts, _) = Request::builder() + .uri("/path") + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_err()); + assert_eq!(result.expect_err("Host extract should fail").0, 400); + } + + #[tokio::test] + async fn multiple_forwarded_stanzas() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header( + "forwarded", + "host=first.example.com, host=second.example.com", + ) + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let result = extract_host(&mut parts).await; + + assert!(result.is_ok()); + assert_eq!( + result.expect("Host extraction should succeed"), + "first.example.com", + ); + } +} diff --git a/packages/ak-axum/src/extract/mod.rs b/packages/ak-axum/src/extract/mod.rs index 3837cec844..87d8bfb631 100644 --- a/packages/ak-axum/src/extract/mod.rs +++ b/packages/ak-axum/src/extract/mod.rs @@ -1,5 +1,6 @@ //! axum extractors to get information about a request. pub mod client_ip; +pub mod host; pub mod scheme; pub mod trusted_proxy; diff --git a/packages/ak-axum/src/router.rs b/packages/ak-axum/src/router.rs index 2caceeab0c..cdc814de8d 100644 --- a/packages/ak-axum/src/router.rs +++ b/packages/ak-axum/src/router.rs @@ -7,7 +7,7 @@ use tower_http::timeout::TimeoutLayer; use crate::{ extract::{ - client_ip::client_ip_middleware, scheme::scheme_middleware, + client_ip::client_ip_middleware, host::host_middleware, scheme::scheme_middleware, trusted_proxy::trusted_proxy_middleware, }, tracing::{span_middleware, tracing_middleware}, @@ -32,7 +32,8 @@ pub fn wrap_router(router: Router, with_tracing: bool) -> Router { .layer(from_fn(span_middleware)) .layer(from_fn(trusted_proxy_middleware)) .layer(from_fn(client_ip_middleware)) - .layer(from_fn(scheme_middleware)); + .layer(from_fn(scheme_middleware)) + .layer(from_fn(host_middleware)); if with_tracing { router.layer(service_builder.layer(from_fn(tracing_middleware))) } else { diff --git a/packages/ak-axum/src/tracing.rs b/packages/ak-axum/src/tracing.rs index 0c7b69a91c..c69678aa31 100644 --- a/packages/ak-axum/src/tracing.rs +++ b/packages/ak-axum/src/tracing.rs @@ -29,6 +29,7 @@ pub(crate) async fn span_middleware(request: Request, next: Next) -> Response { method = %request.method(), remote = field::Empty, scheme = field::Empty, + host = field::Empty, http_headers = ?http_headers, ); next.run(request).instrument(span).await