mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
internal/outpost: improve PostgreSQL connection options parsing (#19118)
* internal: Outpost's conn options should be base64 json * correctly parse target_session_attrs + tests * fix port handling to use env provided port * add multiple port handling abilities to mirror the python config parser --------- Co-authored-by: Duncan Tasker <tasatree@gmail.com>
This commit is contained in:
@@ -204,6 +204,7 @@ func (c *Config) RefreshPostgreSQLConfig() PostgreSQLConfig {
|
||||
// Map of environment variable suffix to config field pointer
|
||||
envVars := map[string]*string{
|
||||
"HOST": &refreshed.Host,
|
||||
"PORT": &refreshed.Port,
|
||||
"USER": &refreshed.User,
|
||||
"PASSWORD": &refreshed.Password,
|
||||
"NAME": &refreshed.Name,
|
||||
|
||||
@@ -27,7 +27,7 @@ type Config struct {
|
||||
|
||||
type PostgreSQLConfig struct {
|
||||
Host string `yaml:"host" env:"HOST, overwrite"`
|
||||
Port int `yaml:"port" env:"PORT, overwrite"`
|
||||
Port string `yaml:"port" env:"PORT, overwrite"`
|
||||
User string `yaml:"user" env:"USER, overwrite"`
|
||||
Password string `yaml:"password" env:"PASSWORD, overwrite"`
|
||||
Name string `yaml:"name" env:"NAME, overwrite"`
|
||||
|
||||
@@ -4,21 +4,23 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
_ "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
@@ -65,8 +67,8 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
if cfg.Name == "" {
|
||||
return nil, fmt.Errorf("PostgreSQL database name is required")
|
||||
}
|
||||
if cfg.Port <= 0 {
|
||||
return nil, fmt.Errorf("PostgreSQL port must be positive")
|
||||
if cfg.Port == "" {
|
||||
return nil, fmt.Errorf("PostgreSQL port is required")
|
||||
}
|
||||
|
||||
// Start with a default config
|
||||
@@ -75,9 +77,38 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
return nil, fmt.Errorf("failed to create default config: %w", err)
|
||||
}
|
||||
|
||||
// Set connection parameters
|
||||
connConfig.Host = cfg.Host
|
||||
connConfig.Port = uint16(cfg.Port)
|
||||
// Parse comma-separated hosts and create fallbacks
|
||||
// cfg.Host can be a comma-separated list like "host1,host2,host3"
|
||||
hosts := strings.Split(cfg.Host, ",")
|
||||
for i, host := range hosts {
|
||||
hosts[i] = strings.TrimSpace(host)
|
||||
}
|
||||
|
||||
// Parse and validate comma-separated ports
|
||||
portStrs := strings.Split(cfg.Port, ",")
|
||||
ports := make([]uint16, len(portStrs))
|
||||
for i, portStr := range portStrs {
|
||||
portStr = strings.TrimSpace(portStr)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port value %q: %w", portStr, err)
|
||||
}
|
||||
if port <= 0 {
|
||||
return nil, fmt.Errorf("PostgreSQL port %d must be positive", port)
|
||||
}
|
||||
if port > 65535 {
|
||||
return nil, fmt.Errorf("PostgreSQL port %d is out of valid range", port)
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
|
||||
// Get port for primary host
|
||||
primaryHost := hosts[0]
|
||||
primaryPort := ports[0]
|
||||
|
||||
// Set connection parameters for primary host
|
||||
connConfig.Host = primaryHost
|
||||
connConfig.Port = primaryPort
|
||||
connConfig.User = cfg.User
|
||||
connConfig.Password = cfg.Password
|
||||
connConfig.Database = cfg.Name
|
||||
@@ -123,13 +154,35 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
case "verify-full":
|
||||
// Verify the certificate and hostname
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
tlsConfig.ServerName = cfg.Host
|
||||
tlsConfig.ServerName = primaryHost
|
||||
}
|
||||
|
||||
connConfig.TLSConfig = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
// Create fallback configurations for additional hosts
|
||||
if len(hosts) > 1 {
|
||||
connConfig.Fallbacks = make([]*pgconn.FallbackConfig, 0, len(hosts)-1)
|
||||
for i, host := range hosts[1:] {
|
||||
port := getPortForIndex(ports, i+1)
|
||||
fallback := &pgconn.FallbackConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
// Copy TLS config to fallback if present
|
||||
if connConfig.TLSConfig != nil {
|
||||
fallbackTLS := connConfig.TLSConfig.Clone()
|
||||
// Update ServerName for verify-full mode
|
||||
if cfg.SSLMode == "verify-full" {
|
||||
fallbackTLS.ServerName = host
|
||||
}
|
||||
fallback.TLSConfig = fallbackTLS
|
||||
}
|
||||
connConfig.Fallbacks = append(connConfig.Fallbacks, fallback)
|
||||
}
|
||||
}
|
||||
|
||||
// Set runtime params
|
||||
if connConfig.RuntimeParams == nil {
|
||||
connConfig.RuntimeParams = make(map[string]string)
|
||||
@@ -141,23 +194,106 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
|
||||
|
||||
// Parse and apply connection options if specified
|
||||
if cfg.ConnOptions != "" {
|
||||
// Parse key=value pairs from ConnOptions
|
||||
// Format: "key1=value1 key2=value2"
|
||||
pairs := strings.Split(cfg.ConnOptions, " ")
|
||||
for _, pair := range pairs {
|
||||
if pair == "" {
|
||||
continue
|
||||
}
|
||||
kv := strings.SplitN(pair, "=", 2)
|
||||
if len(kv) == 2 {
|
||||
connConfig.RuntimeParams[kv[0]] = kv[1]
|
||||
}
|
||||
connOpts, err := parseConnOptions(cfg.ConnOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse connection options: %w", err)
|
||||
}
|
||||
|
||||
if err := applyConnOptions(connConfig, connOpts); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply connection options: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return connConfig, nil
|
||||
}
|
||||
|
||||
// getPortForIndex returns the port for the given host index.
|
||||
// If there are fewer ports than needed, returns the last port (libpq behavior).
|
||||
func getPortForIndex(ports []uint16, i int) uint16 {
|
||||
if i >= len(ports) {
|
||||
return ports[len(ports)-1]
|
||||
}
|
||||
return ports[i]
|
||||
}
|
||||
|
||||
// parseConnOptions decodes a base64-encoded JSON string into a map of connection options.
|
||||
// This matches the Python behavior in authentik/lib/config.py:get_dict_from_b64_json
|
||||
func parseConnOptions(encoded string) (map[string]string, error) {
|
||||
// Base64 decode
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &opts); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
|
||||
// Convert all values to strings
|
||||
result := make(map[string]string)
|
||||
for k, v := range opts {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
result[k] = val
|
||||
case float64:
|
||||
// JSON numbers are float64
|
||||
if val == float64(int(val)) {
|
||||
result[k] = strconv.Itoa(int(val))
|
||||
} else {
|
||||
result[k] = strconv.FormatFloat(val, 'f', -1, 64)
|
||||
}
|
||||
case bool:
|
||||
result[k] = strconv.FormatBool(val)
|
||||
default:
|
||||
result[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// applyConnOptions applies parsed connection options to the pgx.ConnConfig.
|
||||
func applyConnOptions(connConfig *pgx.ConnConfig, opts map[string]string) error {
|
||||
for key, value := range opts {
|
||||
// connect_timeout needs special handling as it's a connection-level timeout
|
||||
if key == "connect_timeout" {
|
||||
timeout, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid connect_timeout value: %w", err)
|
||||
}
|
||||
connConfig.ConnectTimeout = time.Duration(timeout) * time.Second
|
||||
continue
|
||||
}
|
||||
// target_session_attrs needs special handling to set ValidateConnect function
|
||||
if key == "target_session_attrs" {
|
||||
switch value {
|
||||
case "read-write":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||
case "read-only":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadOnly
|
||||
case "primary":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPrimary
|
||||
case "standby":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsStandby
|
||||
case "prefer-standby":
|
||||
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPreferStandby
|
||||
case "any":
|
||||
// "any" is the default (no validation needed)
|
||||
connConfig.ValidateConnect = nil
|
||||
default:
|
||||
return fmt.Errorf("unknown target_session_attrs value: %s", value)
|
||||
}
|
||||
// Do not add target_session_attrs to RuntimeParams
|
||||
continue
|
||||
}
|
||||
// All other options go to RuntimeParams
|
||||
connConfig.RuntimeParams[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
|
||||
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
@@ -13,12 +14,15 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
@@ -33,7 +37,7 @@ import (
|
||||
func SetupTestDB(t *testing.T) (*gorm.DB, *RefreshableConnPool) {
|
||||
cfg := config.Get().PostgreSQL
|
||||
|
||||
t.Logf("PostgreSQL config: Host=%s Port=%d User=%s DBName=%s SSLMode=%s",
|
||||
t.Logf("PostgreSQL config: Host=%s Port=%s User=%s DBName=%s SSLMode=%s",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Name, cfg.SSLMode)
|
||||
t.Logf("Password length: %d", len(cfg.Password))
|
||||
if cfg.Password == "" {
|
||||
@@ -485,7 +489,7 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
{
|
||||
name: "missing host",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
@@ -496,7 +500,7 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
name: "missing user",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
Name: "testdb",
|
||||
},
|
||||
expectError: true,
|
||||
@@ -506,7 +510,7 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
name: "missing database name",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
},
|
||||
expectError: true,
|
||||
@@ -516,23 +520,23 @@ func TestBuildDSN_Validation(t *testing.T) {
|
||||
name: "invalid port (zero)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 0,
|
||||
Port: "0",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "PostgreSQL port must be positive",
|
||||
errorMsg: "PostgreSQL port 0 must be positive",
|
||||
},
|
||||
{
|
||||
name: "invalid port (negative)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: -1,
|
||||
Port: "-1",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "PostgreSQL port must be positive",
|
||||
errorMsg: "PostgreSQL port -1 must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -560,7 +564,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "basic configuration",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
},
|
||||
@@ -576,7 +580,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with simple password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "testpass",
|
||||
Name: "testdb",
|
||||
@@ -589,7 +593,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing spaces",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "my secure password",
|
||||
Name: "testdb",
|
||||
@@ -602,7 +606,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing single quotes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "pass'word",
|
||||
Name: "testdb",
|
||||
@@ -615,7 +619,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing backslashes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: `pass\word`,
|
||||
Name: "testdb",
|
||||
@@ -628,7 +632,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing special characters",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: `p@ss w0rd!#$%^&*()`,
|
||||
Name: "testdb",
|
||||
@@ -641,7 +645,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with password containing quotes and backslashes",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: `my'pass\word"here`,
|
||||
Name: "testdb",
|
||||
@@ -654,7 +658,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with passphrase (multiple spaces)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "the quick brown fox jumps over",
|
||||
Name: "testdb",
|
||||
@@ -667,7 +671,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with sslmode=disable",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "disable",
|
||||
@@ -680,7 +684,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with sslmode=require (no certs)",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
@@ -694,7 +698,7 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with custom schema",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
DefaultSchema: "custom_schema",
|
||||
@@ -707,27 +711,48 @@ func TestBuildConnConfig(t *testing.T) {
|
||||
name: "with connection options",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: "connect_timeout=10 application_name=authentik",
|
||||
ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik"}`)),
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "10", cc.RuntimeParams["connect_timeout"])
|
||||
assert.Equal(t, 10*time.Second, cc.ConnectTimeout)
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with target_session_attrs",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set for target_session_attrs")
|
||||
// Verify it's the correct validator function
|
||||
expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||
assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(),
|
||||
runtime.FuncForPC(reflect.ValueOf(cc.ValidateConnect).Pointer()).Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full configuration with special password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5433,
|
||||
Port: "5433",
|
||||
User: "admin",
|
||||
Password: "my super secret password!@#",
|
||||
Name: "production",
|
||||
SSLMode: "require",
|
||||
DefaultSchema: "app_schema",
|
||||
ConnOptions: "application_name=authentik",
|
||||
ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)),
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "db.example.com", cc.Host)
|
||||
@@ -765,7 +790,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
name: "verify-full with all certificates",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: "my secure password",
|
||||
Name: "testdb",
|
||||
@@ -786,7 +811,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
name: "verify-ca with root cert only",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "verify-ca",
|
||||
@@ -803,7 +828,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
name: "require with client cert",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
@@ -820,7 +845,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
name: "full configuration with SSL and special password",
|
||||
cfg: config.PostgreSQLConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 5433,
|
||||
Port: "5433",
|
||||
User: "admin",
|
||||
Password: "my super secret password!@#",
|
||||
Name: "production",
|
||||
@@ -829,7 +854,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) {
|
||||
SSLCert: clientCertPath,
|
||||
SSLKey: clientKeyPath,
|
||||
DefaultSchema: "app_schema",
|
||||
ConnOptions: "application_name=authentik",
|
||||
ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)),
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "db.example.com", cc.Host)
|
||||
@@ -881,7 +906,7 @@ func TestBuildDSN_WithSpecialPasswords(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Password: tt.password,
|
||||
Name: "testdb",
|
||||
@@ -941,6 +966,221 @@ func TestPostgresStore_ConnectionPoolSettings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseConnOptions tests the base64 JSON parsing of connection options
|
||||
func TestParseConnOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected map[string]string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "simple key-value",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
expected: map[string]string{"target_session_attrs": "read-write"},
|
||||
},
|
||||
{
|
||||
name: "multiple options",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik"}`)),
|
||||
expected: map[string]string{"connect_timeout": "10", "application_name": "authentik"},
|
||||
},
|
||||
{
|
||||
name: "numeric value as number",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":10}`)),
|
||||
expected: map[string]string{"connect_timeout": "10"},
|
||||
},
|
||||
{
|
||||
name: "boolean value",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{"default_transaction_read_only":true}`)),
|
||||
expected: map[string]string{"default_transaction_read_only": "true"},
|
||||
},
|
||||
{
|
||||
name: "empty object",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`{}`)),
|
||||
expected: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
input: "not-valid-base64!!!",
|
||||
expectError: true,
|
||||
errorMsg: "invalid base64 encoding",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`not json`)),
|
||||
expectError: true,
|
||||
errorMsg: "invalid JSON",
|
||||
},
|
||||
{
|
||||
name: "JSON array instead of object",
|
||||
input: base64.StdEncoding.EncodeToString([]byte(`["value1", "value2"]`)),
|
||||
expectError: true,
|
||||
errorMsg: "invalid JSON",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseConnOptions(tt.input)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyConnOptions tests that connection options are applied correctly to pgx.ConnConfig
|
||||
func TestApplyConnOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts map[string]string
|
||||
validate func(*testing.T, *pgx.ConnConfig)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "connect_timeout sets ConnectTimeout",
|
||||
opts: map[string]string{"connect_timeout": "30"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, 30*time.Second, cc.ConnectTimeout)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs sets ValidateConnect",
|
||||
opts: map[string]string{"target_session_attrs": "read-write"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not be in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set")
|
||||
expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||
assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(),
|
||||
runtime.FuncForPC(reflect.ValueOf(cc.ValidateConnect).Pointer()).Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "application_name goes to RuntimeParams",
|
||||
opts: map[string]string{"application_name": "my-app"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "my-app", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "statement_timeout goes to RuntimeParams",
|
||||
opts: map[string]string{"statement_timeout": "5000"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "5000", cc.RuntimeParams["statement_timeout"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown options go to RuntimeParams",
|
||||
opts: map[string]string{"custom_param": "custom_value"},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, "custom_value", cc.RuntimeParams["custom_param"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple options",
|
||||
opts: map[string]string{
|
||||
"connect_timeout": "10",
|
||||
"target_session_attrs": "read-write",
|
||||
"application_name": "authentik",
|
||||
},
|
||||
validate: func(t *testing.T, cc *pgx.ConnConfig) {
|
||||
assert.Equal(t, 10*time.Second, cc.ConnectTimeout)
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not be in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set")
|
||||
assert.Equal(t, "authentik", cc.RuntimeParams["application_name"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid connect_timeout",
|
||||
opts: map[string]string{"connect_timeout": "not-a-number"},
|
||||
expectError: true,
|
||||
errorMsg: "invalid connect_timeout value",
|
||||
},
|
||||
{
|
||||
name: "invalid target_session_attrs",
|
||||
opts: map[string]string{"target_session_attrs": "invalid-mode"},
|
||||
expectError: true,
|
||||
errorMsg: "unknown target_session_attrs value",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a base config
|
||||
connConfig, err := pgx.ParseConfig("")
|
||||
require.NoError(t, err)
|
||||
connConfig.RuntimeParams = make(map[string]string)
|
||||
|
||||
err = applyConnOptions(connConfig, tt.opts)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
tt.validate(t, connConfig)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_Base64JSONConnOptions tests the full integration of base64 JSON connection options
|
||||
func TestBuildConnConfig_Base64JSONConnOptions(t *testing.T) {
|
||||
t.Run("bug report scenario - target_session_attrs", func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "authentik",
|
||||
Name: "authentik",
|
||||
ConnOptions: "eyJ0YXJnZXRfc2Vzc2lvbl9hdHRycyI6InJlYWQtd3JpdGUifQ==",
|
||||
}
|
||||
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := connConfig.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, connConfig.ValidateConnect, "ValidateConnect should be set")
|
||||
expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||
assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(),
|
||||
runtime.FuncForPC(reflect.ValueOf(connConfig.ValidateConnect).Pointer()).Name())
|
||||
})
|
||||
|
||||
t.Run("complex connection options", func(t *testing.T) {
|
||||
// {"connect_timeout":10,"target_session_attrs":"read-write","application_name":"authentik-proxy"}
|
||||
connOpts := base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":10,"target_session_attrs":"read-write","application_name":"authentik-proxy"}`))
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "authentik",
|
||||
Name: "authentik",
|
||||
ConnOptions: connOpts,
|
||||
}
|
||||
|
||||
connConfig, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 10*time.Second, connConfig.ConnectTimeout)
|
||||
// target_session_attrs should NOT be in RuntimeParams
|
||||
_, hasTargetSessionAttrs := connConfig.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams")
|
||||
// It should set ValidateConnect instead
|
||||
assert.NotNil(t, connConfig.ValidateConnect, "ValidateConnect should be set")
|
||||
assert.Equal(t, "authentik-proxy", connConfig.RuntimeParams["application_name"])
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to create session data JSON
|
||||
func createSessionData(t *testing.T, claims map[string]interface{}) string {
|
||||
sessionData := map[string]interface{}{
|
||||
@@ -1036,3 +1276,495 @@ func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPat
|
||||
|
||||
return rootCertPath, clientCertPath, clientKeyPath, cleanup
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_WithBase64EncodedConnOptions demonstrates that ConnOptions
|
||||
// should be base64-encoded JSON but is currently being parsed as key=value pairs
|
||||
func TestBuildConnConfig_WithBase64EncodedConnOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
connOptions string
|
||||
expected map[string]string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "base64 encoded JSON with single parameter",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10"}`)),
|
||||
expected: map[string]string{
|
||||
// connect_timeout is handled specially and NOT added to RuntimeParams
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base64 encoded JSON with multiple parameters",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik","statement_timeout":"30000"}`)),
|
||||
expected: map[string]string{
|
||||
// connect_timeout is handled specially and NOT added to RuntimeParams
|
||||
"application_name": "authentik",
|
||||
"statement_timeout": "30000",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base64 encoded JSON with special characters in values",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik proxy v2"}`)),
|
||||
expected: map[string]string{
|
||||
"application_name": "authentik proxy v2",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base64 encoded JSON with target_session_attrs",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write","application_name":"authentik"}`)),
|
||||
expected: map[string]string{
|
||||
"application_name": "authentik",
|
||||
// target_session_attrs should NOT appear in RuntimeParams
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: tt.connOptions,
|
||||
}
|
||||
|
||||
result, err := BuildConnConfig(cfg)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify that all expected parameters are present in RuntimeParams
|
||||
for key, expectedValue := range tt.expected {
|
||||
actualValue, exists := result.RuntimeParams[key]
|
||||
assert.True(t, exists, "Expected runtime parameter %s to exist", key)
|
||||
assert.Equal(t, expectedValue, actualValue, "Runtime parameter %s should have value %s", key, expectedValue)
|
||||
}
|
||||
|
||||
// Verify that connect_timeout is handled specially (sets ConnectTimeout field, not RuntimeParams)
|
||||
if tt.name == "base64 encoded JSON with single parameter" || tt.name == "base64 encoded JSON with multiple parameters" {
|
||||
_, hasConnectTimeout := result.RuntimeParams["connect_timeout"]
|
||||
assert.False(t, hasConnectTimeout, "connect_timeout should not appear in RuntimeParams")
|
||||
assert.Equal(t, 10*time.Second, result.ConnectTimeout, "connect_timeout should be set as ConnectTimeout duration")
|
||||
}
|
||||
|
||||
// Verify that target_session_attrs is NOT in RuntimeParams
|
||||
// (it affects connection behavior, not a runtime param)
|
||||
_, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_TargetSessionAttrs demonstrates how target_session_attrs
|
||||
// should be properly handled using pgx's ValidateConnect callback
|
||||
func TestBuildConnConfig_TargetSessionAttrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
connOptions string
|
||||
targetSessionAttrs string
|
||||
expectedValidator pgconn.ValidateConnectFunc
|
||||
validatorDescription string
|
||||
}{
|
||||
{
|
||||
name: "target_session_attrs=read-write",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
validatorDescription: "should validate connection is read-write by checking transaction_read_only=off",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=read-only",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-only"}`)),
|
||||
targetSessionAttrs: "read-only",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadOnly,
|
||||
validatorDescription: "should validate connection is read-only by checking transaction_read_only=on",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=primary",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"primary"}`)),
|
||||
targetSessionAttrs: "primary",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsPrimary,
|
||||
validatorDescription: "should validate connection is to primary by checking in_hot_standby=off",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=standby",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"standby"}`)),
|
||||
targetSessionAttrs: "standby",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsStandby,
|
||||
validatorDescription: "should validate connection is to standby by checking in_hot_standby=on",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=prefer-standby",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"prefer-standby"}`)),
|
||||
targetSessionAttrs: "prefer-standby",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsPreferStandby,
|
||||
validatorDescription: "should prefer standby connections (affects fallback logic)",
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs=any (default)",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"any"}`)),
|
||||
targetSessionAttrs: "any",
|
||||
expectedValidator: nil,
|
||||
validatorDescription: "should not set validator as any connection is acceptable",
|
||||
},
|
||||
{
|
||||
name: "no target_session_attrs",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)),
|
||||
targetSessionAttrs: "",
|
||||
expectedValidator: nil,
|
||||
validatorDescription: "should not set validator when target_session_attrs is not specified",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
ConnOptions: tt.connOptions,
|
||||
}
|
||||
|
||||
result, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify target_session_attrs is NOT in RuntimeParams
|
||||
_, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs,
|
||||
"target_session_attrs should not appear in RuntimeParams")
|
||||
|
||||
// Verify ValidateConnect callback is set to the correct standard pgx function
|
||||
if tt.expectedValidator != nil {
|
||||
require.NotNil(t, result.ValidateConnect,
|
||||
"ValidateConnect should be set for target_session_attrs=%s: %s",
|
||||
tt.targetSessionAttrs, tt.validatorDescription)
|
||||
|
||||
// Compare function pointers using reflect to check if it's the same function
|
||||
actualFuncPtr := runtime.FuncForPC(reflect.ValueOf(result.ValidateConnect).Pointer())
|
||||
expectedFuncPtr := runtime.FuncForPC(reflect.ValueOf(tt.expectedValidator).Pointer())
|
||||
|
||||
assert.Equal(t, expectedFuncPtr.Name(), actualFuncPtr.Name(),
|
||||
"ValidateConnect should be set to %s for target_session_attrs=%s",
|
||||
expectedFuncPtr.Name(), tt.targetSessionAttrs)
|
||||
|
||||
t.Logf("Expected validator: %s", expectedFuncPtr.Name())
|
||||
t.Logf("Actual validator: %s", actualFuncPtr.Name())
|
||||
} else {
|
||||
assert.Nil(t, result.ValidateConnect,
|
||||
"ValidateConnect should not be set: %s", tt.validatorDescription)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_TargetSessionAttrs_WithMultipleHosts tests that when multiple
|
||||
// hosts are specified, fallbacks are properly configured along with the validator
|
||||
func TestBuildConnConfig_TargetSessionAttrs_WithMultipleHosts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
port string
|
||||
sslMode string
|
||||
connOptions string
|
||||
targetSessionAttrs string
|
||||
expectedValidator pgconn.ValidateConnectFunc
|
||||
expectedPrimaryHost string
|
||||
expectedPrimaryPort uint16
|
||||
expectedFallbacks []*pgconn.FallbackConfig
|
||||
expectTLS bool
|
||||
validatorDescription string
|
||||
}{
|
||||
{
|
||||
name: "multiple hosts with read-write",
|
||||
host: "db1.local,db2.local,db3.local",
|
||||
port: "5432",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
expectedPrimaryHost: "db1.local",
|
||||
expectedPrimaryPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5432, TLSConfig: nil},
|
||||
{Host: "db3.local", Port: 5432, TLSConfig: nil},
|
||||
},
|
||||
expectTLS: false,
|
||||
validatorDescription: "should set validator and create fallbacks for additional hosts",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts with ports specified",
|
||||
host: "db1.local,db2.local,db3.local",
|
||||
port: "5432,5433,5434",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
expectedPrimaryHost: "db1.local",
|
||||
expectedPrimaryPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5433, TLSConfig: nil},
|
||||
{Host: "db3.local", Port: 5434, TLSConfig: nil},
|
||||
},
|
||||
expectTLS: false,
|
||||
validatorDescription: "should handle hosts with explicit ports",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts with TLS required",
|
||||
host: "db1.local,db2.local,db3.local",
|
||||
port: "5432",
|
||||
sslMode: "require",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write", "sslmode":"require"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
expectedPrimaryHost: "db1.local",
|
||||
expectedPrimaryPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5432}, // TLSConfig should be set (non-nil)
|
||||
{Host: "db3.local", Port: 5432}, // TLSConfig should be set (non-nil)
|
||||
},
|
||||
expectTLS: true,
|
||||
validatorDescription: "should set TLS config for all hosts when sslmode=require",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts with TLS verify-full",
|
||||
host: "db1.local,db2.local,db3.local",
|
||||
port: "5432",
|
||||
sslMode: "require",
|
||||
connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write", "sslmode":"verify-full"}`)),
|
||||
targetSessionAttrs: "read-write",
|
||||
expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||
expectedPrimaryHost: "db1.local",
|
||||
expectedPrimaryPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5432}, // TLSConfig should be set (non-nil)
|
||||
{Host: "db3.local", Port: 5432}, // TLSConfig should be set (non-nil)
|
||||
},
|
||||
expectTLS: true,
|
||||
validatorDescription: "should set TLS config host name for all hosts when sslmode=verify-full",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: tt.host,
|
||||
Port: tt.port,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
SSLMode: tt.sslMode,
|
||||
ConnOptions: tt.connOptions,
|
||||
}
|
||||
|
||||
result, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify target_session_attrs is NOT in RuntimeParams
|
||||
_, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"]
|
||||
assert.False(t, hasTargetSessionAttrs,
|
||||
"target_session_attrs should not appear in RuntimeParams")
|
||||
|
||||
// Verify ValidateConnect is set to the correct function
|
||||
require.NotNil(t, result.ValidateConnect,
|
||||
"ValidateConnect should be set for target_session_attrs=%s with multiple hosts",
|
||||
tt.targetSessionAttrs)
|
||||
|
||||
actualFuncPtr := runtime.FuncForPC(reflect.ValueOf(result.ValidateConnect).Pointer())
|
||||
expectedFuncPtr := runtime.FuncForPC(reflect.ValueOf(tt.expectedValidator).Pointer())
|
||||
|
||||
assert.Equal(t, expectedFuncPtr.Name(), actualFuncPtr.Name(),
|
||||
"ValidateConnect should be %s for target_session_attrs=%s",
|
||||
expectedFuncPtr.Name(), tt.targetSessionAttrs)
|
||||
|
||||
// Verify the primary host and port
|
||||
assert.Equal(t, tt.expectedPrimaryHost, result.Host,
|
||||
"Primary host should be %s", tt.expectedPrimaryHost)
|
||||
assert.Equal(t, tt.expectedPrimaryPort, result.Port,
|
||||
"Primary port should be %d", tt.expectedPrimaryPort)
|
||||
|
||||
// Verify primary TLSConfig based on sslmode
|
||||
if tt.expectTLS {
|
||||
assert.NotNil(t, result.TLSConfig,
|
||||
"Primary connection should have TLSConfig set when sslmode=%s", tt.sslMode)
|
||||
} else {
|
||||
assert.Nil(t, result.TLSConfig,
|
||||
"Primary connection should not have TLSConfig when sslmode is not set")
|
||||
}
|
||||
|
||||
// Verify Fallbacks are configured for the additional hosts
|
||||
require.Len(t, result.Fallbacks, len(tt.expectedFallbacks),
|
||||
"Should have %d fallback configs for the additional hosts", len(tt.expectedFallbacks))
|
||||
|
||||
// Verify each fallback configuration
|
||||
for i, expectedFb := range tt.expectedFallbacks {
|
||||
actualFb := result.Fallbacks[i]
|
||||
|
||||
assert.Equal(t, expectedFb.Host, actualFb.Host,
|
||||
"Fallback %d host should be %s", i+1, expectedFb.Host)
|
||||
assert.Equal(t, expectedFb.Port, actualFb.Port,
|
||||
"Fallback %d port should be %d", i+1, expectedFb.Port)
|
||||
|
||||
// Verify TLSConfig is set appropriately for fallbacks
|
||||
if tt.expectTLS {
|
||||
assert.NotNil(t, actualFb.TLSConfig,
|
||||
"Fallback %d should have TLSConfig set when sslmode=%s", i+1, tt.sslMode)
|
||||
// Verify InsecureSkipVerify for sslmode=require
|
||||
switch tt.sslMode {
|
||||
case "require":
|
||||
assert.True(t, actualFb.TLSConfig.InsecureSkipVerify,
|
||||
"Fallback %d TLSConfig should have InsecureSkipVerify=true for sslmode=require", i+1)
|
||||
case "verify-full":
|
||||
assert.False(t, actualFb.TLSConfig.InsecureSkipVerify,
|
||||
"Fallback %d TLSConfig should have InsecureSkipVerify=false for sslmode=verify-full", i+1)
|
||||
assert.Equal(t, actualFb.Host, actualFb.TLSConfig.ServerName,
|
||||
"Fallback %d TLSConfig ServerName should match host for sslmode=verify-full", i+1)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, actualFb.TLSConfig,
|
||||
"Fallback %d should not have TLSConfig when sslmode is not set", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Log the configuration for debugging
|
||||
t.Logf("Primary host: %s:%d", result.Host, result.Port)
|
||||
t.Logf("Validator: %s", actualFuncPtr.Name())
|
||||
for i, fb := range result.Fallbacks {
|
||||
t.Logf("Fallback %d: %s:%d", i+1, fb.Host, fb.Port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_MultipleHosts_WithoutTargetSessionAttrs tests that multiple hosts
|
||||
// create fallbacks even without target_session_attrs
|
||||
func TestBuildConnConfig_MultipleHosts_WithoutTargetSessionAttrs(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: "db1.local,db2.local,db3.local",
|
||||
Port: "5432",
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
}
|
||||
|
||||
result, err := BuildConnConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify primary host
|
||||
assert.Equal(t, "db1.local", result.Host)
|
||||
assert.Equal(t, uint16(5432), result.Port)
|
||||
|
||||
// Verify fallbacks are created
|
||||
require.Len(t, result.Fallbacks, 2, "Should have 2 fallback configs")
|
||||
assert.Equal(t, "db2.local", result.Fallbacks[0].Host)
|
||||
assert.Equal(t, uint16(5432), result.Fallbacks[0].Port)
|
||||
assert.Equal(t, "db3.local", result.Fallbacks[1].Host)
|
||||
assert.Equal(t, uint16(5432), result.Fallbacks[1].Port)
|
||||
|
||||
// Verify no ValidateConnect is set (no target_session_attrs)
|
||||
assert.Nil(t, result.ValidateConnect)
|
||||
}
|
||||
|
||||
// TestBuildConnConfig_CommaSeparatedPorts_EdgeCases tests edge cases and error scenarios for comma-separated ports
|
||||
func TestBuildConnConfig_CommaSeparatedPorts_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
port string
|
||||
expectError bool
|
||||
errorContains string
|
||||
expectedHost string
|
||||
expectedPort uint16
|
||||
expectedFallbacks []*pgconn.FallbackConfig
|
||||
}{
|
||||
{
|
||||
name: "invalid port in comma-separated list",
|
||||
host: "db1.local,db2.local",
|
||||
port: "5432,abc",
|
||||
expectError: true,
|
||||
errorContains: "invalid port value",
|
||||
},
|
||||
{
|
||||
name: "port out of range (too high)",
|
||||
host: "db1.local,db2.local",
|
||||
port: "5432,99999",
|
||||
expectError: true,
|
||||
errorContains: "PostgreSQL port 99999 is out of valid range",
|
||||
},
|
||||
{
|
||||
name: "port out of range (zero)",
|
||||
host: "db1.local,db2.local",
|
||||
port: "5432,0",
|
||||
expectError: true,
|
||||
errorContains: "PostgreSQL port 0 must be positive",
|
||||
},
|
||||
{
|
||||
name: "empty port string",
|
||||
host: "db1.local",
|
||||
port: "",
|
||||
expectError: true,
|
||||
errorContains: "PostgreSQL port is required",
|
||||
},
|
||||
{
|
||||
name: "port with only whitespace",
|
||||
host: "db1.local",
|
||||
port: " ",
|
||||
expectError: true,
|
||||
errorContains: "invalid port value",
|
||||
},
|
||||
{
|
||||
name: "mismatched number of hosts and ports",
|
||||
host: "db1.local,db2.local",
|
||||
port: "5432",
|
||||
expectError: false,
|
||||
expectedHost: "db1.local",
|
||||
expectedPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{
|
||||
{Host: "db2.local", Port: 5432},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra ports than hosts",
|
||||
host: "db1.local",
|
||||
port: "5432,5433",
|
||||
expectError: false,
|
||||
expectedHost: "db1.local",
|
||||
expectedPort: 5432,
|
||||
expectedFallbacks: []*pgconn.FallbackConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.PostgreSQLConfig{
|
||||
Host: tt.host,
|
||||
Port: tt.port,
|
||||
User: "testuser",
|
||||
Name: "testdb",
|
||||
}
|
||||
|
||||
c, err := BuildConnConfig(cfg)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, c)
|
||||
|
||||
assert.Equal(t, tt.expectedHost, c.Host)
|
||||
assert.Equal(t, tt.expectedPort, c.Port)
|
||||
require.Len(t, c.Fallbacks, len(tt.expectedFallbacks))
|
||||
for i, expectedFb := range tt.expectedFallbacks {
|
||||
actualFb := c.Fallbacks[i]
|
||||
assert.Equal(t, expectedFb.Host, actualFb.Host)
|
||||
assert.Equal(t, expectedFb.Port, actualFb.Port)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user