mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
outpost/proxyv2: postgresstore: credential refresh (#17414)
* outpost/proxyv2: postgresstore: credential refresh * wip * mabye * mabye fix
This commit is contained in:
@@ -182,6 +182,43 @@ func (c *Config) parseScheme(rawVal string) string {
|
||||
return rawVal
|
||||
}
|
||||
|
||||
// RefreshPostgreSQLConfig re-reads PostgreSQL configuration from file:// and env:// URIs
|
||||
// This enables hot-reloading when credentials are rotated by updating the referenced files.
|
||||
// Note: Plain environment variables (without file:// or env:// prefixes) are read from the
|
||||
// process environment and will not change unless the process is restarted or os.Setenv is called.
|
||||
func (c *Config) RefreshPostgreSQLConfig() PostgreSQLConfig {
|
||||
// Start with current config as base
|
||||
refreshed := c.PostgreSQL
|
||||
|
||||
// Manually read from environment variables with proper prefix
|
||||
// We can't use env.Process directly on PostgreSQLConfig because it loses the AUTHENTIK_POSTGRESQL__ prefix
|
||||
// Map of environment variable suffix to config field pointer
|
||||
envVars := map[string]*string{
|
||||
"HOST": &refreshed.Host,
|
||||
"USER": &refreshed.User,
|
||||
"PASSWORD": &refreshed.Password,
|
||||
"NAME": &refreshed.Name,
|
||||
"SSLMODE": &refreshed.SSLMode,
|
||||
"SSLROOTCERT": &refreshed.SSLRootCert,
|
||||
"SSLCERT": &refreshed.SSLCert,
|
||||
"SSLKEY": &refreshed.SSLKey,
|
||||
"DEFAULT_SCHEMA": &refreshed.DefaultSchema,
|
||||
"CONN_OPTIONS": &refreshed.ConnOptions,
|
||||
}
|
||||
|
||||
// Read each environment variable if it exists
|
||||
for suffix, field := range envVars {
|
||||
if val, ok := os.LookupEnv("AUTHENTIK_POSTGRESQL__" + suffix); ok {
|
||||
*field = val
|
||||
}
|
||||
}
|
||||
|
||||
// Process file:// and env:// URI schemes
|
||||
c.walkScheme(&refreshed)
|
||||
|
||||
return refreshed
|
||||
}
|
||||
|
||||
func (c *Config) configureLogger() {
|
||||
switch strings.ToLower(c.LogLevel) {
|
||||
case "trace":
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
package postgresstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"goauthentik.io/internal/config"
|
||||
)
|
||||
|
||||
// RefreshableConnPool wraps sql.DB and refreshes PostgreSQL credentials on authentication errors
|
||||
// This implements gorm.ConnPool interface to allow credential rotation
|
||||
type RefreshableConnPool struct {
|
||||
mu sync.RWMutex
|
||||
db *sql.DB
|
||||
dsnBuilder func(config.PostgreSQLConfig) (string, error)
|
||||
log *log.Entry
|
||||
currentDSN string
|
||||
gormConfig *gorm.Config
|
||||
|
||||
// Connection pool settings (stored for reapplication after reconnection)
|
||||
maxIdleConns int
|
||||
maxOpenConns int
|
||||
connMaxLifetime time.Duration
|
||||
|
||||
// Reconnection management
|
||||
reconnecting sync.Mutex // Prevent concurrent reconnections
|
||||
}
|
||||
|
||||
// NewRefreshableConnPool creates a new connection pool that refreshes credentials from config
|
||||
func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleConns, maxOpenConns int, connMaxLifetime time.Duration) (*RefreshableConnPool, error) {
|
||||
db, err := sql.Open("postgres", initialDSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply connection pool settings
|
||||
db.SetMaxIdleConns(maxIdleConns)
|
||||
db.SetMaxOpenConns(maxOpenConns)
|
||||
db.SetConnMaxLifetime(connMaxLifetime)
|
||||
|
||||
pool := &RefreshableConnPool{
|
||||
db: db,
|
||||
dsnBuilder: BuildDSN,
|
||||
log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"),
|
||||
currentDSN: initialDSN,
|
||||
gormConfig: gormConfig,
|
||||
maxIdleConns: maxIdleConns,
|
||||
maxOpenConns: maxOpenConns,
|
||||
connMaxLifetime: connMaxLifetime,
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// isAuthError checks if an error is a PostgreSQL authentication error
|
||||
func isAuthError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Unwrap the error to find the underlying pgconn.PgError
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
// Check for any PostgreSQL error code in Class 28 (Invalid Authorization Specification)
|
||||
// See https://www.postgresql.org/docs/current/errcodes-appendix.html
|
||||
return len(pgErr.Code) >= 2 && pgErr.Code[:2] == "28"
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// refreshCredentials checks if credentials have changed and reconnects if needed
|
||||
func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error {
|
||||
// Prevent concurrent reconnections
|
||||
p.reconnecting.Lock()
|
||||
defer p.reconnecting.Unlock()
|
||||
|
||||
// Get fresh config
|
||||
cfg := config.Get().RefreshPostgreSQLConfig()
|
||||
newDSN, err := p.dsnBuilder(cfg)
|
||||
if err != nil {
|
||||
p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials")
|
||||
return err
|
||||
}
|
||||
|
||||
p.mu.RLock()
|
||||
dsnChanged := newDSN != p.currentDSN
|
||||
p.mu.RUnlock()
|
||||
|
||||
if !dsnChanged {
|
||||
p.log.Debug("Credentials unchanged, skipping reconnection")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if newDSN == p.currentDSN {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.log.Info("PostgreSQL credentials changed, reconnecting...")
|
||||
|
||||
// Open new connection with fresh credentials
|
||||
newDB, err := sql.Open("postgres", newDSN)
|
||||
if err != nil {
|
||||
p.log.WithError(err).Error("Failed to open new database connection with refreshed credentials")
|
||||
return err
|
||||
}
|
||||
|
||||
// Reapply connection pool settings
|
||||
newDB.SetMaxIdleConns(p.maxIdleConns)
|
||||
newDB.SetMaxOpenConns(p.maxOpenConns)
|
||||
newDB.SetConnMaxLifetime(p.connMaxLifetime)
|
||||
|
||||
// Verify the connection works BEFORE closing old connection
|
||||
if err := newDB.PingContext(ctx); err != nil {
|
||||
p.log.WithError(err).Error("Failed to ping database with new credentials")
|
||||
_ = newDB.Close()
|
||||
// Old connection remains active, pool is still functional
|
||||
return err
|
||||
}
|
||||
|
||||
// Only after successful verification, swap connections
|
||||
oldDB := p.db
|
||||
p.db = newDB
|
||||
p.currentDSN = newDSN
|
||||
|
||||
// Close old connection after swap
|
||||
if oldDB != nil {
|
||||
if err := oldDB.Close(); err != nil {
|
||||
p.log.WithError(err).Warn("Failed to close old database connection")
|
||||
// Not fatal cause new connection is already active
|
||||
}
|
||||
}
|
||||
|
||||
p.log.Info("Successfully reconnected with new PostgreSQL credentials")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryWithRefresh attempts an operation, and if it fails with an auth error, refreshes credentials and retries
|
||||
func (p *RefreshableConnPool) tryWithRefresh(ctx context.Context, op func() error) error {
|
||||
err := op()
|
||||
if err != nil && isAuthError(err) {
|
||||
p.log.WithError(err).Info("Authentication error detected, attempting to refresh credentials")
|
||||
if refreshErr := p.refreshCredentials(ctx); refreshErr == nil {
|
||||
// Retry the operation once after successful refresh
|
||||
p.log.Debug("Retrying operation after credential refresh")
|
||||
return op()
|
||||
} else {
|
||||
p.log.WithError(refreshErr).Warn("Failed to refresh credentials, returning original error")
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// PrepareContext implements gorm.ConnPool interface
|
||||
func (p *RefreshableConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
var stmt *sql.Stmt
|
||||
err := p.tryWithRefresh(ctx, func() error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
var err error
|
||||
stmt, err = p.db.PrepareContext(ctx, query)
|
||||
return err
|
||||
})
|
||||
return stmt, err
|
||||
}
|
||||
|
||||
// ExecContext implements gorm.ConnPool interface
|
||||
func (p *RefreshableConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
var result sql.Result
|
||||
err := p.tryWithRefresh(ctx, func() error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
var err error
|
||||
result, err = p.db.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
// QueryContext implements gorm.ConnPool interface
|
||||
func (p *RefreshableConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
var rows *sql.Rows
|
||||
err := p.tryWithRefresh(ctx, func() error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
var err error
|
||||
rows, err = p.db.QueryContext(ctx, query, args...)
|
||||
return err
|
||||
})
|
||||
return rows, err
|
||||
}
|
||||
|
||||
// QueryRowContext implements gorm.ConnPool interface
|
||||
func (p *RefreshableConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
// Note: sql.Row doesn't return errors until Scan() is called, so we can't detect auth errors here
|
||||
// The error will be caught in higher-level GORM operations
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.db.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// BeginTx implements gorm.TxBeginner and gorm.ConnPoolBeginner interfaces
|
||||
func (p *RefreshableConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
||||
var tx *sql.Tx
|
||||
err := p.tryWithRefresh(ctx, func() error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
var err error
|
||||
tx, err = p.db.BeginTx(ctx, opts)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &refreshableTx{Tx: tx, pool: p}, nil
|
||||
}
|
||||
|
||||
// refreshableTx wraps sql.Tx to implement gorm.ConnPool
|
||||
type refreshableTx struct {
|
||||
*sql.Tx
|
||||
pool *RefreshableConnPool
|
||||
}
|
||||
|
||||
func (tx *refreshableTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
return tx.Tx.PrepareContext(ctx, query)
|
||||
}
|
||||
|
||||
func (tx *refreshableTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
return tx.Tx.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (tx *refreshableTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
return tx.Tx.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (tx *refreshableTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
return tx.Tx.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (p *RefreshableConnPool) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.db != nil {
|
||||
return p.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping verifies the connection is alive
|
||||
func (p *RefreshableConnPool) Ping(ctx context.Context) error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// GetDB returns the underlying sql.DB for connection pool configuration
|
||||
func (p *RefreshableConnPool) GetDB() *sql.DB {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
// NewGORMDB creates a GORM DB instance using the refreshable connection pool
|
||||
func (p *RefreshableConnPool) NewGORMDB() (*gorm.DB, error) {
|
||||
dialector := postgres.New(postgres.Config{
|
||||
Conn: p,
|
||||
})
|
||||
return gorm.Open(dialector, p.gormConfig)
|
||||
}
|
||||
|
||||
// Ensure RefreshableConnPool implements required interfaces
|
||||
var (
|
||||
_ gorm.ConnPool = (*RefreshableConnPool)(nil)
|
||||
_ gorm.ConnPoolBeginner = (*RefreshableConnPool)(nil)
|
||||
_ driver.Pinger = (*RefreshableConnPool)(nil)
|
||||
)
|
||||
@@ -0,0 +1,415 @@
|
||||
package postgresstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"goauthentik.io/internal/config"
|
||||
)
|
||||
|
||||
func TestRefreshableConnPool_CredentialRefresh(t *testing.T) {
|
||||
// Skip if no PostgreSQL available
|
||||
if os.Getenv("AUTHENTIK_POSTGRESQL__HOST") == "" {
|
||||
t.Skip("Skipping test: no PostgreSQL configured")
|
||||
}
|
||||
|
||||
// Create a temporary file for password rotation
|
||||
tmpDir := t.TempDir()
|
||||
passwordFile := filepath.Join(tmpDir, "db_password")
|
||||
|
||||
// Write initial password
|
||||
initialPassword := os.Getenv("AUTHENTIK_POSTGRESQL__PASSWORD")
|
||||
if initialPassword == "" {
|
||||
initialPassword = "postgres"
|
||||
}
|
||||
err := os.WriteFile(passwordFile, []byte(initialPassword), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up config to use file:// URI for password
|
||||
originalPassword := os.Getenv("AUTHENTIK_POSTGRESQL__PASSWORD")
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile))
|
||||
defer func() {
|
||||
if originalPassword != "" {
|
||||
_ = os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", originalPassword)
|
||||
} else {
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
|
||||
}
|
||||
}()
|
||||
|
||||
// Reload config
|
||||
cfg := &config.Config{}
|
||||
cfg.Setup()
|
||||
|
||||
// Build initial DSN
|
||||
dsn, err := BuildDSN(cfg.PostgreSQL)
|
||||
require.NoError(t, err)
|
||||
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
NowFunc: func() time.Time {
|
||||
return time.Now().UTC()
|
||||
},
|
||||
}
|
||||
|
||||
// Create refreshable connection pool
|
||||
pool, err := NewRefreshableConnPool(dsn, gormConfig, 10, 100, time.Hour)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
// Test initial connection works
|
||||
ctx := context.Background()
|
||||
err = pool.Ping(ctx)
|
||||
assert.NoError(t, err, "Initial connection should work")
|
||||
|
||||
// Create GORM DB
|
||||
db, err := pool.NewGORMDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Execute a test query
|
||||
var result int
|
||||
err = db.WithContext(ctx).Raw("SELECT 1").Scan(&result).Error
|
||||
assert.NoError(t, err, "Initial query should succeed")
|
||||
assert.Equal(t, 1, result)
|
||||
|
||||
// Simulate password change by writing to file
|
||||
// In real scenario, this would be an external process updating the file
|
||||
time.Sleep(100 * time.Millisecond) // Small delay to ensure file modification time changes
|
||||
err = os.WriteFile(passwordFile, []byte(initialPassword), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Execute another query - should trigger credential refresh check
|
||||
err = db.WithContext(ctx).Raw("SELECT 2").Scan(&result).Error
|
||||
assert.NoError(t, err, "Query after credential refresh should succeed")
|
||||
assert.Equal(t, 2, result)
|
||||
}
|
||||
|
||||
func TestRefreshableConnPool_Interfaces(t *testing.T) {
|
||||
// Verify that RefreshableConnPool implements required interfaces at compile time
|
||||
// This test will fail to compile if interfaces are not properly implemented
|
||||
var pool *RefreshableConnPool
|
||||
|
||||
// Test gorm.ConnPool interface
|
||||
var _ gorm.ConnPool = pool
|
||||
|
||||
// Test gorm.ConnPoolBeginner interface
|
||||
var _ gorm.ConnPoolBeginner = pool
|
||||
}
|
||||
|
||||
func TestRefreshableConnPool_ConcurrentAccess(t *testing.T) {
|
||||
// Skip if no PostgreSQL available
|
||||
if os.Getenv("AUTHENTIK_POSTGRESQL__HOST") == "" {
|
||||
t.Skip("Skipping test: no PostgreSQL configured")
|
||||
}
|
||||
|
||||
cfg := config.Get()
|
||||
dsn, err := BuildDSN(cfg.PostgreSQL)
|
||||
require.NoError(t, err)
|
||||
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
pool, err := NewRefreshableConnPool(dsn, gormConfig, 10, 100, time.Hour)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
db, err := pool.NewGORMDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test concurrent queries
|
||||
ctx := context.Background()
|
||||
numGoroutines := 10
|
||||
numQueries := 5
|
||||
|
||||
errChan := make(chan error, numGoroutines*numQueries)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(goroutineID int) {
|
||||
for j := 0; j < numQueries; j++ {
|
||||
var result int
|
||||
err := db.WithContext(ctx).Raw("SELECT ?", goroutineID*numQueries+j).Scan(&result).Error
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait a bit for goroutines to complete
|
||||
time.Sleep(2 * time.Second)
|
||||
close(errChan)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errChan {
|
||||
assert.NoError(t, err, "Concurrent queries should succeed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshableConnPool_InvalidCredentials(t *testing.T) {
|
||||
// Create a pool with invalid credentials
|
||||
invalidDSN := "host=localhost port=5432 user=invalid password=invalid dbname=invalid sslmode=disable"
|
||||
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
pool, err := NewRefreshableConnPool(invalidDSN, gormConfig, 10, 100, time.Hour)
|
||||
if err != nil {
|
||||
// sql.Open may succeed even with invalid credentials (lazy connection)
|
||||
return
|
||||
}
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
// Ping should fail with invalid credentials
|
||||
ctx := context.Background()
|
||||
err = pool.Ping(ctx)
|
||||
assert.Error(t, err, "Ping with invalid credentials should fail")
|
||||
}
|
||||
|
||||
func TestConfig_RefreshPostgreSQLConfig_FileURI(t *testing.T) {
|
||||
// Create temporary files for testing file:// URIs
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
passwordFile := filepath.Join(tmpDir, "password")
|
||||
userFile := filepath.Join(tmpDir, "user")
|
||||
hostFile := filepath.Join(tmpDir, "host")
|
||||
|
||||
err := os.WriteFile(passwordFile, []byte("secret_password"), 0600)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(userFile, []byte("dbuser"), 0600)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(hostFile, []byte("db.example.com"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up environment variables with file:// URIs
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "file://"+userFile))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "file://"+hostFile))
|
||||
defer func() {
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST")
|
||||
}()
|
||||
|
||||
// Create and setup config
|
||||
cfg := &config.Config{}
|
||||
cfg.Setup()
|
||||
|
||||
// Test initial values are parsed correctly
|
||||
assert.Equal(t, "secret_password", cfg.PostgreSQL.Password, "Initial password should be parsed from file")
|
||||
assert.Equal(t, "dbuser", cfg.PostgreSQL.User, "Initial user should be parsed from file")
|
||||
assert.Equal(t, "db.example.com", cfg.PostgreSQL.Host, "Initial host should be parsed from file")
|
||||
|
||||
// Test RefreshPostgreSQLConfig returns same values initially
|
||||
refreshed := cfg.RefreshPostgreSQLConfig()
|
||||
assert.Equal(t, "secret_password", refreshed.Password)
|
||||
assert.Equal(t, "dbuser", refreshed.User)
|
||||
assert.Equal(t, "db.example.com", refreshed.Host)
|
||||
|
||||
// Update password file (simulating credential rotation)
|
||||
err = os.WriteFile(passwordFile, []byte("new_password"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update user file
|
||||
err = os.WriteFile(userFile, []byte("new_dbuser"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh should pick up new values from files
|
||||
refreshed = cfg.RefreshPostgreSQLConfig()
|
||||
assert.Equal(t, "new_password", refreshed.Password, "Password should be refreshed from file")
|
||||
assert.Equal(t, "new_dbuser", refreshed.User, "User should be refreshed from file")
|
||||
|
||||
// Original config struct should still have old values (not mutated)
|
||||
assert.Equal(t, "secret_password", cfg.PostgreSQL.Password, "Original config should not be mutated")
|
||||
}
|
||||
|
||||
func TestConfig_RefreshPostgreSQLConfig_EnvURI(t *testing.T) {
|
||||
// Test with env:// URIs (referencing other env vars)
|
||||
require.NoError(t, os.Setenv("DB_PASSWORD", "env_password"))
|
||||
require.NoError(t, os.Setenv("DB_USER", "env_user"))
|
||||
require.NoError(t, os.Setenv("DB_HOST", "env_host"))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "env://DB_PASSWORD"))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "env://DB_USER"))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "env://DB_HOST"))
|
||||
defer func() {
|
||||
_ = os.Unsetenv("DB_PASSWORD")
|
||||
_ = os.Unsetenv("DB_USER")
|
||||
_ = os.Unsetenv("DB_HOST")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST")
|
||||
}()
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Setup()
|
||||
|
||||
// Test initial values are parsed correctly
|
||||
assert.Equal(t, "env_password", cfg.PostgreSQL.Password, "Initial password should be parsed from env")
|
||||
assert.Equal(t, "env_user", cfg.PostgreSQL.User, "Initial user should be parsed from env")
|
||||
assert.Equal(t, "env_host", cfg.PostgreSQL.Host, "Initial host should be parsed from env")
|
||||
|
||||
// Test RefreshPostgreSQLConfig
|
||||
refreshed := cfg.RefreshPostgreSQLConfig()
|
||||
assert.Equal(t, "env_password", refreshed.Password)
|
||||
assert.Equal(t, "env_user", refreshed.User)
|
||||
assert.Equal(t, "env_host", refreshed.Host)
|
||||
|
||||
// Change referenced environment variables (simulating credential rotation)
|
||||
require.NoError(t, os.Setenv("DB_PASSWORD", "new_env_password"))
|
||||
require.NoError(t, os.Setenv("DB_USER", "new_env_user"))
|
||||
|
||||
// Refresh should pick up new values
|
||||
refreshed = cfg.RefreshPostgreSQLConfig()
|
||||
assert.Equal(t, "new_env_password", refreshed.Password, "Password should be refreshed from env")
|
||||
assert.Equal(t, "new_env_user", refreshed.User, "User should be refreshed from env")
|
||||
|
||||
// Original config struct should still have old values (not mutated)
|
||||
assert.Equal(t, "env_password", cfg.PostgreSQL.Password, "Original config should not be mutated")
|
||||
}
|
||||
|
||||
func TestConfig_RefreshPostgreSQLConfig_PlainValues(t *testing.T) {
|
||||
// Test with plain values (no URI scheme)
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "plain_password"))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "plain_user"))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "localhost"))
|
||||
defer func() {
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST")
|
||||
}()
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Setup()
|
||||
|
||||
// Test initial values
|
||||
assert.Equal(t, "plain_password", cfg.PostgreSQL.Password)
|
||||
assert.Equal(t, "plain_user", cfg.PostgreSQL.User)
|
||||
assert.Equal(t, "localhost", cfg.PostgreSQL.Host)
|
||||
|
||||
// Test refresh returns same values
|
||||
refreshed := cfg.RefreshPostgreSQLConfig()
|
||||
assert.Equal(t, "plain_password", refreshed.Password)
|
||||
assert.Equal(t, "plain_user", refreshed.User)
|
||||
assert.Equal(t, "localhost", refreshed.Host)
|
||||
|
||||
// Change env vars
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "new_plain_password"))
|
||||
|
||||
// Refresh should pick up new plain value
|
||||
refreshed = cfg.RefreshPostgreSQLConfig()
|
||||
assert.Equal(t, "new_plain_password", refreshed.Password, "Plain password should be refreshed")
|
||||
}
|
||||
|
||||
func TestConfig_RefreshPostgreSQLConfig_MixedSources(t *testing.T) {
|
||||
// Test with mixed sources: file://, env://, and plain
|
||||
tmpDir := t.TempDir()
|
||||
passwordFile := filepath.Join(tmpDir, "password")
|
||||
err := os.WriteFile(passwordFile, []byte("file_password"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, os.Setenv("DB_USER_VAR", "env_user"))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__PASSWORD", "file://"+passwordFile))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__USER", "env://DB_USER_VAR"))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "plain_host"))
|
||||
defer func() {
|
||||
_ = os.Unsetenv("DB_USER_VAR")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__PASSWORD")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__USER")
|
||||
_ = os.Unsetenv("AUTHENTIK_POSTGRESQL__HOST")
|
||||
}()
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Setup()
|
||||
|
||||
// Test initial values
|
||||
assert.Equal(t, "file_password", cfg.PostgreSQL.Password)
|
||||
assert.Equal(t, "env_user", cfg.PostgreSQL.User)
|
||||
assert.Equal(t, "plain_host", cfg.PostgreSQL.Host)
|
||||
|
||||
// Update all sources
|
||||
err = os.WriteFile(passwordFile, []byte("new_file_password"), 0600)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.Setenv("DB_USER_VAR", "new_env_user"))
|
||||
require.NoError(t, os.Setenv("AUTHENTIK_POSTGRESQL__HOST", "new_plain_host"))
|
||||
|
||||
// Refresh should pick up all changes
|
||||
refreshed := cfg.RefreshPostgreSQLConfig()
|
||||
assert.Equal(t, "new_file_password", refreshed.Password, "File password should be refreshed")
|
||||
assert.Equal(t, "new_env_user", refreshed.User, "Env user should be refreshed")
|
||||
assert.Equal(t, "new_plain_host", refreshed.Host, "Plain host should be refreshed")
|
||||
}
|
||||
|
||||
func TestIsAuthError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "generic error",
|
||||
err: assert.AnError,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "postgres error code 28000 - invalid_authorization_specification",
|
||||
err: &pgconn.PgError{
|
||||
Code: "28000",
|
||||
Message: "invalid authorization specification",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "postgres error code 28P01 - invalid_password",
|
||||
err: &pgconn.PgError{
|
||||
Code: "28P01",
|
||||
Message: "password authentication failed for user",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "postgres error code 28P02 - invalid_password (deprecated)",
|
||||
err: &pgconn.PgError{
|
||||
Code: "28P02",
|
||||
Message: "invalid password",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "postgres error code 42P01 - undefined_table (not auth error)",
|
||||
err: &pgconn.PgError{
|
||||
Code: "42P01",
|
||||
Message: "relation does not exist",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "postgres error code 23505 - unique_violation (not auth error)",
|
||||
err: &pgconn.PgError{
|
||||
Code: "23505",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isAuthError(tt.err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -26,7 +25,8 @@ import (
|
||||
|
||||
// PostgresStore stores gorilla sessions in PostgreSQL using GORM
|
||||
type PostgresStore struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
pool *RefreshableConnPool // Keep reference to pool for cleanup
|
||||
// default options to use when a new session is created
|
||||
options sessions.Options
|
||||
// key prefix with which the session will be stored
|
||||
@@ -125,29 +125,32 @@ func NewPostgresStore() (*PostgresStore, error) {
|
||||
},
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
||||
// Determine connection pool settings
|
||||
maxIdleConns := 10
|
||||
maxOpenConns := 100
|
||||
var connMaxLifetime time.Duration
|
||||
if cfg.ConnMaxAge > 0 {
|
||||
connMaxLifetime = time.Duration(cfg.ConnMaxAge) * time.Second
|
||||
} else {
|
||||
connMaxLifetime = time.Hour // Default 1 hour
|
||||
}
|
||||
|
||||
// Create refreshable connection pool
|
||||
pool, err := NewRefreshableConnPool(dsn, gormConfig, maxIdleConns, maxOpenConns, connMaxLifetime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// Create GORM DB using the refreshable connection pool
|
||||
db, err := pool.NewGORMDB()
|
||||
if err != nil {
|
||||
_ = pool.Close()
|
||||
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool settings
|
||||
sqlDB.SetMaxIdleConns(10)
|
||||
sqlDB.SetMaxOpenConns(100)
|
||||
|
||||
if cfg.ConnMaxAge > 0 {
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxAge) * time.Second)
|
||||
} else {
|
||||
sqlDB.SetConnMaxLifetime(time.Hour) // Default 1 hour
|
||||
}
|
||||
|
||||
ps := &PostgresStore{
|
||||
db: db,
|
||||
db: db,
|
||||
pool: pool,
|
||||
options: sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: 86400 * 30, // 30 days default (but overwritten in postgresstore creation based on token validation)
|
||||
@@ -224,11 +227,10 @@ func (s *PostgresStore) KeyPrefix(keyPrefix string) {
|
||||
|
||||
// Close closes the PostgreSQL store
|
||||
func (s *PostgresStore) Close() error {
|
||||
sqlDB, err := s.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
if s.pool != nil {
|
||||
return s.pool.Close()
|
||||
}
|
||||
return sqlDB.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// save writes session to PostgreSQL
|
||||
|
||||
Reference in New Issue
Block a user