From 41b30c91d9ea452c02ec541ef7c525fc4f4c6253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 7 Oct 2025 13:41:25 +0800 Subject: [PATCH] Improve HTTPS DNS transport --- common/tls/client.go | 40 ++++++++++++---- dns/transport/https.go | 36 ++------------ dns/transport/https_transport.go | 80 ++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 41 deletions(-) create mode 100644 dns/transport/https_transport.go diff --git a/common/tls/client.go b/common/tls/client.go index 5e05c990..afdb5a42 100644 --- a/common/tls/client.go +++ b/common/tls/client.go @@ -53,26 +53,48 @@ func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, e return tlsConn, nil } -type Dialer struct { +type Dialer interface { + N.Dialer + DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error) +} + +type defaultDialer struct { dialer N.Dialer config Config } -func NewDialer(dialer N.Dialer, config Config) N.Dialer { - return &Dialer{dialer, config} +func NewDialer(dialer N.Dialer, config Config) Dialer { + return &defaultDialer{dialer, config} } -func (d *Dialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - if network != N.NetworkTCP { +func (d *defaultDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if N.NetworkName(network) != N.NetworkTCP { return nil, os.ErrInvalid } - conn, err := d.dialer.DialContext(ctx, network, destination) + return d.DialTLSContext(ctx, destination) +} + +func (d *defaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return nil, os.ErrInvalid +} + +func (d *defaultDialer) DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error) { + return d.dialContext(ctx, destination) +} + +func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr) (Conn, error) { + conn, err := d.dialer.DialContext(ctx, N.NetworkTCP, destination) if err != nil { return nil, err } - return ClientHandshake(ctx, conn, d.config) + tlsConn, err := aTLS.ClientHandshake(ctx, conn, d.config) + if err != nil { + conn.Close() + return nil, err + } + return tlsConn, nil } -func (d *Dialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return nil, os.ErrInvalid +func (d *defaultDialer) Upstream() any { + return d.dialer } diff --git a/dns/transport/https.go b/dns/transport/https.go index 7d56f45e..30c2a11f 100644 --- a/dns/transport/https.go +++ b/dns/transport/https.go @@ -25,7 +25,6 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - aTLS "github.com/sagernet/sing/common/tls" sHTTP "github.com/sagernet/sing/protocol/http" mDNS "github.com/miekg/dns" @@ -47,7 +46,7 @@ type HTTPSTransport struct { destination *url.URL headers http.Header transportAccess sync.Mutex - transport *http.Transport + transport *HTTPSTransportWrapper transportResetAt time.Time } @@ -62,11 +61,8 @@ func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options if err != nil { return nil, err } - if common.Error(tlsConfig.Config()) == nil && !common.Contains(tlsConfig.NextProtos(), http2.NextProtoTLS) { - tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), http2.NextProtoTLS)) - } - if !common.Contains(tlsConfig.NextProtos(), "http/1.1") { - tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), "http/1.1")) + if len(tlsConfig.NextProtos()) == 0 { + tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"}) } headers := options.Headers.Build() host := headers.Get("Host") @@ -124,37 +120,13 @@ func NewHTTPSRaw( serverAddr M.Socksaddr, tlsConfig tls.Config, ) *HTTPSTransport { - var transport *http.Transport - if tlsConfig != nil { - transport = &http.Transport{ - ForceAttemptHTTP2: true, - DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - tcpConn, hErr := dialer.DialContext(ctx, network, serverAddr) - if hErr != nil { - return nil, hErr - } - tlsConn, hErr := aTLS.ClientHandshake(ctx, tcpConn, tlsConfig) - if hErr != nil { - tcpConn.Close() - return nil, hErr - } - return tlsConn, nil - }, - } - } else { - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.DialContext(ctx, network, serverAddr) - }, - } - } return &HTTPSTransport{ TransportAdapter: adapter, logger: logger, dialer: dialer, destination: destination, headers: headers, - transport: transport, + transport: NewHTTPSTransportWrapper(tls.NewDialer(dialer, tlsConfig), serverAddr), } } diff --git a/dns/transport/https_transport.go b/dns/transport/https_transport.go new file mode 100644 index 00000000..84cfa17c --- /dev/null +++ b/dns/transport/https_transport.go @@ -0,0 +1,80 @@ +package transport + +import ( + "context" + "errors" + "net" + "net/http" + "sync/atomic" + + "github.com/sagernet/sing-box/common/tls" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "golang.org/x/net/http2" +) + +var errFallback = E.New("fallback to HTTP/1.1") + +type HTTPSTransportWrapper struct { + http2Transport *http2.Transport + httpTransport *http.Transport + fallback *atomic.Bool +} + +func NewHTTPSTransportWrapper(dialer tls.Dialer, serverAddr M.Socksaddr) *HTTPSTransportWrapper { + var fallback atomic.Bool + return &HTTPSTransportWrapper{ + http2Transport: &http2.Transport{ + DialTLSContext: func(ctx context.Context, _, _ string, _ *tls.STDConfig) (net.Conn, error) { + tlsConn, err := dialer.DialTLSContext(ctx, serverAddr) + if err != nil { + return nil, err + } + state := tlsConn.ConnectionState() + if state.NegotiatedProtocol == http2.NextProtoTLS { + return tlsConn, nil + } + tlsConn.Close() + fallback.Store(true) + return nil, errFallback + }, + }, + httpTransport: &http.Transport{ + DialTLSContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialTLSContext(ctx, serverAddr) + }, + }, + fallback: &fallback, + } +} + +func (h *HTTPSTransportWrapper) RoundTrip(request *http.Request) (*http.Response, error) { + if h.fallback.Load() { + return h.httpTransport.RoundTrip(request) + } else { + response, err := h.http2Transport.RoundTrip(request) + if err != nil { + if errors.Is(err, errFallback) { + return h.httpTransport.RoundTrip(request) + } + return nil, err + } + return response, nil + } +} + +func (h *HTTPSTransportWrapper) CloseIdleConnections() { + h.http2Transport.CloseIdleConnections() + h.httpTransport.CloseIdleConnections() +} + +func (h *HTTPSTransportWrapper) Clone() *HTTPSTransportWrapper { + return &HTTPSTransportWrapper{ + httpTransport: h.httpTransport, + http2Transport: &http2.Transport{ + DialTLSContext: h.http2Transport.DialTLSContext, + }, + fallback: h.fallback, + } +}