diff --git a/internal/config/config.go b/internal/config/config.go index b383a2731e..f11dbc9feb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -182,6 +182,43 @@ func (c *Config) parseScheme(rawVal string) string { return rawVal } +// RefreshPostgreSQLConfig re-reads PostgreSQL configuration from file:// and env:// URIs +// This enables hot-reloading when credentials are rotated by updating the referenced files. +// Note: Plain environment variables (without file:// or env:// prefixes) are read from the +// process environment and will not change unless the process is restarted or os.Setenv is called. +func (c *Config) RefreshPostgreSQLConfig() PostgreSQLConfig { + // Start with current config as base + refreshed := c.PostgreSQL + + // Manually read from environment variables with proper prefix + // We can't use env.Process directly on PostgreSQLConfig because it loses the AUTHENTIK_POSTGRESQL__ prefix + // Map of environment variable suffix to config field pointer + envVars := map[string]*string{ + "HOST": &refreshed.Host, + "USER": &refreshed.User, + "PASSWORD": &refreshed.Password, + "NAME": &refreshed.Name, + "SSLMODE": &refreshed.SSLMode, + "SSLROOTCERT": &refreshed.SSLRootCert, + "SSLCERT": &refreshed.SSLCert, + "SSLKEY": &refreshed.SSLKey, + "DEFAULT_SCHEMA": &refreshed.DefaultSchema, + "CONN_OPTIONS": &refreshed.ConnOptions, + } + + // Read each environment variable if it exists + for suffix, field := range envVars { + if val, ok := os.LookupEnv("AUTHENTIK_POSTGRESQL__" + suffix); ok { + *field = val + } + } + + // Process file:// and env:// URI schemes + c.walkScheme(&refreshed) + + return refreshed +} + func (c *Config) configureLogger() { switch strings.ToLower(c.LogLevel) { case "trace": diff --git a/internal/outpost/proxyv2/postgresstore/connpool.go b/internal/outpost/proxyv2/postgresstore/connpool.go new file mode 100644 index 0000000000..c9a5645ff9 --- /dev/null +++ b/internal/outpost/proxyv2/postgresstore/connpool.go @@ -0,0 +1,291 @@ +package postgresstore + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "sync" + "time" + + "github.com/jackc/pgx/v5/pgconn" + log "github.com/sirupsen/logrus" + "gorm.io/driver/postgres" + "gorm.io/gorm" + + "goauthentik.io/internal/config" +) + +// RefreshableConnPool wraps sql.DB and refreshes PostgreSQL credentials on authentication errors +// This implements gorm.ConnPool interface to allow credential rotation +type RefreshableConnPool struct { + mu sync.RWMutex + db *sql.DB + dsnBuilder func(config.PostgreSQLConfig) (string, error) + log *log.Entry + currentDSN string + gormConfig *gorm.Config + + // Connection pool settings (stored for reapplication after reconnection) + maxIdleConns int + maxOpenConns int + connMaxLifetime time.Duration + + // Reconnection management + reconnecting sync.Mutex // Prevent concurrent reconnections +} + +// NewRefreshableConnPool creates a new connection pool that refreshes credentials from config +func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleConns, maxOpenConns int, connMaxLifetime time.Duration) (*RefreshableConnPool, error) { + db, err := sql.Open("postgres", initialDSN) + if err != nil { + return nil, err + } + + // Apply connection pool settings + db.SetMaxIdleConns(maxIdleConns) + db.SetMaxOpenConns(maxOpenConns) + db.SetConnMaxLifetime(connMaxLifetime) + + pool := &RefreshableConnPool{ + db: db, + dsnBuilder: BuildDSN, + log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"), + currentDSN: initialDSN, + gormConfig: gormConfig, + maxIdleConns: maxIdleConns, + maxOpenConns: maxOpenConns, + connMaxLifetime: connMaxLifetime, + } + + return pool, nil +} + +// isAuthError checks if an error is a PostgreSQL authentication error +func isAuthError(err error) bool { + if err == nil { + return false + } + + // Unwrap the error to find the underlying pgconn.PgError + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + // Check for any PostgreSQL error code in Class 28 (Invalid Authorization Specification) + // See https://www.postgresql.org/docs/current/errcodes-appendix.html + return len(pgErr.Code) >= 2 && pgErr.Code[:2] == "28" + } + + return false +} + +// refreshCredentials checks if credentials have changed and reconnects if needed +func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error { + // Prevent concurrent reconnections + p.reconnecting.Lock() + defer p.reconnecting.Unlock() + + // Get fresh config + cfg := config.Get().RefreshPostgreSQLConfig() + newDSN, err := p.dsnBuilder(cfg) + if err != nil { + p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials") + return err + } + + p.mu.RLock() + dsnChanged := newDSN != p.currentDSN + p.mu.RUnlock() + + if !dsnChanged { + p.log.Debug("Credentials unchanged, skipping reconnection") + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + // Double-check after acquiring write lock + if newDSN == p.currentDSN { + return nil + } + + p.log.Info("PostgreSQL credentials changed, reconnecting...") + + // Open new connection with fresh credentials + newDB, err := sql.Open("postgres", newDSN) + if err != nil { + p.log.WithError(err).Error("Failed to open new database connection with refreshed credentials") + return err + } + + // Reapply connection pool settings + newDB.SetMaxIdleConns(p.maxIdleConns) + newDB.SetMaxOpenConns(p.maxOpenConns) + newDB.SetConnMaxLifetime(p.connMaxLifetime) + + // Verify the connection works BEFORE closing old connection + if err := newDB.PingContext(ctx); err != nil { + p.log.WithError(err).Error("Failed to ping database with new credentials") + _ = newDB.Close() + // Old connection remains active, pool is still functional + return err + } + + // Only after successful verification, swap connections + oldDB := p.db + p.db = newDB + p.currentDSN = newDSN + + // Close old connection after swap + if oldDB != nil { + if err := oldDB.Close(); err != nil { + p.log.WithError(err).Warn("Failed to close old database connection") + // Not fatal cause new connection is already active + } + } + + p.log.Info("Successfully reconnected with new PostgreSQL credentials") + + return nil +} + +// tryWithRefresh attempts an operation, and if it fails with an auth error, refreshes credentials and retries +func (p *RefreshableConnPool) tryWithRefresh(ctx context.Context, op func() error) error { + err := op() + if err != nil && isAuthError(err) { + p.log.WithError(err).Info("Authentication error detected, attempting to refresh credentials") + if refreshErr := p.refreshCredentials(ctx); refreshErr == nil { + // Retry the operation once after successful refresh + p.log.Debug("Retrying operation after credential refresh") + return op() + } else { + p.log.WithError(refreshErr).Warn("Failed to refresh credentials, returning original error") + } + } + return err +} + +// PrepareContext implements gorm.ConnPool interface +func (p *RefreshableConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + var stmt *sql.Stmt + err := p.tryWithRefresh(ctx, func() error { + p.mu.RLock() + defer p.mu.RUnlock() + var err error + stmt, err = p.db.PrepareContext(ctx, query) + return err + }) + return stmt, err +} + +// ExecContext implements gorm.ConnPool interface +func (p *RefreshableConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + var result sql.Result + err := p.tryWithRefresh(ctx, func() error { + p.mu.RLock() + defer p.mu.RUnlock() + var err error + result, err = p.db.ExecContext(ctx, query, args...) + return err + }) + return result, err +} + +// QueryContext implements gorm.ConnPool interface +func (p *RefreshableConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + var rows *sql.Rows + err := p.tryWithRefresh(ctx, func() error { + p.mu.RLock() + defer p.mu.RUnlock() + var err error + rows, err = p.db.QueryContext(ctx, query, args...) + return err + }) + return rows, err +} + +// QueryRowContext implements gorm.ConnPool interface +func (p *RefreshableConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + // Note: sql.Row doesn't return errors until Scan() is called, so we can't detect auth errors here + // The error will be caught in higher-level GORM operations + p.mu.RLock() + defer p.mu.RUnlock() + return p.db.QueryRowContext(ctx, query, args...) +} + +// BeginTx implements gorm.TxBeginner and gorm.ConnPoolBeginner interfaces +func (p *RefreshableConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { + var tx *sql.Tx + err := p.tryWithRefresh(ctx, func() error { + p.mu.RLock() + defer p.mu.RUnlock() + var err error + tx, err = p.db.BeginTx(ctx, opts) + return err + }) + if err != nil { + return nil, err + } + return &refreshableTx{Tx: tx, pool: p}, nil +} + +// refreshableTx wraps sql.Tx to implement gorm.ConnPool +type refreshableTx struct { + *sql.Tx + pool *RefreshableConnPool +} + +func (tx *refreshableTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return tx.Tx.PrepareContext(ctx, query) +} + +func (tx *refreshableTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return tx.Tx.ExecContext(ctx, query, args...) +} + +func (tx *refreshableTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return tx.Tx.QueryContext(ctx, query, args...) +} + +func (tx *refreshableTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return tx.Tx.QueryRowContext(ctx, query, args...) +} + +// Close closes the underlying database connection +func (p *RefreshableConnPool) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.db != nil { + return p.db.Close() + } + return nil +} + +// Ping verifies the connection is alive +func (p *RefreshableConnPool) Ping(ctx context.Context) error { + p.mu.RLock() + defer p.mu.RUnlock() + return p.db.PingContext(ctx) +} + +// GetDB returns the underlying sql.DB for connection pool configuration +func (p *RefreshableConnPool) GetDB() *sql.DB { + p.mu.RLock() + defer p.mu.RUnlock() + return p.db +} + +// NewGORMDB creates a GORM DB instance using the refreshable connection pool +func (p *RefreshableConnPool) NewGORMDB() (*gorm.DB, error) { + dialector := postgres.New(postgres.Config{ + Conn: p, + }) + return gorm.Open(dialector, p.gormConfig) +} + +// Ensure RefreshableConnPool implements required interfaces +var ( + _ gorm.ConnPool = (*RefreshableConnPool)(nil) + _ gorm.ConnPoolBeginner = (*RefreshableConnPool)(nil) + _ driver.Pinger = (*RefreshableConnPool)(nil) +) diff --git a/internal/outpost/proxyv2/postgresstore/connpool_test.go b/internal/outpost/proxyv2/postgresstore/connpool_test.go new file mode 100644 index 0000000000..bcb63698b4 --- /dev/null +++ b/internal/outpost/proxyv2/postgresstore/connpool_test.go @@ -0,0 +1,415 @@ +package postgresstore + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "goauthentik.io/internal/config" +) + +func TestRefreshableConnPool_CredentialRefresh(t *testing.T) { + // Skip if no PostgreSQL available + if os.Getenv("AUTHENTIK_POSTGRESQL__HOST") == "" { + t.Skip("Skipping test: no PostgreSQL configured") + } + + // Create a temporary file for password rotation + tmpDir := t.TempDir() + passwordFile := filepath.Join(tmpDir, "db_password") + + // Write initial password + initialPassword := os.Getenv("AUTHENTIK_POSTGRESQL__PASSWORD") + if initialPassword == "" { + initialPassword = "postgres" + } + err := os.WriteFile(passwordFile, []byte(initialPassword), 0600) + require.NoError(t, err) + + // Set up config to use file:// URI for password + originalPassword := os.Getenv("AUTHENTIK_POSTGRESQL__PASSWORD") + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile)) + defer func() { + if originalPassword != "" { + _ = os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", originalPassword) + } else { + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD") + } + }() + + // Reload config + cfg := &config.Config{} + cfg.Setup() + + // Build initial DSN + dsn, err := BuildDSN(cfg.PostgreSQL) + require.NoError(t, err) + + gormConfig := &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + NowFunc: func() time.Time { + return time.Now().UTC() + }, + } + + // Create refreshable connection pool + pool, err := NewRefreshableConnPool(dsn, gormConfig, 10, 100, time.Hour) + require.NoError(t, err) + defer func() { _ = pool.Close() }() + + // Test initial connection works + ctx := context.Background() + err = pool.Ping(ctx) + assert.NoError(t, err, "Initial connection should work") + + // Create GORM DB + db, err := pool.NewGORMDB() + require.NoError(t, err) + + // Execute a test query + var result int + err = db.WithContext(ctx).Raw("SELECT 1").Scan(&result).Error + assert.NoError(t, err, "Initial query should succeed") + assert.Equal(t, 1, result) + + // Simulate password change by writing to file + // In real scenario, this would be an external process updating the file + time.Sleep(100 * time.Millisecond) // Small delay to ensure file modification time changes + err = os.WriteFile(passwordFile, []byte(initialPassword), 0600) + require.NoError(t, err) + + // Execute another query - should trigger credential refresh check + err = db.WithContext(ctx).Raw("SELECT 2").Scan(&result).Error + assert.NoError(t, err, "Query after credential refresh should succeed") + assert.Equal(t, 2, result) +} + +func TestRefreshableConnPool_Interfaces(t *testing.T) { + // Verify that RefreshableConnPool implements required interfaces at compile time + // This test will fail to compile if interfaces are not properly implemented + var pool *RefreshableConnPool + + // Test gorm.ConnPool interface + var _ gorm.ConnPool = pool + + // Test gorm.ConnPoolBeginner interface + var _ gorm.ConnPoolBeginner = pool +} + +func TestRefreshableConnPool_ConcurrentAccess(t *testing.T) { + // Skip if no PostgreSQL available + if os.Getenv("AUTHENTIK_POSTGRESQL__HOST") == "" { + t.Skip("Skipping test: no PostgreSQL configured") + } + + cfg := config.Get() + dsn, err := BuildDSN(cfg.PostgreSQL) + require.NoError(t, err) + + gormConfig := &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + } + + pool, err := NewRefreshableConnPool(dsn, gormConfig, 10, 100, time.Hour) + require.NoError(t, err) + defer func() { _ = pool.Close() }() + + db, err := pool.NewGORMDB() + require.NoError(t, err) + + // Test concurrent queries + ctx := context.Background() + numGoroutines := 10 + numQueries := 5 + + errChan := make(chan error, numGoroutines*numQueries) + + for i := 0; i < numGoroutines; i++ { + go func(goroutineID int) { + for j := 0; j < numQueries; j++ { + var result int + err := db.WithContext(ctx).Raw("SELECT ?", goroutineID*numQueries+j).Scan(&result).Error + if err != nil { + errChan <- err + } + } + }(i) + } + + // Wait a bit for goroutines to complete + time.Sleep(2 * time.Second) + close(errChan) + + // Check for any errors + for err := range errChan { + assert.NoError(t, err, "Concurrent queries should succeed") + } +} + +func TestRefreshableConnPool_InvalidCredentials(t *testing.T) { + // Create a pool with invalid credentials + invalidDSN := "host=localhost port=5432 user=invalid password=invalid dbname=invalid sslmode=disable" + + gormConfig := &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + } + + pool, err := NewRefreshableConnPool(invalidDSN, gormConfig, 10, 100, time.Hour) + if err != nil { + // sql.Open may succeed even with invalid credentials (lazy connection) + return + } + defer func() { _ = pool.Close() }() + + // Ping should fail with invalid credentials + ctx := context.Background() + err = pool.Ping(ctx) + assert.Error(t, err, "Ping with invalid credentials should fail") +} + +func TestConfig_RefreshPostgreSQLConfig_FileURI(t *testing.T) { + // Create temporary files for testing file:// URIs + tmpDir := t.TempDir() + + passwordFile := filepath.Join(tmpDir, "password") + userFile := filepath.Join(tmpDir, "user") + hostFile := filepath.Join(tmpDir, "host") + + err := os.WriteFile(passwordFile, []byte("secret_password"), 0600) + require.NoError(t, err) + err = os.WriteFile(userFile, []byte("dbuser"), 0600) + require.NoError(t, err) + err = os.WriteFile(hostFile, []byte("db.example.com"), 0600) + require.NoError(t, err) + + // Set up environment variables with file:// URIs + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile)) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "file://"+userFile)) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "file://"+hostFile)) + defer func() { + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST") + }() + + // Create and setup config + cfg := &config.Config{} + cfg.Setup() + + // Test initial values are parsed correctly + assert.Equal(t, "secret_password", cfg.PostgreSQL.Password, "Initial password should be parsed from file") + assert.Equal(t, "dbuser", cfg.PostgreSQL.User, "Initial user should be parsed from file") + assert.Equal(t, "db.example.com", cfg.PostgreSQL.Host, "Initial host should be parsed from file") + + // Test RefreshPostgreSQLConfig returns same values initially + refreshed := cfg.RefreshPostgreSQLConfig() + assert.Equal(t, "secret_password", refreshed.Password) + assert.Equal(t, "dbuser", refreshed.User) + assert.Equal(t, "db.example.com", refreshed.Host) + + // Update password file (simulating credential rotation) + err = os.WriteFile(passwordFile, []byte("new_password"), 0600) + require.NoError(t, err) + + // Update user file + err = os.WriteFile(userFile, []byte("new_dbuser"), 0600) + require.NoError(t, err) + + // Refresh should pick up new values from files + refreshed = cfg.RefreshPostgreSQLConfig() + assert.Equal(t, "new_password", refreshed.Password, "Password should be refreshed from file") + assert.Equal(t, "new_dbuser", refreshed.User, "User should be refreshed from file") + + // Original config struct should still have old values (not mutated) + assert.Equal(t, "secret_password", cfg.PostgreSQL.Password, "Original config should not be mutated") +} + +func TestConfig_RefreshPostgreSQLConfig_EnvURI(t *testing.T) { + // Test with env:// URIs (referencing other env vars) + require.NoError(t, os.Setenv("DB_PASSWORD", "env_password")) + require.NoError(t, os.Setenv("DB_USER", "env_user")) + require.NoError(t, os.Setenv("DB_HOST", "env_host")) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "env://DB_PASSWORD")) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "env://DB_USER")) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "env://DB_HOST")) + defer func() { + _ = os.Unsetenv("DB_PASSWORD") + _ = os.Unsetenv("DB_USER") + _ = os.Unsetenv("DB_HOST") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST") + }() + + cfg := &config.Config{} + cfg.Setup() + + // Test initial values are parsed correctly + assert.Equal(t, "env_password", cfg.PostgreSQL.Password, "Initial password should be parsed from env") + assert.Equal(t, "env_user", cfg.PostgreSQL.User, "Initial user should be parsed from env") + assert.Equal(t, "env_host", cfg.PostgreSQL.Host, "Initial host should be parsed from env") + + // Test RefreshPostgreSQLConfig + refreshed := cfg.RefreshPostgreSQLConfig() + assert.Equal(t, "env_password", refreshed.Password) + assert.Equal(t, "env_user", refreshed.User) + assert.Equal(t, "env_host", refreshed.Host) + + // Change referenced environment variables (simulating credential rotation) + require.NoError(t, os.Setenv("DB_PASSWORD", "new_env_password")) + require.NoError(t, os.Setenv("DB_USER", "new_env_user")) + + // Refresh should pick up new values + refreshed = cfg.RefreshPostgreSQLConfig() + assert.Equal(t, "new_env_password", refreshed.Password, "Password should be refreshed from env") + assert.Equal(t, "new_env_user", refreshed.User, "User should be refreshed from env") + + // Original config struct should still have old values (not mutated) + assert.Equal(t, "env_password", cfg.PostgreSQL.Password, "Original config should not be mutated") +} + +func TestConfig_RefreshPostgreSQLConfig_PlainValues(t *testing.T) { + // Test with plain values (no URI scheme) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "plain_password")) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "plain_user")) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "localhost")) + defer func() { + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST") + }() + + cfg := &config.Config{} + cfg.Setup() + + // Test initial values + assert.Equal(t, "plain_password", cfg.PostgreSQL.Password) + assert.Equal(t, "plain_user", cfg.PostgreSQL.User) + assert.Equal(t, "localhost", cfg.PostgreSQL.Host) + + // Test refresh returns same values + refreshed := cfg.RefreshPostgreSQLConfig() + assert.Equal(t, "plain_password", refreshed.Password) + assert.Equal(t, "plain_user", refreshed.User) + assert.Equal(t, "localhost", refreshed.Host) + + // Change env vars + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "new_plain_password")) + + // Refresh should pick up new plain value + refreshed = cfg.RefreshPostgreSQLConfig() + assert.Equal(t, "new_plain_password", refreshed.Password, "Plain password should be refreshed") +} + +func TestConfig_RefreshPostgreSQLConfig_MixedSources(t *testing.T) { + // Test with mixed sources: file://, env://, and plain + tmpDir := t.TempDir() + passwordFile := filepath.Join(tmpDir, "password") + err := os.WriteFile(passwordFile, []byte("file_password"), 0600) + require.NoError(t, err) + + require.NoError(t, os.Setenv("DB_USER_VAR", "env_user")) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile)) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "env://DB_USER_VAR")) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "plain_host")) + defer func() { + _ = os.Unsetenv("DB_USER_VAR") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER") + _ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST") + }() + + cfg := &config.Config{} + cfg.Setup() + + // Test initial values + assert.Equal(t, "file_password", cfg.PostgreSQL.Password) + assert.Equal(t, "env_user", cfg.PostgreSQL.User) + assert.Equal(t, "plain_host", cfg.PostgreSQL.Host) + + // Update all sources + err = os.WriteFile(passwordFile, []byte("new_file_password"), 0600) + require.NoError(t, err) + require.NoError(t, os.Setenv("DB_USER_VAR", "new_env_user")) + require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "new_plain_host")) + + // Refresh should pick up all changes + refreshed := cfg.RefreshPostgreSQLConfig() + assert.Equal(t, "new_file_password", refreshed.Password, "File password should be refreshed") + assert.Equal(t, "new_env_user", refreshed.User, "Env user should be refreshed") + assert.Equal(t, "new_plain_host", refreshed.Host, "Plain host should be refreshed") +} + +func TestIsAuthError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "generic error", + err: assert.AnError, + expected: false, + }, + { + name: "postgres error code 28000 - invalid_authorization_specification", + err: &pgconn.PgError{ + Code: "28000", + Message: "invalid authorization specification", + }, + expected: true, + }, + { + name: "postgres error code 28P01 - invalid_password", + err: &pgconn.PgError{ + Code: "28P01", + Message: "password authentication failed for user", + }, + expected: true, + }, + { + name: "postgres error code 28P02 - invalid_password (deprecated)", + err: &pgconn.PgError{ + Code: "28P02", + Message: "invalid password", + }, + expected: true, + }, + { + name: "postgres error code 42P01 - undefined_table (not auth error)", + err: &pgconn.PgError{ + Code: "42P01", + Message: "relation does not exist", + }, + expected: false, + }, + { + name: "postgres error code 23505 - unique_violation (not auth error)", + err: &pgconn.PgError{ + Code: "23505", + Message: "duplicate key value violates unique constraint", + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isAuthError(tt.err) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/outpost/proxyv2/postgresstore/postgresstore.go b/internal/outpost/proxyv2/postgresstore/postgresstore.go index 6567a34f77..4d89266389 100644 --- a/internal/outpost/proxyv2/postgresstore/postgresstore.go +++ b/internal/outpost/proxyv2/postgresstore/postgresstore.go @@ -14,7 +14,6 @@ import ( "github.com/gorilla/sessions" "github.com/mitchellh/mapstructure" log "github.com/sirupsen/logrus" - "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/logger" @@ -26,7 +25,8 @@ import ( // PostgresStore stores gorilla sessions in PostgreSQL using GORM type PostgresStore struct { - db *gorm.DB + db *gorm.DB + pool *RefreshableConnPool // Keep reference to pool for cleanup // default options to use when a new session is created options sessions.Options // key prefix with which the session will be stored @@ -125,29 +125,32 @@ func NewPostgresStore() (*PostgresStore, error) { }, } - db, err := gorm.Open(postgres.Open(dsn), gormConfig) + // Determine connection pool settings + maxIdleConns := 10 + maxOpenConns := 100 + var connMaxLifetime time.Duration + if cfg.ConnMaxAge > 0 { + connMaxLifetime = time.Duration(cfg.ConnMaxAge) * time.Second + } else { + connMaxLifetime = time.Hour // Default 1 hour + } + + // Create refreshable connection pool + pool, err := NewRefreshableConnPool(dsn, gormConfig, maxIdleConns, maxOpenConns, connMaxLifetime) if err != nil { + return nil, fmt.Errorf("failed to create connection pool: %w", err) + } + + // Create GORM DB using the refreshable connection pool + db, err := pool.NewGORMDB() + if err != nil { + _ = pool.Close() return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err) } - // Configure connection pool - sqlDB, err := db.DB() - if err != nil { - return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err) - } - - // Set connection pool settings - sqlDB.SetMaxIdleConns(10) - sqlDB.SetMaxOpenConns(100) - - if cfg.ConnMaxAge > 0 { - sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxAge) * time.Second) - } else { - sqlDB.SetConnMaxLifetime(time.Hour) // Default 1 hour - } - ps := &PostgresStore{ - db: db, + db: db, + pool: pool, options: sessions.Options{ Path: "/", MaxAge: 86400 * 30, // 30 days default (but overwritten in postgresstore creation based on token validation) @@ -224,11 +227,10 @@ func (s *PostgresStore) KeyPrefix(keyPrefix string) { // Close closes the PostgreSQL store func (s *PostgresStore) Close() error { - sqlDB, err := s.db.DB() - if err != nil { - return err + if s.pool != nil { + return s.pool.Close() } - return sqlDB.Close() + return nil } // save writes session to PostgreSQL