outpost/proxyv2: postgresstore: credential refresh (#17414)

* outpost/proxyv2: postgresstore: credential refresh

* wip

* mabye

* mabye fix
This commit is contained in:
Dominic R
2025-10-15 09:22:27 -04:00
committed by GitHub
parent d0b69bafac
commit 06bfcf04e3
4 changed files with 769 additions and 24 deletions
+37
View File
@@ -182,6 +182,43 @@ func (c *Config) parseScheme(rawVal string) string {
return rawVal
}
// RefreshPostgreSQLConfig re-reads PostgreSQL configuration from file:// and env:// URIs
// This enables hot-reloading when credentials are rotated by updating the referenced files.
// Note: Plain environment variables (without file:// or env:// prefixes) are read from the
// process environment and will not change unless the process is restarted or os.Setenv is called.
func (c *Config) RefreshPostgreSQLConfig() PostgreSQLConfig {
// Start with current config as base
refreshed := c.PostgreSQL
// Manually read from environment variables with proper prefix
// We can't use env.Process directly on PostgreSQLConfig because it loses the AUTHENTIK_POSTGRESQL__ prefix
// Map of environment variable suffix to config field pointer
envVars := map[string]*string{
"HOST": &refreshed.Host,
"USER": &refreshed.User,
"PASSWORD": &refreshed.Password,
"NAME": &refreshed.Name,
"SSLMODE": &refreshed.SSLMode,
"SSLROOTCERT": &refreshed.SSLRootCert,
"SSLCERT": &refreshed.SSLCert,
"SSLKEY": &refreshed.SSLKey,
"DEFAULT_SCHEMA": &refreshed.DefaultSchema,
"CONN_OPTIONS": &refreshed.ConnOptions,
}
// Read each environment variable if it exists
for suffix, field := range envVars {
if val, ok := os.LookupEnv("AUTHENTIK_POSTGRESQL__" + suffix); ok {
*field = val
}
}
// Process file:// and env:// URI schemes
c.walkScheme(&refreshed)
return refreshed
}
func (c *Config) configureLogger() {
switch strings.ToLower(c.LogLevel) {
case "trace":
@@ -0,0 +1,291 @@
package postgresstore
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"sync"
"time"
"github.com/jackc/pgx/v5/pgconn"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"goauthentik.io/internal/config"
)
// RefreshableConnPool wraps sql.DB and refreshes PostgreSQL credentials on authentication errors
// This implements gorm.ConnPool interface to allow credential rotation
type RefreshableConnPool struct {
mu sync.RWMutex
db *sql.DB
dsnBuilder func(config.PostgreSQLConfig) (string, error)
log *log.Entry
currentDSN string
gormConfig *gorm.Config
// Connection pool settings (stored for reapplication after reconnection)
maxIdleConns int
maxOpenConns int
connMaxLifetime time.Duration
// Reconnection management
reconnecting sync.Mutex // Prevent concurrent reconnections
}
// NewRefreshableConnPool creates a new connection pool that refreshes credentials from config
func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleConns, maxOpenConns int, connMaxLifetime time.Duration) (*RefreshableConnPool, error) {
db, err := sql.Open("postgres", initialDSN)
if err != nil {
return nil, err
}
// Apply connection pool settings
db.SetMaxIdleConns(maxIdleConns)
db.SetMaxOpenConns(maxOpenConns)
db.SetConnMaxLifetime(connMaxLifetime)
pool := &RefreshableConnPool{
db: db,
dsnBuilder: BuildDSN,
log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"),
currentDSN: initialDSN,
gormConfig: gormConfig,
maxIdleConns: maxIdleConns,
maxOpenConns: maxOpenConns,
connMaxLifetime: connMaxLifetime,
}
return pool, nil
}
// isAuthError checks if an error is a PostgreSQL authentication error
func isAuthError(err error) bool {
if err == nil {
return false
}
// Unwrap the error to find the underlying pgconn.PgError
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
// Check for any PostgreSQL error code in Class 28 (Invalid Authorization Specification)
// See https://www.postgresql.org/docs/current/errcodes-appendix.html
return len(pgErr.Code) >= 2 && pgErr.Code[:2] == "28"
}
return false
}
// refreshCredentials checks if credentials have changed and reconnects if needed
func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error {
// Prevent concurrent reconnections
p.reconnecting.Lock()
defer p.reconnecting.Unlock()
// Get fresh config
cfg := config.Get().RefreshPostgreSQLConfig()
newDSN, err := p.dsnBuilder(cfg)
if err != nil {
p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials")
return err
}
p.mu.RLock()
dsnChanged := newDSN != p.currentDSN
p.mu.RUnlock()
if !dsnChanged {
p.log.Debug("Credentials unchanged, skipping reconnection")
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
// Double-check after acquiring write lock
if newDSN == p.currentDSN {
return nil
}
p.log.Info("PostgreSQL credentials changed, reconnecting...")
// Open new connection with fresh credentials
newDB, err := sql.Open("postgres", newDSN)
if err != nil {
p.log.WithError(err).Error("Failed to open new database connection with refreshed credentials")
return err
}
// Reapply connection pool settings
newDB.SetMaxIdleConns(p.maxIdleConns)
newDB.SetMaxOpenConns(p.maxOpenConns)
newDB.SetConnMaxLifetime(p.connMaxLifetime)
// Verify the connection works BEFORE closing old connection
if err := newDB.PingContext(ctx); err != nil {
p.log.WithError(err).Error("Failed to ping database with new credentials")
_ = newDB.Close()
// Old connection remains active, pool is still functional
return err
}
// Only after successful verification, swap connections
oldDB := p.db
p.db = newDB
p.currentDSN = newDSN
// Close old connection after swap
if oldDB != nil {
if err := oldDB.Close(); err != nil {
p.log.WithError(err).Warn("Failed to close old database connection")
// Not fatal cause new connection is already active
}
}
p.log.Info("Successfully reconnected with new PostgreSQL credentials")
return nil
}
// tryWithRefresh attempts an operation, and if it fails with an auth error, refreshes credentials and retries
func (p *RefreshableConnPool) tryWithRefresh(ctx context.Context, op func() error) error {
err := op()
if err != nil && isAuthError(err) {
p.log.WithError(err).Info("Authentication error detected, attempting to refresh credentials")
if refreshErr := p.refreshCredentials(ctx); refreshErr == nil {
// Retry the operation once after successful refresh
p.log.Debug("Retrying operation after credential refresh")
return op()
} else {
p.log.WithError(refreshErr).Warn("Failed to refresh credentials, returning original error")
}
}
return err
}
// PrepareContext implements gorm.ConnPool interface
func (p *RefreshableConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
var stmt *sql.Stmt
err := p.tryWithRefresh(ctx, func() error {
p.mu.RLock()
defer p.mu.RUnlock()
var err error
stmt, err = p.db.PrepareContext(ctx, query)
return err
})
return stmt, err
}
// ExecContext implements gorm.ConnPool interface
func (p *RefreshableConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
var result sql.Result
err := p.tryWithRefresh(ctx, func() error {
p.mu.RLock()
defer p.mu.RUnlock()
var err error
result, err = p.db.ExecContext(ctx, query, args...)
return err
})
return result, err
}
// QueryContext implements gorm.ConnPool interface
func (p *RefreshableConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
var rows *sql.Rows
err := p.tryWithRefresh(ctx, func() error {
p.mu.RLock()
defer p.mu.RUnlock()
var err error
rows, err = p.db.QueryContext(ctx, query, args...)
return err
})
return rows, err
}
// QueryRowContext implements gorm.ConnPool interface
func (p *RefreshableConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
// Note: sql.Row doesn't return errors until Scan() is called, so we can't detect auth errors here
// The error will be caught in higher-level GORM operations
p.mu.RLock()
defer p.mu.RUnlock()
return p.db.QueryRowContext(ctx, query, args...)
}
// BeginTx implements gorm.TxBeginner and gorm.ConnPoolBeginner interfaces
func (p *RefreshableConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
var tx *sql.Tx
err := p.tryWithRefresh(ctx, func() error {
p.mu.RLock()
defer p.mu.RUnlock()
var err error
tx, err = p.db.BeginTx(ctx, opts)
return err
})
if err != nil {
return nil, err
}
return &refreshableTx{Tx: tx, pool: p}, nil
}
// refreshableTx wraps sql.Tx to implement gorm.ConnPool
type refreshableTx struct {
*sql.Tx
pool *RefreshableConnPool
}
func (tx *refreshableTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return tx.Tx.PrepareContext(ctx, query)
}
func (tx *refreshableTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return tx.Tx.ExecContext(ctx, query, args...)
}
func (tx *refreshableTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return tx.Tx.QueryContext(ctx, query, args...)
}
func (tx *refreshableTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
return tx.Tx.QueryRowContext(ctx, query, args...)
}
// Close closes the underlying database connection
func (p *RefreshableConnPool) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.db != nil {
return p.db.Close()
}
return nil
}
// Ping verifies the connection is alive
func (p *RefreshableConnPool) Ping(ctx context.Context) error {
p.mu.RLock()
defer p.mu.RUnlock()
return p.db.PingContext(ctx)
}
// GetDB returns the underlying sql.DB for connection pool configuration
func (p *RefreshableConnPool) GetDB() *sql.DB {
p.mu.RLock()
defer p.mu.RUnlock()
return p.db
}
// NewGORMDB creates a GORM DB instance using the refreshable connection pool
func (p *RefreshableConnPool) NewGORMDB() (*gorm.DB, error) {
dialector := postgres.New(postgres.Config{
Conn: p,
})
return gorm.Open(dialector, p.gormConfig)
}
// Ensure RefreshableConnPool implements required interfaces
var (
_ gorm.ConnPool = (*RefreshableConnPool)(nil)
_ gorm.ConnPoolBeginner = (*RefreshableConnPool)(nil)
_ driver.Pinger = (*RefreshableConnPool)(nil)
)
@@ -0,0 +1,415 @@
package postgresstore
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"goauthentik.io/internal/config"
)
func TestRefreshableConnPool_CredentialRefresh(t *testing.T) {
// Skip if no PostgreSQL available
if os.Getenv("AUTHENTIK_POSTGRESQL__HOST") == "" {
t.Skip("Skipping test: no PostgreSQL configured")
}
// Create a temporary file for password rotation
tmpDir := t.TempDir()
passwordFile := filepath.Join(tmpDir, "db_password")
// Write initial password
initialPassword := os.Getenv("AUTHENTIK_POSTGRESQL__PASSWORD")
if initialPassword == "" {
initialPassword = "postgres"
}
err := os.WriteFile(passwordFile, []byte(initialPassword), 0600)
require.NoError(t, err)
// Set up config to use file:// URI for password
originalPassword := os.Getenv("AUTHENTIK_POSTGRESQL__PASSWORD")
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile))
defer func() {
if originalPassword != "" {
_ = os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", originalPassword)
} else {
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
}
}()
// Reload config
cfg := &config.Config{}
cfg.Setup()
// Build initial DSN
dsn, err := BuildDSN(cfg.PostgreSQL)
require.NoError(t, err)
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
NowFunc: func() time.Time {
return time.Now().UTC()
},
}
// Create refreshable connection pool
pool, err := NewRefreshableConnPool(dsn, gormConfig, 10, 100, time.Hour)
require.NoError(t, err)
defer func() { _ = pool.Close() }()
// Test initial connection works
ctx := context.Background()
err = pool.Ping(ctx)
assert.NoError(t, err, "Initial connection should work")
// Create GORM DB
db, err := pool.NewGORMDB()
require.NoError(t, err)
// Execute a test query
var result int
err = db.WithContext(ctx).Raw("SELECT 1").Scan(&result).Error
assert.NoError(t, err, "Initial query should succeed")
assert.Equal(t, 1, result)
// Simulate password change by writing to file
// In real scenario, this would be an external process updating the file
time.Sleep(100 * time.Millisecond) // Small delay to ensure file modification time changes
err = os.WriteFile(passwordFile, []byte(initialPassword), 0600)
require.NoError(t, err)
// Execute another query - should trigger credential refresh check
err = db.WithContext(ctx).Raw("SELECT 2").Scan(&result).Error
assert.NoError(t, err, "Query after credential refresh should succeed")
assert.Equal(t, 2, result)
}
func TestRefreshableConnPool_Interfaces(t *testing.T) {
// Verify that RefreshableConnPool implements required interfaces at compile time
// This test will fail to compile if interfaces are not properly implemented
var pool *RefreshableConnPool
// Test gorm.ConnPool interface
var _ gorm.ConnPool = pool
// Test gorm.ConnPoolBeginner interface
var _ gorm.ConnPoolBeginner = pool
}
func TestRefreshableConnPool_ConcurrentAccess(t *testing.T) {
// Skip if no PostgreSQL available
if os.Getenv("AUTHENTIK_POSTGRESQL__HOST") == "" {
t.Skip("Skipping test: no PostgreSQL configured")
}
cfg := config.Get()
dsn, err := BuildDSN(cfg.PostgreSQL)
require.NoError(t, err)
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}
pool, err := NewRefreshableConnPool(dsn, gormConfig, 10, 100, time.Hour)
require.NoError(t, err)
defer func() { _ = pool.Close() }()
db, err := pool.NewGORMDB()
require.NoError(t, err)
// Test concurrent queries
ctx := context.Background()
numGoroutines := 10
numQueries := 5
errChan := make(chan error, numGoroutines*numQueries)
for i := 0; i < numGoroutines; i++ {
go func(goroutineID int) {
for j := 0; j < numQueries; j++ {
var result int
err := db.WithContext(ctx).Raw("SELECT ?", goroutineID*numQueries+j).Scan(&result).Error
if err != nil {
errChan <- err
}
}
}(i)
}
// Wait a bit for goroutines to complete
time.Sleep(2 * time.Second)
close(errChan)
// Check for any errors
for err := range errChan {
assert.NoError(t, err, "Concurrent queries should succeed")
}
}
func TestRefreshableConnPool_InvalidCredentials(t *testing.T) {
// Create a pool with invalid credentials
invalidDSN := "host=localhost port=5432 user=invalid password=invalid dbname=invalid sslmode=disable"
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}
pool, err := NewRefreshableConnPool(invalidDSN, gormConfig, 10, 100, time.Hour)
if err != nil {
// sql.Open may succeed even with invalid credentials (lazy connection)
return
}
defer func() { _ = pool.Close() }()
// Ping should fail with invalid credentials
ctx := context.Background()
err = pool.Ping(ctx)
assert.Error(t, err, "Ping with invalid credentials should fail")
}
func TestConfig_RefreshPostgreSQLConfig_FileURI(t *testing.T) {
// Create temporary files for testing file:// URIs
tmpDir := t.TempDir()
passwordFile := filepath.Join(tmpDir, "password")
userFile := filepath.Join(tmpDir, "user")
hostFile := filepath.Join(tmpDir, "host")
err := os.WriteFile(passwordFile, []byte("secret_password"), 0600)
require.NoError(t, err)
err = os.WriteFile(userFile, []byte("dbuser"), 0600)
require.NoError(t, err)
err = os.WriteFile(hostFile, []byte("db.example.com"), 0600)
require.NoError(t, err)
// Set up environment variables with file:// URIs
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "file://"+userFile))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "file://"+hostFile))
defer func() {
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST")
}()
// Create and setup config
cfg := &config.Config{}
cfg.Setup()
// Test initial values are parsed correctly
assert.Equal(t, "secret_password", cfg.PostgreSQL.Password, "Initial password should be parsed from file")
assert.Equal(t, "dbuser", cfg.PostgreSQL.User, "Initial user should be parsed from file")
assert.Equal(t, "db.example.com", cfg.PostgreSQL.Host, "Initial host should be parsed from file")
// Test RefreshPostgreSQLConfig returns same values initially
refreshed := cfg.RefreshPostgreSQLConfig()
assert.Equal(t, "secret_password", refreshed.Password)
assert.Equal(t, "dbuser", refreshed.User)
assert.Equal(t, "db.example.com", refreshed.Host)
// Update password file (simulating credential rotation)
err = os.WriteFile(passwordFile, []byte("new_password"), 0600)
require.NoError(t, err)
// Update user file
err = os.WriteFile(userFile, []byte("new_dbuser"), 0600)
require.NoError(t, err)
// Refresh should pick up new values from files
refreshed = cfg.RefreshPostgreSQLConfig()
assert.Equal(t, "new_password", refreshed.Password, "Password should be refreshed from file")
assert.Equal(t, "new_dbuser", refreshed.User, "User should be refreshed from file")
// Original config struct should still have old values (not mutated)
assert.Equal(t, "secret_password", cfg.PostgreSQL.Password, "Original config should not be mutated")
}
func TestConfig_RefreshPostgreSQLConfig_EnvURI(t *testing.T) {
// Test with env:// URIs (referencing other env vars)
require.NoError(t, os.Setenv("DB_PASSWORD", "env_password"))
require.NoError(t, os.Setenv("DB_USER", "env_user"))
require.NoError(t, os.Setenv("DB_HOST", "env_host"))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "env://DB_PASSWORD"))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "env://DB_USER"))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "env://DB_HOST"))
defer func() {
_ = os.Unsetenv("DB_PASSWORD")
_ = os.Unsetenv("DB_USER")
_ = os.Unsetenv("DB_HOST")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST")
}()
cfg := &config.Config{}
cfg.Setup()
// Test initial values are parsed correctly
assert.Equal(t, "env_password", cfg.PostgreSQL.Password, "Initial password should be parsed from env")
assert.Equal(t, "env_user", cfg.PostgreSQL.User, "Initial user should be parsed from env")
assert.Equal(t, "env_host", cfg.PostgreSQL.Host, "Initial host should be parsed from env")
// Test RefreshPostgreSQLConfig
refreshed := cfg.RefreshPostgreSQLConfig()
assert.Equal(t, "env_password", refreshed.Password)
assert.Equal(t, "env_user", refreshed.User)
assert.Equal(t, "env_host", refreshed.Host)
// Change referenced environment variables (simulating credential rotation)
require.NoError(t, os.Setenv("DB_PASSWORD", "new_env_password"))
require.NoError(t, os.Setenv("DB_USER", "new_env_user"))
// Refresh should pick up new values
refreshed = cfg.RefreshPostgreSQLConfig()
assert.Equal(t, "new_env_password", refreshed.Password, "Password should be refreshed from env")
assert.Equal(t, "new_env_user", refreshed.User, "User should be refreshed from env")
// Original config struct should still have old values (not mutated)
assert.Equal(t, "env_password", cfg.PostgreSQL.Password, "Original config should not be mutated")
}
func TestConfig_RefreshPostgreSQLConfig_PlainValues(t *testing.T) {
// Test with plain values (no URI scheme)
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "plain_password"))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "plain_user"))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "localhost"))
defer func() {
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST")
}()
cfg := &config.Config{}
cfg.Setup()
// Test initial values
assert.Equal(t, "plain_password", cfg.PostgreSQL.Password)
assert.Equal(t, "plain_user", cfg.PostgreSQL.User)
assert.Equal(t, "localhost", cfg.PostgreSQL.Host)
// Test refresh returns same values
refreshed := cfg.RefreshPostgreSQLConfig()
assert.Equal(t, "plain_password", refreshed.Password)
assert.Equal(t, "plain_user", refreshed.User)
assert.Equal(t, "localhost", refreshed.Host)
// Change env vars
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "new_plain_password"))
// Refresh should pick up new plain value
refreshed = cfg.RefreshPostgreSQLConfig()
assert.Equal(t, "new_plain_password", refreshed.Password, "Plain password should be refreshed")
}
func TestConfig_RefreshPostgreSQLConfig_MixedSources(t *testing.T) {
// Test with mixed sources: file://, env://, and plain
tmpDir := t.TempDir()
passwordFile := filepath.Join(tmpDir, "password")
err := os.WriteFile(passwordFile, []byte("file_password"), 0600)
require.NoError(t, err)
require.NoError(t, os.Setenv("DB_USER_VAR", "env_user"))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "env://DB_USER_VAR"))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "plain_host"))
defer func() {
_ = os.Unsetenv("DB_USER_VAR")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER")
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST")
}()
cfg := &config.Config{}
cfg.Setup()
// Test initial values
assert.Equal(t, "file_password", cfg.PostgreSQL.Password)
assert.Equal(t, "env_user", cfg.PostgreSQL.User)
assert.Equal(t, "plain_host", cfg.PostgreSQL.Host)
// Update all sources
err = os.WriteFile(passwordFile, []byte("new_file_password"), 0600)
require.NoError(t, err)
require.NoError(t, os.Setenv("DB_USER_VAR", "new_env_user"))
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "new_plain_host"))
// Refresh should pick up all changes
refreshed := cfg.RefreshPostgreSQLConfig()
assert.Equal(t, "new_file_password", refreshed.Password, "File password should be refreshed")
assert.Equal(t, "new_env_user", refreshed.User, "Env user should be refreshed")
assert.Equal(t, "new_plain_host", refreshed.Host, "Plain host should be refreshed")
}
func TestIsAuthError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "generic error",
err: assert.AnError,
expected: false,
},
{
name: "postgres error code 28000 - invalid_authorization_specification",
err: &pgconn.PgError{
Code: "28000",
Message: "invalid authorization specification",
},
expected: true,
},
{
name: "postgres error code 28P01 - invalid_password",
err: &pgconn.PgError{
Code: "28P01",
Message: "password authentication failed for user",
},
expected: true,
},
{
name: "postgres error code 28P02 - invalid_password (deprecated)",
err: &pgconn.PgError{
Code: "28P02",
Message: "invalid password",
},
expected: true,
},
{
name: "postgres error code 42P01 - undefined_table (not auth error)",
err: &pgconn.PgError{
Code: "42P01",
Message: "relation does not exist",
},
expected: false,
},
{
name: "postgres error code 23505 - unique_violation (not auth error)",
err: &pgconn.PgError{
Code: "23505",
Message: "duplicate key value violates unique constraint",
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isAuthError(tt.err)
assert.Equal(t, tt.expected, result)
})
}
}
@@ -14,7 +14,6 @@ import (
"github.com/gorilla/sessions"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
@@ -26,7 +25,8 @@ import (
// PostgresStore stores gorilla sessions in PostgreSQL using GORM
type PostgresStore struct {
db *gorm.DB
db *gorm.DB
pool *RefreshableConnPool // Keep reference to pool for cleanup
// default options to use when a new session is created
options sessions.Options
// key prefix with which the session will be stored
@@ -125,29 +125,32 @@ func NewPostgresStore() (*PostgresStore, error) {
},
}
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
// Determine connection pool settings
maxIdleConns := 10
maxOpenConns := 100
var connMaxLifetime time.Duration
if cfg.ConnMaxAge > 0 {
connMaxLifetime = time.Duration(cfg.ConnMaxAge) * time.Second
} else {
connMaxLifetime = time.Hour // Default 1 hour
}
// Create refreshable connection pool
pool, err := NewRefreshableConnPool(dsn, gormConfig, maxIdleConns, maxOpenConns, connMaxLifetime)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
// Create GORM DB using the refreshable connection pool
db, err := pool.NewGORMDB()
if err != nil {
_ = pool.Close()
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
}
// Configure connection pool
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
}
// Set connection pool settings
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
if cfg.ConnMaxAge > 0 {
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxAge) * time.Second)
} else {
sqlDB.SetConnMaxLifetime(time.Hour) // Default 1 hour
}
ps := &PostgresStore{
db: db,
db: db,
pool: pool,
options: sessions.Options{
Path: "/",
MaxAge: 86400 * 30, // 30 days default (but overwritten in postgresstore creation based on token validation)
@@ -224,11 +227,10 @@ func (s *PostgresStore) KeyPrefix(keyPrefix string) {
// Close closes the PostgreSQL store
func (s *PostgresStore) Close() error {
sqlDB, err := s.db.DB()
if err != nil {
return err
if s.pool != nil {
return s.pool.Close()
}
return sqlDB.Close()
return nil
}
// save writes session to PostgreSQL