diff --git a/internal/outpost/proxyv2/application/application.go b/internal/outpost/proxyv2/application/application.go index 3a428d46b9..ee09266b54 100644 --- a/internal/outpost/proxyv2/application/application.go +++ b/internal/outpost/proxyv2/application/application.go @@ -160,7 +160,7 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server, old a.sessions = sess } mux.Use(web.NewLoggingHandler(muxLogger, func(l *log.Entry, r *http.Request) *log.Entry { - c := a.getClaimsFromSession(r) + c := a.getClaimsFromSession(nil, r) if c == nil { return l } @@ -171,7 +171,7 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server, old })) mux.Use(func(inner http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - c := a.getClaimsFromSession(r) + c := a.getClaimsFromSession(nil, r) user := "" if c != nil { user = c.PreferredUsername diff --git a/internal/outpost/proxyv2/application/auth.go b/internal/outpost/proxyv2/application/auth.go index 2e1e821713..4ce594d838 100644 --- a/internal/outpost/proxyv2/application/auth.go +++ b/internal/outpost/proxyv2/application/auth.go @@ -13,7 +13,7 @@ import ( // checkAuth Get claims which are currently in session // Returns an error if the session can't be loaded or the claims can't be parsed/type-cast func (a *Application) checkAuth(rw http.ResponseWriter, r *http.Request) (*types.Claims, error) { - c := a.getClaimsFromSession(r) + c := a.getClaimsFromSession(rw, r) if c != nil { return c, nil } @@ -50,10 +50,17 @@ func (a *Application) checkAuth(rw http.ResponseWriter, r *http.Request) (*types return nil, fmt.Errorf("failed to get claims from session") } -func (a *Application) getClaimsFromSession(r *http.Request) *types.Claims { +func (a *Application) getClaimsFromSession(rw http.ResponseWriter, r *http.Request) *types.Claims { s, err := a.sessions.Get(r, a.SessionName()) if err != nil { - // err == user has no session/session is not valid, reject + // err == user has no session/session is not valid + // Delete the stale session cookie if it exists + if rw != nil { + s.Options.MaxAge = -1 + if saveErr := s.Save(r, rw); saveErr != nil { + a.log.WithError(saveErr).Warning("failed to delete stale session cookie") + } + } return nil } claims, ok := s.Values[constants.SessionClaims] diff --git a/internal/outpost/proxyv2/application/auth_test.go b/internal/outpost/proxyv2/application/auth_test.go index 3e20b44942..89876e281b 100644 --- a/internal/outpost/proxyv2/application/auth_test.go +++ b/internal/outpost/proxyv2/application/auth_test.go @@ -207,7 +207,7 @@ func TestGetClaimsFromSession_Success(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) // Test getClaimsFromSession - claims := app.getClaimsFromSession(req) + claims := app.getClaimsFromSession(nil, req) require.NotNil(t, claims) assert.Equal(t, "user-id-123", claims.Sub) assert.Equal(t, 1234567890, claims.Exp) @@ -250,7 +250,7 @@ func TestGetClaimsFromSession_NoSession(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) - claims := app.getClaimsFromSession(req) + claims := app.getClaimsFromSession(nil, req) assert.Nil(t, claims) } @@ -266,7 +266,7 @@ func TestGetClaimsFromSession_NoClaims(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) - claims := app.getClaimsFromSession(req) + claims := app.getClaimsFromSession(nil, req) assert.Nil(t, claims) } @@ -280,7 +280,7 @@ func TestGetClaimsFromSession_InvalidClaimsType(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) - claims := app.getClaimsFromSession(req) + claims := app.getClaimsFromSession(nil, req) assert.Nil(t, claims) } diff --git a/internal/outpost/proxyv2/application/mode_proxy.go b/internal/outpost/proxyv2/application/mode_proxy.go index 4555dc28ac..79126d0017 100644 --- a/internal/outpost/proxyv2/application/mode_proxy.go +++ b/internal/outpost/proxyv2/application/mode_proxy.go @@ -73,7 +73,7 @@ func (a *Application) proxyModifyRequest(ou *url.URL) func(req *http.Request) { r.Header.Set("X-Forwarded-Host", r.Host) r.URL.Scheme = ou.Scheme r.URL.Host = ou.Host - claims := a.getClaimsFromSession(r) + claims := a.getClaimsFromSession(nil, r) if claims != nil && claims.Proxy != nil { if claims.Proxy.BackendOverride != "" { u, err := url.Parse(claims.Proxy.BackendOverride) diff --git a/internal/outpost/proxyv2/application/oauth.go b/internal/outpost/proxyv2/application/oauth.go index e2e3c17dd3..0ce1c7d789 100644 --- a/internal/outpost/proxyv2/application/oauth.go +++ b/internal/outpost/proxyv2/application/oauth.go @@ -19,6 +19,7 @@ func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, f state, err := a.createState(r, rw, fwd) if err != nil { a.log.WithError(err).Warning("failed to create state") + rw.WriteHeader(400) return } http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(state), http.StatusFound) diff --git a/internal/outpost/proxyv2/application/oauth_callback.go b/internal/outpost/proxyv2/application/oauth_callback.go index 8c3d8f210d..781eb67c5e 100644 --- a/internal/outpost/proxyv2/application/oauth_callback.go +++ b/internal/outpost/proxyv2/application/oauth_callback.go @@ -13,7 +13,7 @@ import ( ) func (a *Application) handleAuthCallback(rw http.ResponseWriter, r *http.Request) { - state := a.stateFromRequest(r) + state := a.stateFromRequest(rw, r) if state == nil { a.log.Warning("invalid state") a.redirect(rw, r) diff --git a/internal/outpost/proxyv2/application/oauth_state.go b/internal/outpost/proxyv2/application/oauth_state.go index c6666d4497..0da7a53464 100644 --- a/internal/outpost/proxyv2/application/oauth_state.go +++ b/internal/outpost/proxyv2/application/oauth_state.go @@ -96,7 +96,7 @@ func (a *Application) createState(r *http.Request, w http.ResponseWriter, fwd st return tokenString, nil } -func (a *Application) stateFromRequest(r *http.Request) *OAuthState { +func (a *Application) stateFromRequest(rw http.ResponseWriter, r *http.Request) *OAuthState { stateJwt := r.URL.Query().Get("state") token, err := jwt.Parse(stateJwt, func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: @@ -127,6 +127,13 @@ func (a *Application) stateFromRequest(r *http.Request) *OAuthState { s, err := a.sessions.Get(r, a.SessionName()) if err != nil { a.log.WithError(err).Warning("failed to get session") + // Delete the stale session cookie if it exists + if rw != nil { + s.Options.MaxAge = -1 + if saveErr := s.Save(r, rw); saveErr != nil { + a.log.WithError(saveErr).Warning("failed to delete stale session cookie") + } + } return nil } if claims.SessionID != s.ID { diff --git a/internal/outpost/proxyv2/application/session_test.go b/internal/outpost/proxyv2/application/session_test.go index 0bc9e6563c..e2b323e580 100644 --- a/internal/outpost/proxyv2/application/session_test.go +++ b/internal/outpost/proxyv2/application/session_test.go @@ -76,3 +76,81 @@ func TestLogout(t *testing.T) { _, err = os.Stat(s2Name) assert.True(t, errors.Is(err, os.ErrNotExist)) } + +func TestStaleCookieDeletion(t *testing.T) { + a := newTestApplication() + _ = a.configureProxy() + + // Create a request with a session cookie that references a non-existent session file + req, _ := http.NewRequest("GET", "https://ext.t.goauthentik.io/foo", nil) + + // Set a cookie for a session that doesn't exist (simulates pod restart) + nonExistentSessionID := uuid.New().String() + req.AddCookie(&http.Cookie{ + Name: a.SessionName(), + Value: "encoded_session_data_" + nonExistentSessionID, + Path: "/", + }) + + rr := httptest.NewRecorder() + + // Call getClaimsFromSession which should delete the stale cookie + claims := a.getClaimsFromSession(rr, req) + + // Verify no claims were returned (session doesn't exist) + assert.Nil(t, claims) + + // Verify the response includes a Set-Cookie header to delete the stale cookie + cookies := rr.Result().Cookies() + var foundDeleteCookie bool + for _, cookie := range cookies { + if cookie.Name == a.SessionName() && cookie.MaxAge < 0 { + foundDeleteCookie = true + break + } + } + assert.True(t, foundDeleteCookie, "Expected stale session cookie to be deleted") +} + +func TestStateFromRequestDeletesStaleCookie(t *testing.T) { + a := newTestApplication() + _ = a.configureProxy() + + // Create a valid state JWT (from createState) + req, _ := http.NewRequest("GET", "https://ext.t.goauthentik.io/foo", nil) + rr := httptest.NewRecorder() + + state, err := a.createState(req, rr, "/redirect") + assert.NoError(t, err) + + // Create a new request with the state but a stale session cookie + req2, _ := http.NewRequest("GET", "https://ext.t.goauthentik.io/callback?state="+state, nil) + + // Add a cookie for a non-existent session + nonExistentSessionID := uuid.New().String() + req2.AddCookie(&http.Cookie{ + Name: a.SessionName(), + Value: "encoded_session_data_" + nonExistentSessionID, + Path: "/", + }) + + rr2 := httptest.NewRecorder() + + // Call stateFromRequest which should fail due to missing session + // but should also delete the stale cookie + claims := a.stateFromRequest(rr2, req2) + + // Verify no claims were returned + assert.Nil(t, claims) + + // Verify the response includes a Set-Cookie header to delete the stale cookie + cookies := rr2.Result().Cookies() + var foundDeleteCookie bool + for _, cookie := range cookies { + if cookie.Name == a.SessionName() && cookie.MaxAge < 0 { + foundDeleteCookie = true + break + } + } + assert.True(t, foundDeleteCookie, "Expected stale session cookie to be deleted") +} diff --git a/internal/outpost/proxyv2/application/test.go b/internal/outpost/proxyv2/application/test.go index ecc095a404..a3e948453a 100644 --- a/internal/outpost/proxyv2/application/test.go +++ b/internal/outpost/proxyv2/application/test.go @@ -87,7 +87,7 @@ func (a *Application) assertState(t *testing.T, req *http.Request, response *htt nrq.Set("state", state) nr.URL.RawQuery = nrq.Encode() // parse state - parsed := a.stateFromRequest(nr) + parsed := a.stateFromRequest(nil, nr) if parsed == nil { panic("Could not parse state") } diff --git a/internal/outpost/proxyv2/application/utils.go b/internal/outpost/proxyv2/application/utils.go index 248d7ddf09..4fae458fc3 100644 --- a/internal/outpost/proxyv2/application/utils.go +++ b/internal/outpost/proxyv2/application/utils.go @@ -16,7 +16,7 @@ func urlJoin(originalUrl string, newPath string) string { func (a *Application) redirect(rw http.ResponseWriter, r *http.Request) { fallbackRedirect := a.proxyConfig.ExternalHost - state := a.stateFromRequest(r) + state := a.stateFromRequest(rw, r) if state == nil { rw.WriteHeader(http.StatusBadRequest) return diff --git a/internal/outpost/proxyv2/postgresstore/connpool.go b/internal/outpost/proxyv2/postgresstore/connpool.go index aa4650e639..d96372966b 100644 --- a/internal/outpost/proxyv2/postgresstore/connpool.go +++ b/internal/outpost/proxyv2/postgresstore/connpool.go @@ -21,7 +21,6 @@ import ( type RefreshableConnPool struct { mu sync.RWMutex db *sql.DB - dsnBuilder func(config.PostgreSQLConfig) (string, error) log *log.Entry currentDSN string gormConfig *gorm.Config @@ -49,7 +48,6 @@ func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleC pool := &RefreshableConnPool{ db: db, - dsnBuilder: BuildDSN, log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"), currentDSN: initialDSN, gormConfig: gormConfig, @@ -86,7 +84,7 @@ func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error { // Get fresh config cfg := config.Get().RefreshPostgreSQLConfig() - newDSN, err := p.dsnBuilder(cfg) + newDSN, err := BuildDSN(cfg) if err != nil { p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials") return err diff --git a/internal/outpost/proxyv2/postgresstore/postgresstore.go b/internal/outpost/proxyv2/postgresstore/postgresstore.go index 7328bf518a..893404fe21 100644 --- a/internal/outpost/proxyv2/postgresstore/postgresstore.go +++ b/internal/outpost/proxyv2/postgresstore/postgresstore.go @@ -2,16 +2,20 @@ package postgresstore import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" "net/http" + "os" "strings" "time" "github.com/google/uuid" "github.com/gorilla/sessions" - _ "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" "github.com/mitchellh/mapstructure" log "github.com/sirupsen/logrus" _ "gorm.io/driver/postgres" @@ -49,60 +53,121 @@ func (ProxySession) TableName() string { return "authentik_providers_proxy_proxysession" } -// BuildDSN constructs a PostgreSQL connection string -func BuildDSN(cfg config.PostgreSQLConfig) (string, error) { +// BuildConnConfig constructs a pgx.ConnConfig from PostgreSQL configuration. +func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) { // Validate required fields if cfg.Host == "" { - return "", fmt.Errorf("PostgreSQL host is required") + return nil, fmt.Errorf("PostgreSQL host is required") } if cfg.User == "" { - return "", fmt.Errorf("PostgreSQL user is required") + return nil, fmt.Errorf("PostgreSQL user is required") } if cfg.Name == "" { - return "", fmt.Errorf("PostgreSQL database name is required") + return nil, fmt.Errorf("PostgreSQL database name is required") } if cfg.Port <= 0 { - return "", fmt.Errorf("PostgreSQL port must be positive") + return nil, fmt.Errorf("PostgreSQL port must be positive") } - // Build DSN string with all parameters - dsnParts := []string{ - "host=" + cfg.Host, - fmt.Sprintf("port=%d", cfg.Port), - "user=" + cfg.User, - "dbname=" + cfg.Name, + // Start with a default config + connConfig, err := pgx.ParseConfig("") + if err != nil { + return nil, fmt.Errorf("failed to create default config: %w", err) } - if cfg.Password != "" { - dsnParts = append(dsnParts, "password="+cfg.Password) - } + // Set connection parameters + connConfig.Host = cfg.Host + connConfig.Port = uint16(cfg.Port) + connConfig.User = cfg.User + connConfig.Password = cfg.Password + connConfig.Database = cfg.Name - // Add SSL mode + // Configure TLS/SSL if cfg.SSLMode != "" { - dsnParts = append(dsnParts, "sslmode="+cfg.SSLMode) + switch cfg.SSLMode { + case "disable": + connConfig.TLSConfig = nil + case "require", "verify-ca", "verify-full": + tlsConfig := &tls.Config{} + + // Load root CA certificate if provided + if cfg.SSLRootCert != "" { + caCert, err := os.ReadFile(cfg.SSLRootCert) + if err != nil { + return nil, fmt.Errorf("failed to read SSL root certificate: %w", err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse SSL root certificate") + } + tlsConfig.RootCAs = caCertPool + } + + // Load client certificate and key if provided + if cfg.SSLCert != "" && cfg.SSLKey != "" { + cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey) + if err != nil { + return nil, fmt.Errorf("failed to load SSL client certificate: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + // Set verification mode + switch cfg.SSLMode { + case "require": + // Don't verify the server certificate (just encrypt) + tlsConfig.InsecureSkipVerify = true + case "verify-ca": + // Verify the certificate is signed by a trusted CA + tlsConfig.InsecureSkipVerify = false + case "verify-full": + // Verify the certificate and hostname + tlsConfig.InsecureSkipVerify = false + tlsConfig.ServerName = cfg.Host + } + + connConfig.TLSConfig = tlsConfig + } } - // Add SSL certificates if provided - if cfg.SSLRootCert != "" { - dsnParts = append(dsnParts, "sslrootcert="+cfg.SSLRootCert) - } - if cfg.SSLCert != "" { - dsnParts = append(dsnParts, "sslcert="+cfg.SSLCert) - } - if cfg.SSLKey != "" { - dsnParts = append(dsnParts, "sslkey="+cfg.SSLKey) + // Set runtime params + if connConfig.RuntimeParams == nil { + connConfig.RuntimeParams = make(map[string]string) } + if cfg.DefaultSchema != "" { - dsnParts = append(dsnParts, "search_path="+cfg.DefaultSchema) + connConfig.RuntimeParams["search_path"] = cfg.DefaultSchema } - // Add connection options if specified + // Parse and apply connection options if specified if cfg.ConnOptions != "" { - dsnParts = append(dsnParts, cfg.ConnOptions) + // Parse key=value pairs from ConnOptions + // Format: "key1=value1 key2=value2" + pairs := strings.Split(cfg.ConnOptions, " ") + for _, pair := range pairs { + if pair == "" { + continue + } + kv := strings.SplitN(pair, "=", 2) + if len(kv) == 2 { + connConfig.RuntimeParams[kv[0]] = kv[1] + } + } } - // Join parts with spaces - return strings.Join(dsnParts, " "), nil + return connConfig, nil +} + +// BuildDSN constructs a PostgreSQL connection string from a ConnConfig. +func BuildDSN(cfg config.PostgreSQLConfig) (string, error) { + connConfig, err := BuildConnConfig(cfg) + if err != nil { + return "", err + } + + // Register the config and get a connection string + // (This approach lets pgx handle all the escaping internally which is quite convenient for say spaces in the password) + return stdlib.RegisterConnConfig(connConfig), nil } // SetupGORMWithRefreshablePool creates a GORM DB with a refreshable connection pool. diff --git a/internal/outpost/proxyv2/postgresstore/postgresstore_test.go b/internal/outpost/proxyv2/postgresstore/postgresstore_test.go index 4cfe4bda9f..092803c7b2 100644 --- a/internal/outpost/proxyv2/postgresstore/postgresstore_test.go +++ b/internal/outpost/proxyv2/postgresstore/postgresstore_test.go @@ -2,14 +2,23 @@ package postgresstore import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" "encoding/json" + "encoding/pem" "fmt" + "math/big" "net/http/httptest" + "os" + "path/filepath" "testing" "time" "github.com/google/uuid" "github.com/gorilla/sessions" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" @@ -541,11 +550,11 @@ func TestBuildDSN_Validation(t *testing.T) { } } -func TestBuildDSN(t *testing.T) { +func TestBuildConnConfig(t *testing.T) { tests := []struct { name string cfg config.PostgreSQLConfig - expected string + validate func(*testing.T, *pgx.ConnConfig) }{ { name: "basic configuration", @@ -555,10 +564,16 @@ func TestBuildDSN(t *testing.T) { User: "testuser", Name: "testdb", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "localhost", cc.Host) + assert.Equal(t, uint16(5432), cc.Port) + assert.Equal(t, "testuser", cc.User) + assert.Equal(t, "testdb", cc.Database) + assert.Equal(t, "", cc.Password) + }, }, { - name: "with password", + name: "with simple password", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: 5432, @@ -566,7 +581,87 @@ func TestBuildDSN(t *testing.T) { Password: "testpass", Name: "testdb", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb password=testpass", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "testpass", cc.Password) + }, + }, + { + name: "with password containing spaces", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: "my secure password", + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "my secure password", cc.Password) + }, + }, + { + name: "with password containing single quotes", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: "pass'word", + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "pass'word", cc.Password) + }, + }, + { + name: "with password containing backslashes", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: `pass\word`, + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, `pass\word`, cc.Password) + }, + }, + { + name: "with password containing special characters", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: `p@ss w0rd!#$%^&*()`, + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, `p@ss w0rd!#$%^&*()`, cc.Password) + }, + }, + { + name: "with password containing quotes and backslashes", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: `my'pass\word"here`, + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, `my'pass\word"here`, cc.Password) + }, + }, + { + name: "with passphrase (multiple spaces)", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: "the quick brown fox jumps over", + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "the quick brown fox jumps over", cc.Password) + }, }, { name: "with sslmode=disable", @@ -577,10 +672,12 @@ func TestBuildDSN(t *testing.T) { Name: "testdb", SSLMode: "disable", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=disable", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Nil(t, cc.TLSConfig) + }, }, { - name: "with sslmode=require", + name: "with sslmode=require (no certs)", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: 5432, @@ -588,32 +685,10 @@ func TestBuildDSN(t *testing.T) { Name: "testdb", SSLMode: "require", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=require", - }, - { - name: "with sslmode=prefer", - cfg: config.PostgreSQLConfig{ - Host: "localhost", - Port: 5432, - User: "testuser", - Name: "testdb", - SSLMode: "prefer", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.NotNil(t, cc.TLSConfig) + assert.True(t, cc.TLSConfig.InsecureSkipVerify) }, - expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=prefer", - }, - { - name: "with SSL certificates", - cfg: config.PostgreSQLConfig{ - Host: "localhost", - Port: 5432, - User: "testuser", - Name: "testdb", - SSLMode: "verify-full", - SSLRootCert: "/path/to/root.crt", - SSLCert: "/path/to/client.crt", - SSLKey: "/path/to/client.key", - }, - expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=verify-full sslrootcert=/path/to/root.crt sslcert=/path/to/client.crt sslkey=/path/to/client.key", }, { name: "with custom schema", @@ -624,7 +699,9 @@ func TestBuildDSN(t *testing.T) { Name: "testdb", DefaultSchema: "custom_schema", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb search_path=custom_schema", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "custom_schema", cc.RuntimeParams["search_path"]) + }, }, { name: "with connection options", @@ -633,34 +710,192 @@ func TestBuildDSN(t *testing.T) { Port: 5432, User: "testuser", Name: "testdb", - ConnOptions: "connect_timeout=10", + ConnOptions: "connect_timeout=10 application_name=authentik", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "10", cc.RuntimeParams["connect_timeout"]) + assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) }, - expected: "host=localhost port=5432 user=testuser dbname=testdb connect_timeout=10", }, { - name: "full configuration", + name: "full configuration with special password", cfg: config.PostgreSQLConfig{ Host: "db.example.com", Port: 5433, User: "admin", - Password: "secret", + Password: "my super secret password!@#", Name: "production", - SSLMode: "verify-full", - SSLRootCert: "/certs/root.crt", - SSLCert: "/certs/client.crt", - SSLKey: "/certs/client.key", + SSLMode: "require", DefaultSchema: "app_schema", ConnOptions: "application_name=authentik", }, - expected: "host=db.example.com port=5433 user=admin dbname=production password=secret sslmode=verify-full sslrootcert=/certs/root.crt sslcert=/certs/client.crt sslkey=/certs/client.key search_path=app_schema application_name=authentik", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "db.example.com", cc.Host) + assert.Equal(t, uint16(5433), cc.Port) + assert.Equal(t, "admin", cc.User) + assert.Equal(t, "my super secret password!@#", cc.Password) + assert.Equal(t, "production", cc.Database) + assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"]) + assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := BuildDSN(tt.cfg) + result, err := BuildConnConfig(tt.cfg) require.NoError(t, err) - assert.Equal(t, tt.expected, result) + require.NotNil(t, result) + tt.validate(t, result) + }) + } +} + +// TestBuildConnConfig_WithSSLCertificates tests SSL certificate configuration +func TestBuildConnConfig_WithSSLCertificates(t *testing.T) { + rootCertPath, clientCertPath, clientKeyPath, cleanup := generateTestCerts(t) + defer cleanup() + + tests := []struct { + name string + cfg config.PostgreSQLConfig + validate func(*testing.T, *pgx.ConnConfig) + }{ + { + name: "verify-full with all certificates", + cfg: config.PostgreSQLConfig{ + Host: "db.example.com", + Port: 5432, + User: "testuser", + Password: "my secure password", + Name: "testdb", + SSLMode: "verify-full", + SSLRootCert: rootCertPath, + SSLCert: clientCertPath, + SSLKey: clientKeyPath, + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + require.NotNil(t, cc.TLSConfig) + assert.False(t, cc.TLSConfig.InsecureSkipVerify) + assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName) + assert.NotNil(t, cc.TLSConfig.RootCAs) + assert.Len(t, cc.TLSConfig.Certificates, 1) + }, + }, + { + name: "verify-ca with root cert only", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Name: "testdb", + SSLMode: "verify-ca", + SSLRootCert: rootCertPath, + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + require.NotNil(t, cc.TLSConfig) + assert.False(t, cc.TLSConfig.InsecureSkipVerify) + assert.NotNil(t, cc.TLSConfig.RootCAs) + assert.Empty(t, cc.TLSConfig.Certificates) + }, + }, + { + name: "require with client cert", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Name: "testdb", + SSLMode: "require", + SSLCert: clientCertPath, + SSLKey: clientKeyPath, + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + require.NotNil(t, cc.TLSConfig) + assert.True(t, cc.TLSConfig.InsecureSkipVerify) + assert.Len(t, cc.TLSConfig.Certificates, 1) + }, + }, + { + name: "full configuration with SSL and special password", + cfg: config.PostgreSQLConfig{ + Host: "db.example.com", + Port: 5433, + User: "admin", + Password: "my super secret password!@#", + Name: "production", + SSLMode: "verify-full", + SSLRootCert: rootCertPath, + SSLCert: clientCertPath, + SSLKey: clientKeyPath, + DefaultSchema: "app_schema", + ConnOptions: "application_name=authentik", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "db.example.com", cc.Host) + assert.Equal(t, uint16(5433), cc.Port) + assert.Equal(t, "admin", cc.User) + assert.Equal(t, "my super secret password!@#", cc.Password) + assert.Equal(t, "production", cc.Database) + require.NotNil(t, cc.TLSConfig) + assert.False(t, cc.TLSConfig.InsecureSkipVerify) + assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName) + assert.NotNil(t, cc.TLSConfig.RootCAs) + assert.Len(t, cc.TLSConfig.Certificates, 1) + assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"]) + assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := BuildConnConfig(tt.cfg) + require.NoError(t, err) + require.NotNil(t, result) + tt.validate(t, result) + }) + } +} + +// TestBuildDSN_WithSpecialPasswords tests that BuildDSN can handle passwords with special characters +// by verifying the DSN can actually be used to connect to a database +func TestBuildDSN_WithSpecialPasswords(t *testing.T) { + tests := []struct { + name string + password string + }{ + {"space in password", "my password"}, + {"multiple spaces", "the quick brown fox"}, + {"single quote", "pass'word"}, + {"backslash", `pass\word`}, + {"double quote", `pass"word`}, + {"special chars", `p@ss!#$%^&*()`}, + {"mixed special", `my'pass\word"here`}, + {"unicode", "pässwörd"}, + {"leading/trailing spaces", " password "}, + {"tab character", "pass\tword"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: tt.password, + Name: "testdb", + } + + // Test that BuildDSN doesn't error + dsn, err := BuildDSN(cfg) + require.NoError(t, err) + require.NotEmpty(t, dsn) + + // Test that BuildConnConfig preserves the password exactly + connConfig, err := BuildConnConfig(cfg) + require.NoError(t, err) + assert.Equal(t, tt.password, connConfig.Password, "Password should be preserved exactly") }) } } @@ -715,3 +950,89 @@ func createSessionData(t *testing.T, claims map[string]interface{}) string { require.NoError(t, err) return string(sessionDataJSON) } + +// generateTestCerts creates temporary SSL certificates for testing +func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPath string, cleanup func()) { + tmpDir := t.TempDir() + + // Generate CA certificate + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test CA"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + require.NoError(t, err) + + // Write CA certificate + rootCertPath = filepath.Join(tmpDir, "root.crt") + rootCertFile, err := os.Create(rootCertPath) + require.NoError(t, err) + defer func() { + if closeErr := rootCertFile.Close(); closeErr != nil { + t.Logf("failed to close root cert file: %v", closeErr) + } + }() + err = pem.Encode(rootCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: caCertDER}) + require.NoError(t, err) + + // Generate client key + clientKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + // Generate client certificate + clientTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Client"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + clientCertDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey) + require.NoError(t, err) + + // Write client certificate + clientCertPath = filepath.Join(tmpDir, "client.crt") + clientCertFile, err := os.Create(clientCertPath) + require.NoError(t, err) + defer func() { + if closeErr := clientCertFile.Close(); closeErr != nil { + t.Logf("failed to close client cert file: %v", closeErr) + } + }() + err = pem.Encode(clientCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER}) + require.NoError(t, err) + + // Write client key + clientKeyPath = filepath.Join(tmpDir, "client.key") + clientKeyFile, err := os.Create(clientKeyPath) + require.NoError(t, err) + defer func() { + if closeErr := clientKeyFile.Close(); closeErr != nil { + t.Logf("failed to close client key file: %v", closeErr) + } + }() + clientKeyBytes := x509.MarshalPKCS1PrivateKey(clientKey) + err = pem.Encode(clientKeyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: clientKeyBytes}) + require.NoError(t, err) + + cleanup = func() { + // TempDir cleanup is automatic in Go tests + } + + return rootCertPath, clientCertPath, clientKeyPath, cleanup +}