diff --git a/authentik/lib/config.py b/authentik/lib/config.py index ba5da03093..de0dccf871 100644 --- a/authentik/lib/config.py +++ b/authentik/lib/config.py @@ -276,7 +276,7 @@ class ConfigLoader: try: return int(value) except (ValueError, TypeError) as exc: - if value is None or (isinstance(value, str) and value.lower() == "null"): + if value is None or (isinstance(value, str) and value.lower() in ("", "null", "none")): return None self.log("warning", "Failed to parse config as int", path=path, exc=str(exc)) return default diff --git a/authentik/lib/default.yml b/authentik/lib/default.yml index cfbb9c68ba..7f04e3a819 100644 --- a/authentik/lib/default.yml +++ b/authentik/lib/default.yml @@ -22,6 +22,7 @@ postgresql: port: 5432 password: "env://POSTGRES_PASSWORD" sslmode: disable + conn_max_age: 0 conn_health_checks: false use_pool: False test: diff --git a/authentik/lib/tests/test_config.py b/authentik/lib/tests/test_config.py index 90c3afebbc..1ed289cf7f 100644 --- a/authentik/lib/tests/test_config.py +++ b/authentik/lib/tests/test_config.py @@ -315,7 +315,7 @@ class TestConfig(TestCase): { "default": { "DISABLE_SERVER_SIDE_CURSORS": True, - "CONN_MAX_AGE": None, + "CONN_MAX_AGE": 0, "CONN_HEALTH_CHECKS": False, "ENGINE": "psqlextra.backend", "HOST": "foo", diff --git a/packages/ak-common/src/config/schema.rs b/packages/ak-common/src/config/schema.rs index 5d110051d8..e892ebd369 100644 --- a/packages/ak-common/src/config/schema.rs +++ b/packages/ak-common/src/config/schema.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, net::SocketAddr, num::NonZeroUsize}; use ipnet::IpNet; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, de::Error as _}; pub(super) const KEYS_TO_PARSE_AS_LIST: [&str; 4] = [ "listen.http", @@ -10,6 +10,32 @@ pub(super) const KEYS_TO_PARSE_AS_LIST: [&str; 4] = [ "log.http_headers", ]; +fn deserialize_optional_u64<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + // The value comes as a number from config files but as a string from env vars. + #[derive(Deserialize)] + #[serde(untagged)] + enum NumOrStr { + Num(u64), + Str(String), + } + + match Option::::deserialize(deserializer)? { + None => Ok(None), + Some(NumOrStr::Num(n)) => Ok(Some(n)), + Some(NumOrStr::Str(s)) => { + let s = s.trim(); + if s.is_empty() || s.eq_ignore_ascii_case("none") || s.eq_ignore_ascii_case("null") { + Ok(None) + } else { + s.parse().map(Some).map_err(D::Error::custom) + } + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { pub postgresql: PostgreSQLConfig, @@ -50,6 +76,7 @@ pub struct PostgreSQLConfig { pub sslcert: Option, pub sslkey: Option, + #[serde(deserialize_with = "deserialize_optional_u64")] pub conn_max_age: Option, pub conn_health_checks: bool, diff --git a/packages/ak-common/src/db.rs b/packages/ak-common/src/db.rs index 293faf4439..6ec0bceb56 100644 --- a/packages/ak-common/src/db.rs +++ b/packages/ak-common/src/db.rs @@ -68,7 +68,6 @@ pub async fn init(tasks: &mut Tasks) -> Result<()> { .min_connections(1) .max_connections(4) .acquire_time_level(LevelFilter::Trace) - .max_lifetime(config.postgresql.conn_max_age.map(Duration::from_secs)) .test_before_acquire(config.postgresql.conn_health_checks) .after_connect(|conn, _meta| { Box::pin(async move { @@ -84,6 +83,11 @@ pub async fn init(tasks: &mut Tasks) -> Result<()> { }) }); + let pool_options = match config.postgresql.conn_max_age { + Some(0) => pool_options.after_release(|_conn, _meta| Box::pin(async { Ok(false) })), + other => pool_options.max_lifetime(other.map(Duration::from_secs)), + }; + let pool = pool_options.connect_with(options).await?; DB.get_or_init(|| pool);