From 3d7a770c2283d356c8877cc4e84e5f797640946d Mon Sep 17 00:00:00 2001 From: Romain Date: Wed, 17 Jun 2026 17:30:08 +0200 Subject: [PATCH] Fix nondeterministic TLS certificate selection on shared SAN --- pkg/tls/certificate_store.go | 86 +++++++++++++++---------------- pkg/tls/certificate_store_test.go | 24 +++++++++ 2 files changed, 66 insertions(+), 44 deletions(-) diff --git a/pkg/tls/certificate_store.go b/pkg/tls/certificate_store.go index 5321f98aa..9516869ec 100644 --- a/pkg/tls/certificate_store.go +++ b/pkg/tls/certificate_store.go @@ -3,7 +3,9 @@ package tls import ( "crypto/tls" "fmt" + "maps" "net" + "slices" "sort" "strings" "time" @@ -76,6 +78,7 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo) if c == nil { return nil } + serverName := strings.ToLower(strings.TrimSpace(clientHello.ServerName)) if len(serverName) == 0 { // If no ServerName is provided, Check for local IP address matches @@ -99,40 +102,33 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo) return certificateData.Certificate } - matchedCerts := map[string]*CertificateData{} if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil { - for domains, cert := range c.DynamicCerts.Get().(map[string]*CertificateData) { - for certDomain := range strings.SplitSeq(domains, ",") { - if matchDomain(serverName, certDomain) { - matchedCerts[certDomain] = cert + certs := c.DynamicCerts.Get().(map[string]*CertificateData) + // sorted cert sans identifiers + sorted := slices.SortedFunc(maps.Keys(certs), func(certKey string, certKey2 string) int { + // reverse sort. + return strings.Compare(certKey2, certKey) + }) + + for _, certDomains := range sorted { + if matchDomain(serverName, certDomains) { + // cache best match + certificateData := certs[certDomains] + c.CertCache.SetDefault(serverName, certificateData) + + if c.ocspStapler != nil && certificateData.Hash != "" { + if staple, ok := c.ocspStapler.GetStaple(certificateData.Hash); ok { + // We are updating the OCSPStaple of the certificate without any synchronization + // as this should not cause any issue. + certificateData.Certificate.OCSPStaple = staple + } } + + return certificateData.Certificate } } } - if len(matchedCerts) > 0 { - // sort map by keys - keys := make([]string, 0, len(matchedCerts)) - for k := range matchedCerts { - keys = append(keys, k) - } - sort.Strings(keys) - - // cache best match - certificateData := matchedCerts[keys[len(keys)-1]] - c.CertCache.SetDefault(serverName, certificateData) - - if c.ocspStapler != nil && certificateData.Hash != "" { - if staple, ok := c.ocspStapler.GetStaple(certificateData.Hash); ok { - // We are updating the OCSPStaple of the certificate without any synchronization - // as this should not cause any issue. - certificateData.Certificate.OCSPStaple = staple - } - } - - return certificateData.Certificate - } - return nil } @@ -267,26 +263,28 @@ func parseCertificate(cert *Certificate) (tls.Certificate, []string, error) { return tlsCert, SANs, err } -// matchDomain returns whether the server name matches the cert domain. +// matchDomain returns whether the server name matches the cert domains. // The server name, from TLS SNI, must not have trailing dots (https://datatracker.ietf.org/doc/html/rfc6066#section-3). // This is enforced by https://github.com/golang/go/blob/d3d7998756c33f69706488cade1cd2b9b10a4c7f/src/crypto/tls/handshake_messages.go#L423-L427. -func matchDomain(serverName, certDomain string) bool { - // TODO: assert equality after removing the trailing dots? - if serverName == certDomain { - return true - } - - for len(certDomain) > 0 && certDomain[len(certDomain)-1] == '.' { - certDomain = certDomain[:len(certDomain)-1] - } - - labels := strings.Split(serverName, ".") - for i := range labels { - labels[i] = "*" - candidate := strings.Join(labels, ".") - if certDomain == candidate { +func matchDomain(serverName, certDomains string) bool { + for certDomain := range strings.SplitSeq(certDomains, ",") { + // TODO: assert equality after removing the trailing dots? + if serverName == certDomain { return true } + + for len(certDomain) > 0 && certDomain[len(certDomain)-1] == '.' { + certDomain = certDomain[:len(certDomain)-1] + } + + labels := strings.Split(serverName, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if certDomain == candidate { + return true + } + } } return false } diff --git a/pkg/tls/certificate_store_test.go b/pkg/tls/certificate_store_test.go index ef4fd3885..6e8398fcd 100644 --- a/pkg/tls/certificate_store_test.go +++ b/pkg/tls/certificate_store_test.go @@ -88,6 +88,30 @@ func TestGetBestCertificate(t *testing.T) { } } +// TestGetBestCertificate_SharedSAN ensures the selection stays deterministic +// when distinct certificates share a SAN matching the server name (https://github.com/traefik/traefik/issues/13286). +func TestGetBestCertificate_SharedSAN(t *testing.T) { + wildcardCert := &CertificateData{Certificate: &tls.Certificate{}} + exactCert := &CertificateData{Certificate: &tls.Certificate{}} + + // Both certificates have a SAN matching app.example.test, but the exact-only + // certificate must always win as its identifier sorts last. + for range 100 { + dynamicMap := map[string]*CertificateData{ + "*.app.example.test,app.example.test": wildcardCert, + "app.example.test": exactCert, + } + + store := &CertificateStore{ + DynamicCerts: safe.New(dynamicMap), + CertCache: cache.New(1*time.Hour, 10*time.Minute), + } + + clientHello := &tls.ClientHelloInfo{ServerName: "app.example.test"} + assert.Same(t, exactCert.Certificate, store.GetBestCertificate(clientHello)) + } +} + func loadTestCert(certName string, uppercase bool) (*tls.Certificate, error) { replacement := "wildcard" if uppercase {