mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
outpost/proxyv2: handle PostgreSQL passwords with spaces and special characters
And modify / add some more tests and a bit of refactoring
This commit is contained in:
@@ -21,7 +21,6 @@ import (
|
||||
type RefreshableConnPool struct {
|
||||
mu sync.RWMutex
|
||||
db *sql.DB
|
||||
dsnBuilder func(config.PostgreSQLConfig) (string, error)
|
||||
log *log.Entry
|
||||
currentDSN string
|
||||
gormConfig *gorm.Config
|
||||
@@ -49,7 +48,6 @@ func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleC
|
||||
|
||||
pool := &RefreshableConnPool{
|
||||
db: db,
|
||||
dsnBuilder: BuildDSN,
|
||||
log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"),
|
||||
currentDSN: initialDSN,
|
||||
gormConfig: gormConfig,
|
||||
@@ -86,7 +84,7 @@ func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error {
|
||||
|
||||
// Get fresh config
|
||||
cfg := config.Get().RefreshPostgreSQLConfig()
|
||||
newDSN, err := p.dsnBuilder(cfg)
|
||||
newDSN, err := BuildDSN(cfg)
|
||||
if err != nil {
|
||||
p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials")
|
||||
return err
|
||||
|
||||
@@ -2,16 +2,20 @@ package postgresstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
_ "gorm.io/driver/postgres"
|
||||
@@ -49,60 +53,121 @@ func (ProxySession) TableName() string {
|
||||
return "authentik_providers_proxy_proxysession"
|
||||
}
|
||||
|
||||
// BuildDSN constructs a PostgreSQL connection string
|
||||
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
|
||||
// BuildConnConfig constructs a pgx.ConnConfig from PostgreSQL configuration.
|
||||
func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
// Validate required fields
|
||||
if cfg.Host == "" {
|
||||
return "", fmt.Errorf("PostgreSQL host is required")
|
||||
return nil, fmt.Errorf("PostgreSQL host is required")
|
||||
}
|
||||
if cfg.User == "" {
|
||||
return "", fmt.Errorf("PostgreSQL user is required")
|
||||
return nil, fmt.Errorf("PostgreSQL user is required")
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
return "", fmt.Errorf("PostgreSQL database name is required")
|
||||
return nil, fmt.Errorf("PostgreSQL database name is required")
|
||||
}
|
||||
if cfg.Port <= 0 {
|
||||
return "", fmt.Errorf("PostgreSQL port must be positive")
|
||||
return nil, fmt.Errorf("PostgreSQL port must be positive")
|
||||
}
|
||||
|
||||
// Build DSN string with all parameters
|
||||
dsnParts := []string{
|
||||
"host=" + cfg.Host,
|
||||
fmt.Sprintf("port=%d", cfg.Port),
|
||||
"user=" + cfg.User,
|
||||
"dbname=" + cfg.Name,
|
||||
// Start with a default config
|
||||
connConfig, err := pgx.ParseConfig("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Password != "" {
|
||||
dsnParts = append(dsnParts, "password="+cfg.Password)
|
||||
}
|
||||
// Set connection parameters
|
||||
connConfig.Host = cfg.Host
|
||||
connConfig.Port = uint16(cfg.Port)
|
||||
connConfig.User = cfg.User
|
||||
connConfig.Password = cfg.Password
|
||||
connConfig.Database = cfg.Name
|
||||
|
||||
// Add SSL mode
|
||||
// Configure TLS/SSL
|
||||
if cfg.SSLMode != "" {
|
||||
dsnParts = append(dsnParts, "sslmode="+cfg.SSLMode)
|
||||
switch cfg.SSLMode {
|
||||
case "disable":
|
||||
connConfig.TLSConfig = nil
|
||||
case "require", "verify-ca", "verify-full":
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
// Load root CA certificate if provided
|
||||
if cfg.SSLRootCert != "" {
|
||||
caCert, err := os.ReadFile(cfg.SSLRootCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read SSL root certificate: %w", err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse SSL root certificate")
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
// Load client certificate and key if provided
|
||||
if cfg.SSLCert != "" && cfg.SSLKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load SSL client certificate: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// Set verification mode
|
||||
switch cfg.SSLMode {
|
||||
case "require":
|
||||
// Don't verify the server certificate (just encrypt)
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
case "verify-ca":
|
||||
// Verify the certificate is signed by a trusted CA
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
case "verify-full":
|
||||
// Verify the certificate and hostname
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
tlsConfig.ServerName = cfg.Host
|
||||
}
|
||||
|
||||
connConfig.TLSConfig = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
// Add SSL certificates if provided
|
||||
if cfg.SSLRootCert != "" {
|
||||
dsnParts = append(dsnParts, "sslrootcert="+cfg.SSLRootCert)
|
||||
}
|
||||
if cfg.SSLCert != "" {
|
||||
dsnParts = append(dsnParts, "sslcert="+cfg.SSLCert)
|
||||
}
|
||||
if cfg.SSLKey != "" {
|
||||
dsnParts = append(dsnParts, "sslkey="+cfg.SSLKey)
|
||||
// Set runtime params
|
||||
if connConfig.RuntimeParams == nil {
|
||||
connConfig.RuntimeParams = make(map[string]string)
|
||||
}
|
||||
|
||||
if cfg.DefaultSchema != "" {
|
||||
dsnParts = append(dsnParts, "search_path="+cfg.DefaultSchema)
|
||||
connConfig.RuntimeParams["search_path"] = cfg.DefaultSchema
|
||||
}
|
||||
|
||||
// Add connection options if specified
|
||||
// Parse and apply connection options if specified
|
||||
if cfg.ConnOptions != "" {
|
||||
dsnParts = append(dsnParts, cfg.ConnOptions)
|
||||
// Parse key=value pairs from ConnOptions
|
||||
// Format: "key1=value1 key2=value2"
|
||||
pairs := strings.Split(cfg.ConnOptions, " ")
|
||||
for _, pair := range pairs {
|
||||
if pair == "" {
|
||||
continue
|
||||
}
|
||||
kv := strings.SplitN(pair, "=", 2)
|
||||
if len(kv) == 2 {
|
||||
connConfig.RuntimeParams[kv[0]] = kv[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Join parts with spaces
|
||||
return strings.Join(dsnParts, " "), nil
|
||||
return connConfig, nil
|
||||
}
|
||||
|
||||
// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
|
||||
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Register the config and get a connection string
|
||||
// (This approach lets pgx handle all the escaping internally which is quite convenient for say spaces in the password)
|
||||
return stdlib.RegisterConnConfig(connConfig), nil
|
||||
}
|
||||
|
||||
// SetupGORMWithRefreshablePool creates a GORM DB with a refreshable connection pool.
|
||||
|
||||
@@ -2,14 +2,23 @@ package postgresstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
@@ -541,11 +550,11 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDSN(t *testing.T) {
|
||||
func TestBuildConnConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.PostgreSQLConfig
|
||||
expected string
|
||||
validate func(*testing.T, *pgx.ConnConfig)
|
||||
}{
|
||||
{
|
||||
name: "basic configuration",
|
||||
@@ -555,10 +564,16 @@ func TestBuildDSN(t *testing.T) {
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "localhost", cc.Host)
|
||||
assert.Equal(t, uint16(5432), cc.Port)
|
||||
assert.Equal(t, "testuser", cc.User)
|
||||
assert.Equal(t, "testdb", cc.Database)
|
||||
assert.Equal(t, "", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password",
|
||||
name: "with simple password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
@@ -566,7 +581,87 @@ func TestBuildDSN(t *testing.T) {
|
||||
Password: "testpass",
|
||||
Name: "testdb",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb password=testpass",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "testpass", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing spaces",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "my secure password",
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "my secure password", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing single quotes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "pass'word",
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "pass'word", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing backslashes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: `pass\word`,
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, `pass\word`, cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing special characters",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: `p@ss w0rd!#$%^&*()`,
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, `p@ss w0rd!#$%^&*()`, cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with password containing quotes and backslashes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: `my'pass\word"here`,
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, `my'pass\word"here`, cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with passphrase (multiple spaces)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "the quick brown fox jumps over",
|
||||
Name: "testdb",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "the quick brown fox jumps over", cc.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with sslmode=disable",
|
||||
@@ -577,10 +672,12 @@ func TestBuildDSN(t *testing.T) {
|
||||
Name: "testdb",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=disable",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Nil(t, cc.TLSConfig)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with sslmode=require",
|
||||
name: "with sslmode=require (no certs)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
@@ -588,32 +685,10 @@ func TestBuildDSN(t *testing.T) {
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=require",
|
||||
},
|
||||
{
|
||||
name: "with sslmode=prefer",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "prefer",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.NotNil(t, cc.TLSConfig)
|
||||
assert.True(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=prefer",
|
||||
},
|
||||
{
|
||||
name: "with SSL certificates",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "verify-full",
|
||||
SSLRootCert: "/path/to/root.crt",
|
||||
SSLCert: "/path/to/client.crt",
|
||||
SSLKey: "/path/to/client.key",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=verify-full sslrootcert=/path/to/root.crt sslcert=/path/to/client.crt sslkey=/path/to/client.key",
|
||||
},
|
||||
{
|
||||
name: "with custom schema",
|
||||
@@ -624,7 +699,9 @@ func TestBuildDSN(t *testing.T) {
|
||||
Name: "testdb",
|
||||
DefaultSchema: "custom_schema",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb search_path=custom_schema",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "custom_schema", cc.RuntimeParams["search_path"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with connection options",
|
||||
@@ -633,34 +710,192 @@ func TestBuildDSN(t *testing.T) {
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: "connect_timeout=10",
|
||||
ConnOptions: "connect_timeout=10 application_name=authentik",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "10", cc.RuntimeParams["connect_timeout"])
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
expected: "host=localhost port=5432 user=testuser dbname=testdb connect_timeout=10",
|
||||
},
|
||||
{
|
||||
name: "full configuration",
|
||||
name: "full configuration with special password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5433,
|
||||
User: "admin",
|
||||
Password: "secret",
|
||||
Password: "my super secret password!@#",
|
||||
Name: "production",
|
||||
SSLMode: "verify-full",
|
||||
SSLRootCert: "/certs/root.crt",
|
||||
SSLCert: "/certs/client.crt",
|
||||
SSLKey: "/certs/client.key",
|
||||
SSLMode: "require",
|
||||
DefaultSchema: "app_schema",
|
||||
ConnOptions: "application_name=authentik",
|
||||
},
|
||||
expected: "host=db.example.com port=5433 user=admin dbname=production password=secret sslmode=verify-full sslrootcert=/certs/root.crt sslcert=/certs/client.crt sslkey=/certs/client.key search_path=app_schema application_name=authentik",
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "db.example.com", cc.Host)
|
||||
assert.Equal(t, uint16(5433), cc.Port)
|
||||
assert.Equal(t, "admin", cc.User)
|
||||
assert.Equal(t, "my super secret password!@#", cc.Password)
|
||||
assert.Equal(t, "production", cc.Database)
|
||||
assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"])
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := BuildDSN(tt.cfg)
|
||||
result, err := BuildConnConfig(tt.cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
require.NotNil(t, result)
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_WithSSLCertificates tests SSL certificate configuration
|
||||
func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
rootCertPath, clientCertPath, clientKeyPath, cleanup := generateTestCerts(t)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.PostgreSQLConfig
|
||||
validate func(*testing.T, *pgx.ConnConfig)
|
||||
}{
|
||||
{
|
||||
name: "verify-full with all certificates",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "my secure password",
|
||||
Name: "testdb",
|
||||
SSLMode: "verify-full",
|
||||
SSLRootCert: rootCertPath,
|
||||
SSLCert: clientCertPath,
|
||||
SSLKey: clientKeyPath,
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
require.NotNil(t, cc.TLSConfig)
|
||||
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName)
|
||||
assert.NotNil(t, cc.TLSConfig.RootCAs)
|
||||
assert.Len(t, cc.TLSConfig.Certificates, 1)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "verify-ca with root cert only",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "verify-ca",
|
||||
SSLRootCert: rootCertPath,
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
require.NotNil(t, cc.TLSConfig)
|
||||
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
assert.NotNil(t, cc.TLSConfig.RootCAs)
|
||||
assert.Empty(t, cc.TLSConfig.Certificates)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "require with client cert",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
SSLCert: clientCertPath,
|
||||
SSLKey: clientKeyPath,
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
require.NotNil(t, cc.TLSConfig)
|
||||
assert.True(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
assert.Len(t, cc.TLSConfig.Certificates, 1)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full configuration with SSL and special password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5433,
|
||||
User: "admin",
|
||||
Password: "my super secret password!@#",
|
||||
Name: "production",
|
||||
SSLMode: "verify-full",
|
||||
SSLRootCert: rootCertPath,
|
||||
SSLCert: clientCertPath,
|
||||
SSLKey: clientKeyPath,
|
||||
DefaultSchema: "app_schema",
|
||||
ConnOptions: "application_name=authentik",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "db.example.com", cc.Host)
|
||||
assert.Equal(t, uint16(5433), cc.Port)
|
||||
assert.Equal(t, "admin", cc.User)
|
||||
assert.Equal(t, "my super secret password!@#", cc.Password)
|
||||
assert.Equal(t, "production", cc.Database)
|
||||
require.NotNil(t, cc.TLSConfig)
|
||||
assert.False(t, cc.TLSConfig.InsecureSkipVerify)
|
||||
assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName)
|
||||
assert.NotNil(t, cc.TLSConfig.RootCAs)
|
||||
assert.Len(t, cc.TLSConfig.Certificates, 1)
|
||||
assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"])
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := BuildConnConfig(tt.cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildDSN_WithSpecialPasswords tests that BuildDSN can handle passwords with special characters
|
||||
// by verifying the DSN can actually be used to connect to a database
|
||||
func TestBuildDSN_WithSpecialPasswords(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
}{
|
||||
{"space in password", "my password"},
|
||||
{"multiple spaces", "the quick brown fox"},
|
||||
{"single quote", "pass'word"},
|
||||
{"backslash", `pass\word`},
|
||||
{"double quote", `pass"word`},
|
||||
{"special chars", `p@ss!#$%^&*()`},
|
||||
{"mixed special", `my'pass\word"here`},
|
||||
{"unicode", "pässwörd"},
|
||||
{"leading/trailing spaces", " password "},
|
||||
{"tab character", "pass\tword"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: tt.password,
|
||||
Name: "testdb",
|
||||
}
|
||||
|
||||
// Test that BuildDSN doesn't error
|
||||
dsn, err := BuildDSN(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, dsn)
|
||||
|
||||
// Test that BuildConnConfig preserves the password exactly
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.password, connConfig.Password, "Password should be preserved exactly")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -715,3 +950,77 @@ func createSessionData(t *testing.T, claims map[string]interface{}) string {
|
||||
require.NoError(t, err)
|
||||
return string(sessionDataJSON)
|
||||
}
|
||||
|
||||
// generateTestCerts creates temporary SSL certificates for testing
|
||||
func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPath string, cleanup func()) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Generate CA certificate
|
||||
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test CA"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write CA certificate
|
||||
rootCertPath = filepath.Join(tmpDir, "root.crt")
|
||||
rootCertFile, err := os.Create(rootCertPath)
|
||||
require.NoError(t, err)
|
||||
defer rootCertFile.Close()
|
||||
err = pem.Encode(rootCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key
|
||||
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client certificate
|
||||
clientTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Client"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
clientCertDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write client certificate
|
||||
clientCertPath = filepath.Join(tmpDir, "client.crt")
|
||||
clientCertFile, err := os.Create(clientCertPath)
|
||||
require.NoError(t, err)
|
||||
defer clientCertFile.Close()
|
||||
err = pem.Encode(clientCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write client key
|
||||
clientKeyPath = filepath.Join(tmpDir, "client.key")
|
||||
clientKeyFile, err := os.Create(clientKeyPath)
|
||||
require.NoError(t, err)
|
||||
defer clientKeyFile.Close()
|
||||
clientKeyBytes := x509.MarshalPKCS1PrivateKey(clientKey)
|
||||
err = pem.Encode(clientKeyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: clientKeyBytes})
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup = func() {
|
||||
// TempDir cleanup is automatic in Go tests
|
||||
}
|
||||
|
||||
return rootCertPath, clientCertPath, clientKeyPath, cleanup
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user