From 27c5b0b1aff2cd77ac367982e7f308377eab42b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 6 Mar 2026 14:53:03 +0800 Subject: [PATCH] Fix DNS exchange failure and recursion deadlock in connector Co-authored-by: everyx --- dns/transport/connector.go | 110 +++++++++++-- dns/transport/connector_test.go | 263 ++++++++++++++++++++++++++++++++ transport/v2raygrpc/client.go | 2 +- transport/v2raygrpc/conn.go | 14 +- transport/v2raygrpc/server.go | 2 +- 5 files changed, 373 insertions(+), 18 deletions(-) create mode 100644 dns/transport/connector_test.go diff --git a/dns/transport/connector.go b/dns/transport/connector.go index 18fad0a5..769232f4 100644 --- a/dns/transport/connector.go +++ b/dns/transport/connector.go @@ -4,6 +4,9 @@ import ( "context" "net" "sync" + "time" + + E "github.com/sagernet/sing/common/exceptions" ) type ConnectorCallbacks[T any] struct { @@ -16,10 +19,11 @@ type Connector[T any] struct { dial func(ctx context.Context) (T, error) callbacks ConnectorCallbacks[T] - access sync.Mutex - connection T - hasConnection bool - connecting chan struct{} + access sync.Mutex + connection T + hasConnection bool + connectionCancel context.CancelFunc + connecting chan struct{} closeCtx context.Context closed bool @@ -47,6 +51,10 @@ func NewSingleflightConnector(closeCtx context.Context, dial func(context.Contex }) } +type contextKeyConnecting struct{} + +var errRecursiveConnectorDial = E.New("recursive connector dial") + func (c *Connector[T]) Get(ctx context.Context) (T, error) { var zero T for { @@ -64,6 +72,14 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) { } c.hasConnection = false + if c.connectionCancel != nil { + c.connectionCancel() + c.connectionCancel = nil + } + if isRecursiveConnectorDial(ctx, c) { + c.access.Unlock() + return zero, errRecursiveConnectorDial + } if c.connecting != nil { connecting := c.connecting @@ -79,10 +95,16 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) { } } + if err := ctx.Err(); err != nil { + c.access.Unlock() + return zero, err + } + c.connecting = make(chan struct{}) c.access.Unlock() - connection, err := c.dialWithCancellation(ctx) + dialContext := context.WithValue(ctx, contextKeyConnecting{}, c) + connection, cancel, err := c.dialWithCancellation(dialContext) c.access.Lock() close(c.connecting) @@ -94,13 +116,21 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) { } if c.closed { + cancel() c.callbacks.Close(connection) c.access.Unlock() return zero, ErrTransportClosed } + if err = ctx.Err(); err != nil { + cancel() + c.callbacks.Close(connection) + c.access.Unlock() + return zero, err + } c.connection = connection c.hasConnection = true + c.connectionCancel = cancel result := c.connection c.access.Unlock() @@ -108,19 +138,63 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) { } } -func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, error) { - dialCtx, cancel := context.WithCancel(ctx) - defer cancel() +func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T]) bool { + dialConnector, loaded := ctx.Value(contextKeyConnecting{}).(*Connector[T]) + return loaded && dialConnector == connector +} - go func() { - select { - case <-c.closeCtx.Done(): +func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) { + var zero T + if err := ctx.Err(); err != nil { + return zero, nil, err + } + connCtx, cancel := context.WithCancel(c.closeCtx) + + var ( + stateAccess sync.Mutex + dialComplete bool + ) + stopCancel := context.AfterFunc(ctx, func() { + stateAccess.Lock() + if !dialComplete { cancel() - case <-dialCtx.Done(): } - }() + stateAccess.Unlock() + }) + select { + case <-ctx.Done(): + stateAccess.Lock() + dialComplete = true + stateAccess.Unlock() + stopCancel() + cancel() + return zero, nil, ctx.Err() + default: + } - return c.dial(dialCtx) + connection, err := c.dial(valueContext{connCtx, ctx}) + stateAccess.Lock() + dialComplete = true + stateAccess.Unlock() + stopCancel() + if err != nil { + cancel() + return zero, nil, err + } + return connection, cancel, nil +} + +type valueContext struct { + context.Context + parent context.Context +} + +func (v valueContext) Value(key any) any { + return v.parent.Value(key) +} + +func (v valueContext) Deadline() (time.Time, bool) { + return v.parent.Deadline() } func (c *Connector[T]) Close() error { @@ -132,6 +206,10 @@ func (c *Connector[T]) Close() error { } c.closed = true + if c.connectionCancel != nil { + c.connectionCancel() + c.connectionCancel = nil + } if c.hasConnection { c.callbacks.Close(c.connection) c.hasConnection = false @@ -144,6 +222,10 @@ func (c *Connector[T]) Reset() { c.access.Lock() defer c.access.Unlock() + if c.connectionCancel != nil { + c.connectionCancel() + c.connectionCancel = nil + } if c.hasConnection { c.callbacks.Reset(c.connection) c.hasConnection = false diff --git a/dns/transport/connector_test.go b/dns/transport/connector_test.go new file mode 100644 index 00000000..280e5da6 --- /dev/null +++ b/dns/transport/connector_test.go @@ -0,0 +1,263 @@ +package transport + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type testConnectorConnection struct{} + +func TestConnectorRecursiveGetFailsFast(t *testing.T) { + t.Parallel() + + var ( + dialCount atomic.Int32 + closeCount atomic.Int32 + connector *Connector[*testConnectorConnection] + ) + + dial := func(ctx context.Context) (*testConnectorConnection, error) { + dialCount.Add(1) + _, err := connector.Get(ctx) + if err != nil { + return nil, err + } + return &testConnectorConnection{}, nil + } + + connector = NewConnector(context.Background(), dial, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) { + closeCount.Add(1) + }, + Reset: func(connection *testConnectorConnection) { + closeCount.Add(1) + }, + }) + + _, err := connector.Get(context.Background()) + require.ErrorIs(t, err, errRecursiveConnectorDial) + require.EqualValues(t, 1, dialCount.Load()) + require.EqualValues(t, 0, closeCount.Load()) +} + +func TestConnectorRecursiveGetAcrossConnectorsAllowed(t *testing.T) { + t.Parallel() + + var ( + outerDialCount atomic.Int32 + innerDialCount atomic.Int32 + outerConnector *Connector[*testConnectorConnection] + innerConnector *Connector[*testConnectorConnection] + ) + + innerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + innerDialCount.Add(1) + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) {}, + Reset: func(connection *testConnectorConnection) {}, + }) + + outerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + outerDialCount.Add(1) + _, err := innerConnector.Get(ctx) + if err != nil { + return nil, err + } + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) {}, + Reset: func(connection *testConnectorConnection) {}, + }) + + _, err := outerConnector.Get(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 1, outerDialCount.Load()) + require.EqualValues(t, 1, innerDialCount.Load()) +} + +func TestConnectorDialContextPreservesValueAndDeadline(t *testing.T) { + t.Parallel() + + type contextKey struct{} + + var ( + dialValue any + dialDeadline time.Time + dialHasDeadline bool + ) + + connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + dialValue = ctx.Value(contextKey{}) + dialDeadline, dialHasDeadline = ctx.Deadline() + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) {}, + Reset: func(connection *testConnectorConnection) {}, + }) + + deadline := time.Now().Add(time.Minute) + requestContext, cancel := context.WithDeadline(context.WithValue(context.Background(), contextKey{}, "test-value"), deadline) + defer cancel() + + _, err := connector.Get(requestContext) + require.NoError(t, err) + require.Equal(t, "test-value", dialValue) + require.True(t, dialHasDeadline) + require.WithinDuration(t, deadline, dialDeadline, time.Second) +} + +func TestConnectorDialSkipsCanceledRequest(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + dialCount.Add(1) + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) {}, + Reset: func(connection *testConnectorConnection) {}, + }) + + requestContext, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := connector.Get(requestContext) + require.ErrorIs(t, err, context.Canceled) + require.EqualValues(t, 0, dialCount.Load()) +} + +func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) { + t.Parallel() + + var ( + dialCount atomic.Int32 + closeCount atomic.Int32 + ) + dialStarted := make(chan struct{}, 1) + releaseDial := make(chan struct{}) + + connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + dialCount.Add(1) + select { + case dialStarted <- struct{}{}: + default: + } + <-releaseDial + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) { + closeCount.Add(1) + }, + Reset: func(connection *testConnectorConnection) {}, + }) + + requestContext, cancel := context.WithCancel(context.Background()) + result := make(chan error, 1) + go func() { + _, err := connector.Get(requestContext) + result <- err + }() + + <-dialStarted + cancel() + close(releaseDial) + + err := <-result + require.ErrorIs(t, err, context.Canceled) + require.EqualValues(t, 1, dialCount.Load()) + require.EqualValues(t, 1, closeCount.Load()) + + _, err = connector.Get(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 2, dialCount.Load()) +} + +func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) { + t.Parallel() + + var dialContext context.Context + connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + dialContext = ctx + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) {}, + Reset: func(connection *testConnectorConnection) {}, + }) + + requestContext, cancel := context.WithCancel(context.Background()) + _, err := connector.Get(requestContext) + require.NoError(t, err) + require.NotNil(t, dialContext) + + cancel() + + select { + case <-dialContext.Done(): + t.Fatal("dial context canceled by request context after successful dial") + case <-time.After(100 * time.Millisecond): + } + + err = connector.Close() + require.NoError(t, err) +} + +func TestConnectorDialContextCanceledOnClose(t *testing.T) { + t.Parallel() + + var dialContext context.Context + connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + dialContext = ctx + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) {}, + Reset: func(connection *testConnectorConnection) {}, + }) + + _, err := connector.Get(context.Background()) + require.NoError(t, err) + require.NotNil(t, dialContext) + + select { + case <-dialContext.Done(): + t.Fatal("dial context canceled before connector close") + default: + } + + err = connector.Close() + require.NoError(t, err) + + select { + case <-dialContext.Done(): + case <-time.After(time.Second): + t.Fatal("dial context not canceled after connector close") + } +} diff --git a/transport/v2raygrpc/client.go b/transport/v2raygrpc/client.go index 2bbaa627..5af53856 100644 --- a/transport/v2raygrpc/client.go +++ b/transport/v2raygrpc/client.go @@ -106,7 +106,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { cancel(err) return nil, err } - return NewGRPCConn(stream), nil + return NewGRPCConn(stream, cancel), nil } func (c *Client) Close() error { diff --git a/transport/v2raygrpc/conn.go b/transport/v2raygrpc/conn.go index c29da4f9..87be9661 100644 --- a/transport/v2raygrpc/conn.go +++ b/transport/v2raygrpc/conn.go @@ -1,8 +1,10 @@ package v2raygrpc import ( + "context" "net" "os" + "sync" "time" "github.com/sagernet/sing/common/baderror" @@ -14,16 +16,19 @@ var _ net.Conn = (*GRPCConn)(nil) type GRPCConn struct { GunService - cache []byte + cache []byte + cancel context.CancelCauseFunc + closeOnce sync.Once } -func NewGRPCConn(service GunService) *GRPCConn { +func NewGRPCConn(service GunService, cancel context.CancelCauseFunc) *GRPCConn { //nolint:staticcheck if client, isClient := service.(GunService_TunClient); isClient { service = &clientConnWrapper{client} } return &GRPCConn{ GunService: service, + cancel: cancel, } } @@ -54,6 +59,11 @@ func (c *GRPCConn) Write(b []byte) (n int, err error) { } func (c *GRPCConn) Close() error { + c.closeOnce.Do(func() { + if c.cancel != nil { + c.cancel(nil) + } + }) return nil } diff --git a/transport/v2raygrpc/server.go b/transport/v2raygrpc/server.go index b6b13f82..4d426aa1 100644 --- a/transport/v2raygrpc/server.go +++ b/transport/v2raygrpc/server.go @@ -52,7 +52,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option. } func (s *Server) Tun(server GunService_TunServer) error { - conn := NewGRPCConn(server) + conn := NewGRPCConn(server, nil) var source M.Socksaddr if remotePeer, loaded := peer.FromContext(server.Context()); loaded { source = M.SocksaddrFromNet(remotePeer.Addr)