packages/ak-axum/extract/scheme: init (#21322)

This commit is contained in:
Marc 'risson' Schmitt
2026-04-08 14:39:58 +00:00
committed by GitHub
parent 2b8313ee91
commit d4e651d893
7 changed files with 263 additions and 2 deletions
Generated
+1
View File
@@ -133,6 +133,7 @@ dependencies = [
"client-ip",
"durstr",
"eyre",
"forwarded-header-value",
"futures",
"tokio",
"tokio-rustls",
+1
View File
@@ -34,6 +34,7 @@ console-subscriber = "= 0.5.0"
dotenvy = "= 0.15.7"
durstr = "= 0.5.1"
eyre = "= 0.6.12"
forwarded-header-value = "= 0.1.1"
futures = "= 0.3.32"
glob = "= 0.3.3"
ipnet = { version = "= 2.12.0", features = ["serde"] }
+1
View File
@@ -16,6 +16,7 @@ axum.workspace = true
client-ip.workspace = true
durstr.workspace = true
eyre.workspace = true
forwarded-header-value.workspace = true
futures.workspace = true
tokio-rustls.workspace = true
tokio.workspace = true
+1
View File
@@ -1,4 +1,5 @@
//! axum extractors to get information about a request.
pub mod client_ip;
pub mod scheme;
pub mod trusted_proxy;
+252
View File
@@ -0,0 +1,252 @@
//! axum extractor and middleware to get the request scheme.
use axum::{
Extension, RequestPartsExt as _,
extract::{FromRequestParts, Request},
http::{self, header::FORWARDED, request::Parts},
middleware::Next,
response::Response,
};
use forwarded_header_value::{ForwardedHeaderValue, Protocol};
use tracing::{Span, instrument};
use crate::{
accept::{proxy_protocol::ProxyProtocolState, tls::TlsState},
extract::trusted_proxy::TrustedProxy,
};
const X_FORWARDED_PROTO: &str = "X-Forwarded-Proto";
const X_FORWARDED_SCHEME: &str = "X-Forwarded-Scheme";
/// Request scheme.
///
/// The [`scheme_middleware`] must be added to the router before using this extractor,
/// otherwise this will result in requests erroring.
#[derive(Clone, Debug)]
pub struct Scheme(pub http::uri::Scheme);
impl<S> FromRequestParts<S> for Scheme
where
S: Send + Sync,
{
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Extension::<Self>::from_request_parts(parts, state)
.await
.map(|Extension(scheme)| scheme)
}
}
/// Get the scheme from the request.
#[instrument(skip_all)]
async fn extract_scheme(parts: &mut Parts) -> http::uri::Scheme {
let is_trusted = parts
.extract::<TrustedProxy>()
.await
.unwrap_or(TrustedProxy(false))
.0;
if is_trusted {
if let Some(proto) = parts.headers.get(X_FORWARDED_PROTO)
&& let Ok(proto) = proto.to_str()
&& let Ok(scheme) = proto.to_lowercase().as_str().try_into()
{
return scheme;
}
if let Some(proto) = parts.headers.get(X_FORWARDED_SCHEME)
&& let Ok(proto) = proto.to_str()
&& let Ok(scheme) = proto.to_lowercase().as_str().try_into()
{
return scheme;
}
if let Some(forwarded) = parts.headers.get(FORWARDED)
&& let Ok(forwarded) = forwarded.to_str()
&& let Ok(forwarded) = ForwardedHeaderValue::from_forwarded(forwarded)
{
for stanza in forwarded.iter() {
if let Some(forwarded_proto) = &stanza.forwarded_proto {
let scheme = match forwarded_proto {
Protocol::Http => http::uri::Scheme::HTTP,
Protocol::Https => http::uri::Scheme::HTTPS,
};
return scheme;
}
}
}
if let Ok(Extension(proxy_protocol_state)) =
parts.extract::<Extension<ProxyProtocolState>>().await
&& let Some(header) = &proxy_protocol_state.header
&& header.ssl().is_some()
{
return http::uri::Scheme::HTTPS;
}
}
if parts.extract::<Extension<TlsState>>().await.is_ok() {
http::uri::Scheme::HTTPS
} else {
http::uri::Scheme::HTTP
}
}
/// Middleware required by the [`Scheme`] extractor.
///
/// Use with [`axum::middleware::from_fn`].
pub async fn scheme_middleware(request: Request, next: Next) -> Response {
let (mut parts, body) = request.into_parts();
let scheme = extract_scheme(&mut parts).await;
Span::current().record("scheme", scheme.to_string());
parts.extensions.insert::<Scheme>(Scheme(scheme));
let request = Request::from_parts(parts, body);
next.run(request).await
}
#[cfg(test)]
mod tests {
use axum::{body::Body, http::Request};
use super::*;
#[tokio::test]
async fn x_forwarded_proto_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-proto", "https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
#[tokio::test]
async fn x_forwarded_scheme_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-scheme", "https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
#[tokio::test]
async fn forwarded_header_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "proto=https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
#[tokio::test]
async fn x_forwarded_proto_untrusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-proto", "https")
.extension(TrustedProxy(false))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTP,);
}
#[tokio::test]
async fn scheme_from_tls_state() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.extension(TlsState {
peer_certificates: None,
})
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
#[tokio::test]
async fn scheme_defaults_to_http() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTP,);
}
#[tokio::test]
async fn priority_order() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-proto", "http")
.header("x-forwarded-scheme", "https")
.header("forwarded", "proto=https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTP,);
}
#[tokio::test]
async fn multiple_forwarded_stanzas() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "proto=http, proto=https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTP,);
}
#[tokio::test]
async fn test_scheme_case_insensitive() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-proto", "HTTPS")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
}
+6 -2
View File
@@ -6,7 +6,10 @@ use tower::ServiceBuilder;
use tower_http::timeout::TimeoutLayer;
use crate::{
extract::{client_ip::client_ip_middleware, trusted_proxy::trusted_proxy_middleware},
extract::{
client_ip::client_ip_middleware, scheme::scheme_middleware,
trusted_proxy::trusted_proxy_middleware,
},
tracing::{span_middleware, tracing_middleware},
};
@@ -28,7 +31,8 @@ pub fn wrap_router(router: Router, with_tracing: bool) -> Router {
))
.layer(from_fn(span_middleware))
.layer(from_fn(trusted_proxy_middleware))
.layer(from_fn(client_ip_middleware));
.layer(from_fn(client_ip_middleware))
.layer(from_fn(scheme_middleware));
if with_tracing {
router.layer(service_builder.layer(from_fn(tracing_middleware)))
} else {
+1
View File
@@ -28,6 +28,7 @@ pub(crate) async fn span_middleware(request: Request, next: Next) -> Response {
path = %request.uri(),
method = %request.method(),
remote = field::Empty,
scheme = field::Empty,
http_headers = ?http_headers,
);
next.run(request).instrument(span).await