mirror of
https://github.com/traefik/traefik.git
synced 2026-06-18 19:38:23 +03:00
Fix connection upgrades when backend server is using h2c scheme
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -536,3 +537,64 @@ func (s *WebsocketSuite) TestHeaderAreForwarded() {
|
||||
require.NoError(s.T(), err)
|
||||
assert.Equal(s.T(), "OK", string(msg))
|
||||
}
|
||||
|
||||
func (s *WebsocketSuite) TestSSLh2c() {
|
||||
upgrader := gorillawebsocket.Upgrader{} // use default options
|
||||
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
ts.Config.Protocols = &http.Protocols{}
|
||||
ts.Config.Protocols.SetHTTP1(true)
|
||||
ts.Config.Protocols.SetUnencryptedHTTP2(true)
|
||||
ts.Start()
|
||||
|
||||
url, err := url.Parse(ts.URL)
|
||||
require.NoError(s.T(), err)
|
||||
url.Scheme = "h2c"
|
||||
|
||||
file := s.adaptFile("fixtures/websocket/config_https.toml", struct {
|
||||
WebsocketServer string
|
||||
}{
|
||||
WebsocketServer: url.String(),
|
||||
})
|
||||
|
||||
s.traefikCmd(withConfigFile(file), "--log.level=DEBUG", "--accesslog")
|
||||
|
||||
// wait for traefik
|
||||
err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 10*time.Second, try.BodyContains("127.0.0.1"))
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Add client self-signed cert
|
||||
roots := x509.NewCertPool()
|
||||
certContent, err := os.ReadFile("./resources/tls/local.cert")
|
||||
require.NoError(s.T(), err)
|
||||
roots.AppendCertsFromPEM(certContent)
|
||||
gorillawebsocket.DefaultDialer.TLSClientConfig = &tls.Config{
|
||||
RootCAs: roots,
|
||||
}
|
||||
conn, _, err := gorillawebsocket.DefaultDialer.Dial("wss://127.0.0.1:8000/echo", nil)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
_, msg, err := conn.ReadMessage()
|
||||
require.NoError(s.T(), err)
|
||||
assert.Equal(s.T(), "OK", string(msg))
|
||||
}
|
||||
|
||||
@@ -1,58 +1,33 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type h2cTransportWrapper struct {
|
||||
*http2.Transport
|
||||
}
|
||||
|
||||
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.URL.Scheme = "http"
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) {
|
||||
func newSmartRoundTripper(transport *http.Transport) *smartRoundTripper {
|
||||
// HTTP/1 only transport for requests with a Connection: Upgrade header.
|
||||
transportHTTP1 := transport.Clone()
|
||||
transportHTTP1.Protocols = new(http.Protocols)
|
||||
transportHTTP1.Protocols.SetHTTP1(true)
|
||||
|
||||
transportHTTP2, err := http2.ConfigureTransports(transport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Transport switching automatically to HTTP/2 with TLS ALPN.
|
||||
transportHTTP2 := transport.Clone()
|
||||
transportHTTP2.Protocols = new(http.Protocols)
|
||||
transportHTTP2.Protocols.SetHTTP1(true)
|
||||
transportHTTP2.Protocols.SetHTTP2(true)
|
||||
|
||||
if forwardingTimeouts != nil {
|
||||
transportHTTP2.ReadIdleTimeout = time.Duration(forwardingTimeouts.ReadIdleTimeout)
|
||||
transportHTTP2.PingTimeout = time.Duration(forwardingTimeouts.PingTimeout)
|
||||
}
|
||||
|
||||
transportH2C := &h2cTransportWrapper{
|
||||
Transport: &http2.Transport{
|
||||
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
return net.Dial(network, addr)
|
||||
},
|
||||
AllowHTTP: true,
|
||||
},
|
||||
}
|
||||
|
||||
if forwardingTimeouts != nil {
|
||||
transportH2C.ReadIdleTimeout = time.Duration(forwardingTimeouts.ReadIdleTimeout)
|
||||
transportH2C.PingTimeout = time.Duration(forwardingTimeouts.PingTimeout)
|
||||
}
|
||||
|
||||
transport.RegisterProtocol("h2c", transportH2C)
|
||||
// Transport speaking HTTP/2 with prior knowledge on unencrypted connections.
|
||||
transportH2C := transport.Clone()
|
||||
transportH2C.Protocols = new(http.Protocols)
|
||||
transportH2C.Protocols.SetUnencryptedHTTP2(true)
|
||||
|
||||
return &smartRoundTripper{
|
||||
http2: transport,
|
||||
http2: transportHTTP2,
|
||||
http: transportHTTP1,
|
||||
}, nil
|
||||
h2c: transportH2C,
|
||||
}
|
||||
}
|
||||
|
||||
// smartRoundTripper implements RoundTrip while making sure that HTTP/2 is not used
|
||||
@@ -60,19 +35,31 @@ func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic
|
||||
type smartRoundTripper struct {
|
||||
http2 *http.Transport
|
||||
http *http.Transport
|
||||
h2c *http.Transport
|
||||
}
|
||||
|
||||
func (m *smartRoundTripper) Clone() http.RoundTripper {
|
||||
h := m.http.Clone()
|
||||
h2 := m.http2.Clone()
|
||||
return &smartRoundTripper{http: h, http2: h2}
|
||||
return &smartRoundTripper{
|
||||
http2: m.http2.Clone(),
|
||||
http: m.http.Clone(),
|
||||
h2c: m.h2c.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *smartRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// If we have a connection upgrade, we don't use HTTP/2
|
||||
h2c := req.URL.Scheme == "h2c"
|
||||
if h2c {
|
||||
req.URL.Scheme = "http"
|
||||
}
|
||||
|
||||
// Connection upgrades cannot be carried over HTTP/2, they always use HTTP/1.
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
|
||||
return m.http.RoundTrip(req)
|
||||
}
|
||||
|
||||
if h2c {
|
||||
return m.h2c.RoundTrip(req)
|
||||
}
|
||||
|
||||
return m.http2.RoundTrip(req)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSmartRoundTripper(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = fmt.Fprint(rw, req.Proto)
|
||||
})
|
||||
|
||||
backend := httptest.NewUnstartedServer(handler)
|
||||
backend.Config.Protocols = new(http.Protocols)
|
||||
backend.Config.Protocols.SetHTTP1(true)
|
||||
backend.Config.Protocols.SetUnencryptedHTTP2(true)
|
||||
backend.Start()
|
||||
t.Cleanup(backend.Close)
|
||||
|
||||
tlsBackend := httptest.NewUnstartedServer(handler)
|
||||
tlsBackend.EnableHTTP2 = true
|
||||
tlsBackend.StartTLS()
|
||||
t.Cleanup(tlsBackend.Close)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
scheme string
|
||||
upgrade bool
|
||||
expectedProto string
|
||||
}{
|
||||
{
|
||||
desc: "h2c uses HTTP/2 with prior knowledge",
|
||||
scheme: "h2c",
|
||||
expectedProto: "HTTP/2.0",
|
||||
},
|
||||
{
|
||||
desc: "h2c with connection upgrade falls back to HTTP/1.1",
|
||||
scheme: "h2c",
|
||||
upgrade: true,
|
||||
expectedProto: "HTTP/1.1",
|
||||
},
|
||||
{
|
||||
desc: "http uses HTTP/1.1",
|
||||
scheme: "http",
|
||||
expectedProto: "HTTP/1.1",
|
||||
},
|
||||
{
|
||||
desc: "http with connection upgrade uses HTTP/1.1",
|
||||
scheme: "http",
|
||||
upgrade: true,
|
||||
expectedProto: "HTTP/1.1",
|
||||
},
|
||||
{
|
||||
desc: "https uses HTTP/2 negotiated with TLS ALPN",
|
||||
scheme: "https",
|
||||
expectedProto: "HTTP/2.0",
|
||||
},
|
||||
{
|
||||
desc: "https with connection upgrade falls back to HTTP/1.1",
|
||||
scheme: "https",
|
||||
upgrade: true,
|
||||
expectedProto: "HTTP/1.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rt := newSmartRoundTripper(&http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
})
|
||||
|
||||
targetURL := backend.URL
|
||||
switch test.scheme {
|
||||
case "https":
|
||||
targetURL = tlsBackend.URL
|
||||
case "h2c":
|
||||
targetURL = strings.Replace(targetURL, "http://", "h2c://", 1)
|
||||
}
|
||||
|
||||
proto := doProtoRequest(t, rt, targetURL, test.upgrade)
|
||||
assert.Equal(t, test.expectedProto, proto)
|
||||
|
||||
// The kerberos round tripper relies on Clone, which must preserve the protocol switching.
|
||||
proto = doProtoRequest(t, rt.Clone(), targetURL, test.upgrade)
|
||||
assert.Equal(t, test.expectedProto, proto)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func doProtoRequest(t *testing.T, rt http.RoundTripper, targetURL string, upgrade bool) string {
|
||||
t.Helper()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
if upgrade {
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
}
|
||||
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() { _ = res.Body.Close() })
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return string(body)
|
||||
}
|
||||
@@ -224,6 +224,16 @@ func (t *TransportManager) createRoundTripper(cfg *dynamic.ServersTransport, tls
|
||||
if cfg.ForwardingTimeouts != nil {
|
||||
transport.ResponseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
transport.IdleConnTimeout = time.Duration(cfg.ForwardingTimeouts.IdleConnTimeout)
|
||||
// The forwarding timeout names come from the x/net/http2.Transport fields (ReadIdleTimeout/PingTimeout),
|
||||
// which were used to configure the HTTP/2 health checks before the net/http native support (Go 1.24).
|
||||
// HTTP2Config.SendPingTimeout carries the same semantics as ReadIdleTimeout:
|
||||
// the delay without any frame received on a connection after which a ping health check is sent.
|
||||
// The field was renamed when the HTTP/2 configuration moved to net/http, see https://go.dev/issue/67813.
|
||||
// The HTTP2 config does not enable HTTP2 protocol.
|
||||
transport.HTTP2 = &http.HTTP2Config{
|
||||
SendPingTimeout: time.Duration(cfg.ForwardingTimeouts.ReadIdleTimeout),
|
||||
PingTimeout: time.Duration(cfg.ForwardingTimeouts.PingTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
// Return directly HTTP/1.1 transport when HTTP/2 is disabled
|
||||
@@ -236,10 +246,7 @@ func (t *TransportManager) createRoundTripper(cfg *dynamic.ServersTransport, tls
|
||||
}, nil
|
||||
}
|
||||
|
||||
rt, err := newSmartRoundTripper(transport, cfg.ForwardingTimeouts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rt := newSmartRoundTripper(transport)
|
||||
return &kerberosRoundTripper{
|
||||
OriginalRoundTripper: rt,
|
||||
new: func() http.RoundTripper {
|
||||
|
||||
Reference in New Issue
Block a user