packages/ak-axum/extract/host: init (#21323)

This commit is contained in:
Marc 'risson' Schmitt
2026-04-09 11:57:15 +00:00
committed by GitHub
parent 165297dcd4
commit dedbbee55c
4 changed files with 277 additions and 2 deletions
+272
View File
@@ -0,0 +1,272 @@
//! axum extractor and middleware to retrieve the host.
use axum::{
Extension, RequestPartsExt as _,
extract::{FromRequestParts, Request},
http::{
header::{FORWARDED, HOST},
request::Parts,
status::StatusCode,
},
middleware::Next,
response::{IntoResponse as _, Response},
};
use forwarded_header_value::ForwardedHeaderValue;
use tracing::{Span, instrument};
use crate::extract::trusted_proxy::TrustedProxy;
const X_FORWARDED_HOST: &str = "X-Forwarded-Host";
/// Request host.
///
/// The [`host_middleware`] must be added to the router before using this extractor,
/// otherwise this will result in requests erroring.
#[derive(Clone, Debug)]
pub struct Host(pub String);
impl<S> FromRequestParts<S> for Host
where
S: Send + Sync,
{
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Extension::<Self>::from_request_parts(parts, state)
.await
.map(|Extension(host)| host)
}
}
/// Get the host from the request.
#[instrument(skip_all)]
async fn extract_host(parts: &mut Parts) -> Result<String, (StatusCode, &'static str)> {
let is_trusted = parts
.extract::<TrustedProxy>()
.await
.unwrap_or(TrustedProxy(false))
.0;
if is_trusted {
if let Some(host) = parts
.headers
.get(X_FORWARDED_HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(host.to_owned());
}
if let Some(forwarded) = parts.headers.get(FORWARDED)
&& let Ok(forwarded) = forwarded.to_str()
&& let Ok(forwarded) = ForwardedHeaderValue::from_forwarded(forwarded)
{
for stanza in forwarded.iter() {
if let Some(forwarded_host) = &stanza.forwarded_host {
return Ok(forwarded_host.to_owned());
}
}
}
}
if let Some(host) = parts.headers.get(HOST).and_then(|host| host.to_str().ok()) {
return Ok(host.to_owned());
}
if let Some(host) = parts.uri.host() {
Ok(host.to_owned())
} else {
Err((StatusCode::BAD_REQUEST, "missing host header"))
}
}
/// Middleware required by the [`Host`] extractor.
///
/// Use with [`axum::middleware::from_fn`].
pub async fn host_middleware(request: Request, next: Next) -> Response {
let (mut parts, body) = request.into_parts();
let host = match extract_host(&mut parts).await {
Ok(host) => host,
Err(err) => return err.into_response(),
};
Span::current().record("host", host.clone());
parts.extensions.insert::<Host>(Host(host));
let request = Request::from_parts(parts, body);
next.run(request).await
}
#[cfg(test)]
mod tests {
use axum::{body::Body, http::Request};
use super::*;
#[tokio::test]
async fn host_header() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("host", "example.com:8080")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"example.com:8080",
);
}
#[tokio::test]
async fn from_uri() {
let (mut parts, _) = Request::builder()
.uri("http://example.com:8080/path")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"example.com",
);
}
#[tokio::test]
async fn x_forwarded_host_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-host", "forwarded.example.com")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"forwarded.example.com",
);
}
#[tokio::test]
async fn forwarded_header_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "host=forwarded.example.com")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"forwarded.example.com",
);
}
#[tokio::test]
async fn forwarded_host_untrusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-host", "malicious.example.com")
.extension(TrustedProxy(false))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"example.com",
);
}
#[tokio::test]
async fn forwarded_header_untrusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "host=malicious.example.com")
.extension(TrustedProxy(false))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"example.com",
);
}
#[tokio::test]
async fn priority_order() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-host", "x-forwarded.example.com")
.header("forwarded", "host=forwarded.example.com")
.header("host", "host-header.example.com")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"x-forwarded.example.com",
);
}
#[tokio::test]
async fn no_host_found() {
let (mut parts, _) = Request::builder()
.uri("/path")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_err());
assert_eq!(result.expect_err("Host extract should fail").0, 400);
}
#[tokio::test]
async fn multiple_forwarded_stanzas() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header(
"forwarded",
"host=first.example.com, host=second.example.com",
)
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"first.example.com",
);
}
}
+1
View File
@@ -1,5 +1,6 @@
//! axum extractors to get information about a request.
pub mod client_ip;
pub mod host;
pub mod scheme;
pub mod trusted_proxy;
+3 -2
View File
@@ -7,7 +7,7 @@ use tower_http::timeout::TimeoutLayer;
use crate::{
extract::{
client_ip::client_ip_middleware, scheme::scheme_middleware,
client_ip::client_ip_middleware, host::host_middleware, scheme::scheme_middleware,
trusted_proxy::trusted_proxy_middleware,
},
tracing::{span_middleware, tracing_middleware},
@@ -32,7 +32,8 @@ pub fn wrap_router(router: Router, with_tracing: bool) -> Router {
.layer(from_fn(span_middleware))
.layer(from_fn(trusted_proxy_middleware))
.layer(from_fn(client_ip_middleware))
.layer(from_fn(scheme_middleware));
.layer(from_fn(scheme_middleware))
.layer(from_fn(host_middleware));
if with_tracing {
router.layer(service_builder.layer(from_fn(tracing_middleware)))
} else {
+1
View File
@@ -29,6 +29,7 @@ pub(crate) async fn span_middleware(request: Request, next: Next) -> Response {
method = %request.method(),
remote = field::Empty,
scheme = field::Empty,
host = field::Empty,
http_headers = ?http_headers,
);
next.run(request).instrument(span).await