Fix connection upgrades when backend server is using h2c scheme

This commit is contained in:
stffabi
2026-06-16 09:32:07 +02:00
committed by GitHub
parent 6336f6e9a8
commit 9f7bd55ddf
4 changed files with 225 additions and 49 deletions
+62
View File
@@ -7,6 +7,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"
@@ -536,3 +537,64 @@ func (s *WebsocketSuite) TestHeaderAreForwarded() {
require.NoError(s.T(), err)
assert.Equal(s.T(), "OK", string(msg))
}
func (s *WebsocketSuite) TestSSLh2c() {
upgrader := gorillawebsocket.Upgrader{} // use default options
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
ts.Config.Protocols = &http.Protocols{}
ts.Config.Protocols.SetHTTP1(true)
ts.Config.Protocols.SetUnencryptedHTTP2(true)
ts.Start()
url, err := url.Parse(ts.URL)
require.NoError(s.T(), err)
url.Scheme = "h2c"
file := s.adaptFile("fixtures/websocket/config_https.toml", struct {
WebsocketServer string
}{
WebsocketServer: url.String(),
})
s.traefikCmd(withConfigFile(file), "--log.level=DEBUG", "--accesslog")
// wait for traefik
err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 10*time.Second, try.BodyContains("127.0.0.1"))
require.NoError(s.T(), err)
// Add client self-signed cert
roots := x509.NewCertPool()
certContent, err := os.ReadFile("./resources/tls/local.cert")
require.NoError(s.T(), err)
roots.AppendCertsFromPEM(certContent)
gorillawebsocket.DefaultDialer.TLSClientConfig = &tls.Config{
RootCAs: roots,
}
conn, _, err := gorillawebsocket.DefaultDialer.Dial("wss://127.0.0.1:8000/echo", nil)
require.NoError(s.T(), err)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(s.T(), err)
_, msg, err := conn.ReadMessage()
require.NoError(s.T(), err)
assert.Equal(s.T(), "OK", string(msg))
}
+32 -45
View File
@@ -1,58 +1,33 @@
package service
import (
"crypto/tls"
"net"
"net/http"
"time"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2"
)
type h2cTransportWrapper struct {
*http2.Transport
}
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
return t.Transport.RoundTrip(req)
}
func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) {
func newSmartRoundTripper(transport *http.Transport) *smartRoundTripper {
// HTTP/1 only transport for requests with a Connection: Upgrade header.
transportHTTP1 := transport.Clone()
transportHTTP1.Protocols = new(http.Protocols)
transportHTTP1.Protocols.SetHTTP1(true)
transportHTTP2, err := http2.ConfigureTransports(transport)
if err != nil {
return nil, err
}
// Transport switching automatically to HTTP/2 with TLS ALPN.
transportHTTP2 := transport.Clone()
transportHTTP2.Protocols = new(http.Protocols)
transportHTTP2.Protocols.SetHTTP1(true)
transportHTTP2.Protocols.SetHTTP2(true)
if forwardingTimeouts != nil {
transportHTTP2.ReadIdleTimeout = time.Duration(forwardingTimeouts.ReadIdleTimeout)
transportHTTP2.PingTimeout = time.Duration(forwardingTimeouts.PingTimeout)
}
transportH2C := &h2cTransportWrapper{
Transport: &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return net.Dial(network, addr)
},
AllowHTTP: true,
},
}
if forwardingTimeouts != nil {
transportH2C.ReadIdleTimeout = time.Duration(forwardingTimeouts.ReadIdleTimeout)
transportH2C.PingTimeout = time.Duration(forwardingTimeouts.PingTimeout)
}
transport.RegisterProtocol("h2c", transportH2C)
// Transport speaking HTTP/2 with prior knowledge on unencrypted connections.
transportH2C := transport.Clone()
transportH2C.Protocols = new(http.Protocols)
transportH2C.Protocols.SetUnencryptedHTTP2(true)
return &smartRoundTripper{
http2: transport,
http2: transportHTTP2,
http: transportHTTP1,
}, nil
h2c: transportH2C,
}
}
// smartRoundTripper implements RoundTrip while making sure that HTTP/2 is not used
@@ -60,19 +35,31 @@ func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic
type smartRoundTripper struct {
http2 *http.Transport
http *http.Transport
h2c *http.Transport
}
func (m *smartRoundTripper) Clone() http.RoundTripper {
h := m.http.Clone()
h2 := m.http2.Clone()
return &smartRoundTripper{http: h, http2: h2}
return &smartRoundTripper{
http2: m.http2.Clone(),
http: m.http.Clone(),
h2c: m.h2c.Clone(),
}
}
func (m *smartRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// If we have a connection upgrade, we don't use HTTP/2
h2c := req.URL.Scheme == "h2c"
if h2c {
req.URL.Scheme = "http"
}
// Connection upgrades cannot be carried over HTTP/2, they always use HTTP/1.
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
return m.http.RoundTrip(req)
}
if h2c {
return m.h2c.RoundTrip(req)
}
return m.http2.RoundTrip(req)
}
@@ -0,0 +1,120 @@
package service
import (
"crypto/tls"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSmartRoundTripper(t *testing.T) {
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = fmt.Fprint(rw, req.Proto)
})
backend := httptest.NewUnstartedServer(handler)
backend.Config.Protocols = new(http.Protocols)
backend.Config.Protocols.SetHTTP1(true)
backend.Config.Protocols.SetUnencryptedHTTP2(true)
backend.Start()
t.Cleanup(backend.Close)
tlsBackend := httptest.NewUnstartedServer(handler)
tlsBackend.EnableHTTP2 = true
tlsBackend.StartTLS()
t.Cleanup(tlsBackend.Close)
testCases := []struct {
desc string
scheme string
upgrade bool
expectedProto string
}{
{
desc: "h2c uses HTTP/2 with prior knowledge",
scheme: "h2c",
expectedProto: "HTTP/2.0",
},
{
desc: "h2c with connection upgrade falls back to HTTP/1.1",
scheme: "h2c",
upgrade: true,
expectedProto: "HTTP/1.1",
},
{
desc: "http uses HTTP/1.1",
scheme: "http",
expectedProto: "HTTP/1.1",
},
{
desc: "http with connection upgrade uses HTTP/1.1",
scheme: "http",
upgrade: true,
expectedProto: "HTTP/1.1",
},
{
desc: "https uses HTTP/2 negotiated with TLS ALPN",
scheme: "https",
expectedProto: "HTTP/2.0",
},
{
desc: "https with connection upgrade falls back to HTTP/1.1",
scheme: "https",
upgrade: true,
expectedProto: "HTTP/1.1",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rt := newSmartRoundTripper(&http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
})
targetURL := backend.URL
switch test.scheme {
case "https":
targetURL = tlsBackend.URL
case "h2c":
targetURL = strings.Replace(targetURL, "http://", "h2c://", 1)
}
proto := doProtoRequest(t, rt, targetURL, test.upgrade)
assert.Equal(t, test.expectedProto, proto)
// The kerberos round tripper relies on Clone, which must preserve the protocol switching.
proto = doProtoRequest(t, rt.Clone(), targetURL, test.upgrade)
assert.Equal(t, test.expectedProto, proto)
})
}
}
func doProtoRequest(t *testing.T, rt http.RoundTripper, targetURL string, upgrade bool) string {
t.Helper()
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL, http.NoBody)
require.NoError(t, err)
if upgrade {
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
}
res, err := rt.RoundTrip(req)
require.NoError(t, err)
t.Cleanup(func() { _ = res.Body.Close() })
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
return string(body)
}
+11 -4
View File
@@ -224,6 +224,16 @@ func (t *TransportManager) createRoundTripper(cfg *dynamic.ServersTransport, tls
if cfg.ForwardingTimeouts != nil {
transport.ResponseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
transport.IdleConnTimeout = time.Duration(cfg.ForwardingTimeouts.IdleConnTimeout)
// The forwarding timeout names come from the x/net/http2.Transport fields (ReadIdleTimeout/PingTimeout),
// which were used to configure the HTTP/2 health checks before the net/http native support (Go 1.24).
// HTTP2Config.SendPingTimeout carries the same semantics as ReadIdleTimeout:
// the delay without any frame received on a connection after which a ping health check is sent.
// The field was renamed when the HTTP/2 configuration moved to net/http, see https://go.dev/issue/67813.
// The HTTP2 config does not enable HTTP2 protocol.
transport.HTTP2 = &http.HTTP2Config{
SendPingTimeout: time.Duration(cfg.ForwardingTimeouts.ReadIdleTimeout),
PingTimeout: time.Duration(cfg.ForwardingTimeouts.PingTimeout),
}
}
// Return directly HTTP/1.1 transport when HTTP/2 is disabled
@@ -236,10 +246,7 @@ func (t *TransportManager) createRoundTripper(cfg *dynamic.ServersTransport, tls
}, nil
}
rt, err := newSmartRoundTripper(transport, cfg.ForwardingTimeouts)
if err != nil {
return nil, err
}
rt := newSmartRoundTripper(transport)
return &kerberosRoundTripper{
OriginalRoundTripper: rt,
new: func() http.RoundTripper {