outpost/proxyv2: handle PostgreSQL passwords with spaces and special characters

And modify / add some more tests and a bit of refactoring
This commit is contained in:
Dominic R
2025-11-17 21:27:51 -05:00
parent 35329991ef
commit 86c0128dae
3 changed files with 451 additions and 79 deletions
@@ -21,7 +21,6 @@ import (
type RefreshableConnPool struct {
mu sync.RWMutex
db *sql.DB
dsnBuilder func(config.PostgreSQLConfig) (string, error)
log *log.Entry
currentDSN string
gormConfig *gorm.Config
@@ -49,7 +48,6 @@ func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleC
pool := &RefreshableConnPool{
db: db,
dsnBuilder: BuildDSN,
log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"),
currentDSN: initialDSN,
gormConfig: gormConfig,
@@ -86,7 +84,7 @@ func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error {
// Get fresh config
cfg := config.Get().RefreshPostgreSQLConfig()
newDSN, err := p.dsnBuilder(cfg)
newDSN, err := BuildDSN(cfg)
if err != nil {
p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials")
return err
@@ -2,16 +2,20 @@ package postgresstore
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
_ "gorm.io/driver/postgres"
@@ -49,60 +53,121 @@ func (ProxySession) TableName() string {
return "authentik_providers_proxy_proxysession"
}
// BuildDSN constructs a PostgreSQL connection string
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
// BuildConnConfig constructs a pgx.ConnConfig from PostgreSQL configuration.
func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
// Validate required fields
if cfg.Host == "" {
return "", fmt.Errorf("PostgreSQL host is required")
return nil, fmt.Errorf("PostgreSQL host is required")
}
if cfg.User == "" {
return "", fmt.Errorf("PostgreSQL user is required")
return nil, fmt.Errorf("PostgreSQL user is required")
}
if cfg.Name == "" {
return "", fmt.Errorf("PostgreSQL database name is required")
return nil, fmt.Errorf("PostgreSQL database name is required")
}
if cfg.Port <= 0 {
return "", fmt.Errorf("PostgreSQL port must be positive")
return nil, fmt.Errorf("PostgreSQL port must be positive")
}
// Build DSN string with all parameters
dsnParts := []string{
"host=" + cfg.Host,
fmt.Sprintf("port=%d", cfg.Port),
"user=" + cfg.User,
"dbname=" + cfg.Name,
// Start with a default config
connConfig, err := pgx.ParseConfig("")
if err != nil {
return nil, fmt.Errorf("failed to create default config: %w", err)
}
if cfg.Password != "" {
dsnParts = append(dsnParts, "password="+cfg.Password)
}
// Set connection parameters
connConfig.Host = cfg.Host
connConfig.Port = uint16(cfg.Port)
connConfig.User = cfg.User
connConfig.Password = cfg.Password
connConfig.Database = cfg.Name
// Add SSL mode
// Configure TLS/SSL
if cfg.SSLMode != "" {
dsnParts = append(dsnParts, "sslmode="+cfg.SSLMode)
switch cfg.SSLMode {
case "disable":
connConfig.TLSConfig = nil
case "require", "verify-ca", "verify-full":
tlsConfig := &tls.Config{}
// Load root CA certificate if provided
if cfg.SSLRootCert != "" {
caCert, err := os.ReadFile(cfg.SSLRootCert)
if err != nil {
return nil, fmt.Errorf("failed to read SSL root certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse SSL root certificate")
}
tlsConfig.RootCAs = caCertPool
}
// Load client certificate and key if provided
if cfg.SSLCert != "" && cfg.SSLKey != "" {
cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey)
if err != nil {
return nil, fmt.Errorf("failed to load SSL client certificate: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
// Set verification mode
switch cfg.SSLMode {
case "require":
// Don't verify the server certificate (just encrypt)
tlsConfig.InsecureSkipVerify = true
case "verify-ca":
// Verify the certificate is signed by a trusted CA
tlsConfig.InsecureSkipVerify = false
case "verify-full":
// Verify the certificate and hostname
tlsConfig.InsecureSkipVerify = false
tlsConfig.ServerName = cfg.Host
}
connConfig.TLSConfig = tlsConfig
}
}
// Add SSL certificates if provided
if cfg.SSLRootCert != "" {
dsnParts = append(dsnParts, "sslrootcert="+cfg.SSLRootCert)
}
if cfg.SSLCert != "" {
dsnParts = append(dsnParts, "sslcert="+cfg.SSLCert)
}
if cfg.SSLKey != "" {
dsnParts = append(dsnParts, "sslkey="+cfg.SSLKey)
// Set runtime params
if connConfig.RuntimeParams == nil {
connConfig.RuntimeParams = make(map[string]string)
}
if cfg.DefaultSchema != "" {
dsnParts = append(dsnParts, "search_path="+cfg.DefaultSchema)
connConfig.RuntimeParams["search_path"] = cfg.DefaultSchema
}
// Add connection options if specified
// Parse and apply connection options if specified
if cfg.ConnOptions != "" {
dsnParts = append(dsnParts, cfg.ConnOptions)
// Parse key=value pairs from ConnOptions
// Format: "key1=value1 key2=value2"
pairs := strings.Split(cfg.ConnOptions, " ")
for _, pair := range pairs {
if pair == "" {
continue
}
kv := strings.SplitN(pair, "=", 2)
if len(kv) == 2 {
connConfig.RuntimeParams[kv[0]] = kv[1]
}
}
}
// Join parts with spaces
return strings.Join(dsnParts, " "), nil
return connConfig, nil
}
// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
connConfig, err := BuildConnConfig(cfg)
if err != nil {
return "", err
}
// Register the config and get a connection string
// (This approach lets pgx handle all the escaping internally which is quite convenient for say spaces in the password)
return stdlib.RegisterConnConfig(connConfig), nil
}
// SetupGORMWithRefreshablePool creates a GORM DB with a refreshable connection pool.
@@ -2,14 +2,23 @@ package postgresstore
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
@@ -541,11 +550,11 @@ func TestBuildDSN_Validation(t *testing.T) {
}
}
func TestBuildDSN(t *testing.T) {
func TestBuildConnConfig(t *testing.T) {
tests := []struct {
name string
cfg config.PostgreSQLConfig
expected string
validate func(*testing.T, *pgx.ConnConfig)
}{
{
name: "basic configuration",
@@ -555,10 +564,16 @@ func TestBuildDSN(t *testing.T) {
User: "testuser",
Name: "testdb",
},
expected: "host=localhost port=5432 user=testuser dbname=testdb",
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "localhost", cc.Host)
assert.Equal(t, uint16(5432), cc.Port)
assert.Equal(t, "testuser", cc.User)
assert.Equal(t, "testdb", cc.Database)
assert.Equal(t, "", cc.Password)
},
},
{
name: "with password",
name: "with simple password",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
@@ -566,7 +581,87 @@ func TestBuildDSN(t *testing.T) {
Password: "testpass",
Name: "testdb",
},
expected: "host=localhost port=5432 user=testuser dbname=testdb password=testpass",
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "testpass", cc.Password)
},
},
{
name: "with password containing spaces",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Password: "my secure password",
Name: "testdb",
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "my secure password", cc.Password)
},
},
{
name: "with password containing single quotes",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Password: "pass'word",
Name: "testdb",
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "pass'word", cc.Password)
},
},
{
name: "with password containing backslashes",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Password: `pass\word`,
Name: "testdb",
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, `pass\word`, cc.Password)
},
},
{
name: "with password containing special characters",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Password: `p@ss w0rd!#$%^&*()`,
Name: "testdb",
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, `p@ss w0rd!#$%^&*()`, cc.Password)
},
},
{
name: "with password containing quotes and backslashes",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Password: `my'pass\word"here`,
Name: "testdb",
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, `my'pass\word"here`, cc.Password)
},
},
{
name: "with passphrase (multiple spaces)",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Password: "the quick brown fox jumps over",
Name: "testdb",
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "the quick brown fox jumps over", cc.Password)
},
},
{
name: "with sslmode=disable",
@@ -577,10 +672,12 @@ func TestBuildDSN(t *testing.T) {
Name: "testdb",
SSLMode: "disable",
},
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=disable",
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Nil(t, cc.TLSConfig)
},
},
{
name: "with sslmode=require",
name: "with sslmode=require (no certs)",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
@@ -588,32 +685,10 @@ func TestBuildDSN(t *testing.T) {
Name: "testdb",
SSLMode: "require",
},
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=require",
},
{
name: "with sslmode=prefer",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Name: "testdb",
SSLMode: "prefer",
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.NotNil(t, cc.TLSConfig)
assert.True(t, cc.TLSConfig.InsecureSkipVerify)
},
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=prefer",
},
{
name: "with SSL certificates",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Name: "testdb",
SSLMode: "verify-full",
SSLRootCert: "/path/to/root.crt",
SSLCert: "/path/to/client.crt",
SSLKey: "/path/to/client.key",
},
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=verify-full sslrootcert=/path/to/root.crt sslcert=/path/to/client.crt sslkey=/path/to/client.key",
},
{
name: "with custom schema",
@@ -624,7 +699,9 @@ func TestBuildDSN(t *testing.T) {
Name: "testdb",
DefaultSchema: "custom_schema",
},
expected: "host=localhost port=5432 user=testuser dbname=testdb search_path=custom_schema",
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "custom_schema", cc.RuntimeParams["search_path"])
},
},
{
name: "with connection options",
@@ -633,34 +710,192 @@ func TestBuildDSN(t *testing.T) {
Port: 5432,
User: "testuser",
Name: "testdb",
ConnOptions: "connect_timeout=10",
ConnOptions: "connect_timeout=10 application_name=authentik",
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "10", cc.RuntimeParams["connect_timeout"])
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
},
expected: "host=localhost port=5432 user=testuser dbname=testdb connect_timeout=10",
},
{
name: "full configuration",
name: "full configuration with special password",
cfg: config.PostgreSQLConfig{
Host: "db.example.com",
Port: 5433,
User: "admin",
Password: "secret",
Password: "my super secret password!@#",
Name: "production",
SSLMode: "verify-full",
SSLRootCert: "/certs/root.crt",
SSLCert: "/certs/client.crt",
SSLKey: "/certs/client.key",
SSLMode: "require",
DefaultSchema: "app_schema",
ConnOptions: "application_name=authentik",
},
expected: "host=db.example.com port=5433 user=admin dbname=production password=secret sslmode=verify-full sslrootcert=/certs/root.crt sslcert=/certs/client.crt sslkey=/certs/client.key search_path=app_schema application_name=authentik",
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "db.example.com", cc.Host)
assert.Equal(t, uint16(5433), cc.Port)
assert.Equal(t, "admin", cc.User)
assert.Equal(t, "my super secret password!@#", cc.Password)
assert.Equal(t, "production", cc.Database)
assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"])
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := BuildDSN(tt.cfg)
result, err := BuildConnConfig(tt.cfg)
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
require.NotNil(t, result)
tt.validate(t, result)
})
}
}
// TestBuildConnConfig_WithSSLCertificates tests SSL certificate configuration
func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
rootCertPath, clientCertPath, clientKeyPath, cleanup := generateTestCerts(t)
defer cleanup()
tests := []struct {
name string
cfg config.PostgreSQLConfig
validate func(*testing.T, *pgx.ConnConfig)
}{
{
name: "verify-full with all certificates",
cfg: config.PostgreSQLConfig{
Host: "db.example.com",
Port: 5432,
User: "testuser",
Password: "my secure password",
Name: "testdb",
SSLMode: "verify-full",
SSLRootCert: rootCertPath,
SSLCert: clientCertPath,
SSLKey: clientKeyPath,
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
require.NotNil(t, cc.TLSConfig)
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName)
assert.NotNil(t, cc.TLSConfig.RootCAs)
assert.Len(t, cc.TLSConfig.Certificates, 1)
},
},
{
name: "verify-ca with root cert only",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Name: "testdb",
SSLMode: "verify-ca",
SSLRootCert: rootCertPath,
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
require.NotNil(t, cc.TLSConfig)
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
assert.NotNil(t, cc.TLSConfig.RootCAs)
assert.Empty(t, cc.TLSConfig.Certificates)
},
},
{
name: "require with client cert",
cfg: config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Name: "testdb",
SSLMode: "require",
SSLCert: clientCertPath,
SSLKey: clientKeyPath,
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
require.NotNil(t, cc.TLSConfig)
assert.True(t, cc.TLSConfig.InsecureSkipVerify)
assert.Len(t, cc.TLSConfig.Certificates, 1)
},
},
{
name: "full configuration with SSL and special password",
cfg: config.PostgreSQLConfig{
Host: "db.example.com",
Port: 5433,
User: "admin",
Password: "my super secret password!@#",
Name: "production",
SSLMode: "verify-full",
SSLRootCert: rootCertPath,
SSLCert: clientCertPath,
SSLKey: clientKeyPath,
DefaultSchema: "app_schema",
ConnOptions: "application_name=authentik",
},
validate: func(t *testing.T, cc *pgx.ConnConfig) {
assert.Equal(t, "db.example.com", cc.Host)
assert.Equal(t, uint16(5433), cc.Port)
assert.Equal(t, "admin", cc.User)
assert.Equal(t, "my super secret password!@#", cc.Password)
assert.Equal(t, "production", cc.Database)
require.NotNil(t, cc.TLSConfig)
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName)
assert.NotNil(t, cc.TLSConfig.RootCAs)
assert.Len(t, cc.TLSConfig.Certificates, 1)
assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"])
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := BuildConnConfig(tt.cfg)
require.NoError(t, err)
require.NotNil(t, result)
tt.validate(t, result)
})
}
}
// TestBuildDSN_WithSpecialPasswords tests that BuildDSN can handle passwords with special characters
// by verifying the DSN can actually be used to connect to a database
func TestBuildDSN_WithSpecialPasswords(t *testing.T) {
tests := []struct {
name string
password string
}{
{"space in password", "my password"},
{"multiple spaces", "the quick brown fox"},
{"single quote", "pass'word"},
{"backslash", `pass\word`},
{"double quote", `pass"word`},
{"special chars", `p@ss!#$%^&*()`},
{"mixed special", `my'pass\word"here`},
{"unicode", "pässwörd"},
{"leading/trailing spaces", " password "},
{"tab character", "pass\tword"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := config.PostgreSQLConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Password: tt.password,
Name: "testdb",
}
// Test that BuildDSN doesn't error
dsn, err := BuildDSN(cfg)
require.NoError(t, err)
require.NotEmpty(t, dsn)
// Test that BuildConnConfig preserves the password exactly
connConfig, err := BuildConnConfig(cfg)
require.NoError(t, err)
assert.Equal(t, tt.password, connConfig.Password, "Password should be preserved exactly")
})
}
}
@@ -715,3 +950,77 @@ func createSessionData(t *testing.T, claims map[string]interface{}) string {
require.NoError(t, err)
return string(sessionDataJSON)
}
// generateTestCerts creates temporary SSL certificates for testing
func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPath string, cleanup func()) {
tmpDir := t.TempDir()
// Generate CA certificate
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
require.NoError(t, err)
// Write CA certificate
rootCertPath = filepath.Join(tmpDir, "root.crt")
rootCertFile, err := os.Create(rootCertPath)
require.NoError(t, err)
defer rootCertFile.Close()
err = pem.Encode(rootCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})
require.NoError(t, err)
// Generate client key
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
// Generate client certificate
clientTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Client"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey)
require.NoError(t, err)
// Write client certificate
clientCertPath = filepath.Join(tmpDir, "client.crt")
clientCertFile, err := os.Create(clientCertPath)
require.NoError(t, err)
defer clientCertFile.Close()
err = pem.Encode(clientCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})
require.NoError(t, err)
// Write client key
clientKeyPath = filepath.Join(tmpDir, "client.key")
clientKeyFile, err := os.Create(clientKeyPath)
require.NoError(t, err)
defer clientKeyFile.Close()
clientKeyBytes := x509.MarshalPKCS1PrivateKey(clientKey)
err = pem.Encode(clientKeyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: clientKeyBytes})
require.NoError(t, err)
cleanup = func() {
// TempDir cleanup is automatic in Go tests
}
return rootCertPath, clientCertPath, clientKeyPath, cleanup
}