From 0889ddd001188e9112555a2f1701c0b35d5ad06d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 Mar 2026 17:48:10 +0800 Subject: [PATCH] Fix connector canceled dial cleanup --- dns/transport/connector.go | 94 +++++++++++++------- dns/transport/connector_test.go | 146 +++++++++++++++++++++++++++++++- 2 files changed, 209 insertions(+), 31 deletions(-) diff --git a/dns/transport/connector.go b/dns/transport/connector.go index 769232f4..3a87456d 100644 --- a/dns/transport/connector.go +++ b/dns/transport/connector.go @@ -55,6 +55,12 @@ type contextKeyConnecting struct{} var errRecursiveConnectorDial = E.New("recursive connector dial") +type connectorDialResult[T any] struct { + connection T + cancel context.CancelFunc + err error +} + func (c *Connector[T]) Get(ctx context.Context) (T, error) { var zero T for { @@ -100,41 +106,37 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) { return zero, err } - c.connecting = make(chan struct{}) + connecting := make(chan struct{}) + c.connecting = connecting + dialContext := context.WithValue(ctx, contextKeyConnecting{}, c) + dialResult := make(chan connectorDialResult[T], 1) c.access.Unlock() - dialContext := context.WithValue(ctx, contextKeyConnecting{}, c) - connection, cancel, err := c.dialWithCancellation(dialContext) + go func() { + connection, cancel, err := c.dialWithCancellation(dialContext) + dialResult <- connectorDialResult[T]{ + connection: connection, + cancel: cancel, + err: err, + } + }() - c.access.Lock() - close(c.connecting) - c.connecting = nil - - if err != nil { - c.access.Unlock() - return zero, err - } - - if c.closed { - cancel() - c.callbacks.Close(connection) - c.access.Unlock() + select { + case result := <-dialResult: + return c.completeDial(ctx, connecting, result) + case <-ctx.Done(): + go func() { + result := <-dialResult + _, _ = c.completeDial(ctx, connecting, result) + }() + return zero, ctx.Err() + case <-c.closeCtx.Done(): + go func() { + result := <-dialResult + _, _ = c.completeDial(ctx, connecting, result) + }() return zero, ErrTransportClosed } - if err = ctx.Err(); err != nil { - cancel() - c.callbacks.Close(connection) - c.access.Unlock() - return zero, err - } - - c.connection = connection - c.hasConnection = true - c.connectionCancel = cancel - result := c.connection - c.access.Unlock() - - return result, nil } } @@ -143,6 +145,38 @@ func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T return loaded && dialConnector == connector } +func (c *Connector[T]) completeDial(ctx context.Context, connecting chan struct{}, result connectorDialResult[T]) (T, error) { + var zero T + + c.access.Lock() + defer c.access.Unlock() + defer func() { + if c.connecting == connecting { + c.connecting = nil + } + close(connecting) + }() + + if result.err != nil { + return zero, result.err + } + if c.closed || c.closeCtx.Err() != nil { + result.cancel() + c.callbacks.Close(result.connection) + return zero, ErrTransportClosed + } + if err := ctx.Err(); err != nil { + result.cancel() + c.callbacks.Close(result.connection) + return zero, err + } + + c.connection = result.connection + c.hasConnection = true + c.connectionCancel = result.cancel + return c.connection, nil +} + func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) { var zero T if err := ctx.Err(); err != nil { diff --git a/dns/transport/connector_test.go b/dns/transport/connector_test.go index 280e5da6..309b28c8 100644 --- a/dns/transport/connector_test.go +++ b/dns/transport/connector_test.go @@ -188,13 +188,157 @@ func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) { err := <-result require.ErrorIs(t, err, context.Canceled) require.EqualValues(t, 1, dialCount.Load()) - require.EqualValues(t, 1, closeCount.Load()) + require.Eventually(t, func() bool { + return closeCount.Load() == 1 + }, time.Second, 10*time.Millisecond) _, err = connector.Get(context.Background()) require.NoError(t, err) require.EqualValues(t, 2, dialCount.Load()) } +func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) { + t.Parallel() + + var ( + dialCount atomic.Int32 + closeCount atomic.Int32 + ) + dialStarted := make(chan struct{}, 1) + releaseDial := make(chan struct{}) + + connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + dialCount.Add(1) + select { + case dialStarted <- struct{}{}: + default: + } + <-releaseDial + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) { + closeCount.Add(1) + }, + Reset: func(connection *testConnectorConnection) {}, + }) + + requestContext, cancel := context.WithCancel(context.Background()) + result := make(chan error, 1) + go func() { + _, err := connector.Get(requestContext) + result <- err + }() + + <-dialStarted + cancel() + + select { + case err := <-result: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + t.Fatal("Get did not return after request cancel") + } + + require.EqualValues(t, 1, dialCount.Load()) + require.EqualValues(t, 0, closeCount.Load()) + + close(releaseDial) + + require.Eventually(t, func() bool { + return closeCount.Load() == 1 + }, time.Second, 10*time.Millisecond) + + _, err := connector.Get(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 2, dialCount.Load()) +} + +func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) { + t.Parallel() + + var ( + dialCount atomic.Int32 + closeCount atomic.Int32 + ) + firstDialStarted := make(chan struct{}, 1) + secondDialStarted := make(chan struct{}, 1) + releaseFirstDial := make(chan struct{}) + + connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { + attempt := dialCount.Add(1) + switch attempt { + case 1: + select { + case firstDialStarted <- struct{}{}: + default: + } + <-releaseFirstDial + case 2: + select { + case secondDialStarted <- struct{}{}: + default: + } + } + return &testConnectorConnection{}, nil + }, ConnectorCallbacks[*testConnectorConnection]{ + IsClosed: func(connection *testConnectorConnection) bool { + return false + }, + Close: func(connection *testConnectorConnection) { + closeCount.Add(1) + }, + Reset: func(connection *testConnectorConnection) {}, + }) + + requestContext, cancel := context.WithCancel(context.Background()) + firstResult := make(chan error, 1) + go func() { + _, err := connector.Get(requestContext) + firstResult <- err + }() + + <-firstDialStarted + cancel() + + secondResult := make(chan error, 1) + go func() { + _, err := connector.Get(context.Background()) + secondResult <- err + }() + + select { + case <-secondDialStarted: + t.Fatal("second dial started before first dial completed") + case <-time.After(100 * time.Millisecond): + } + + select { + case err := <-firstResult: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + t.Fatal("first Get did not return after request cancel") + } + + close(releaseFirstDial) + + require.Eventually(t, func() bool { + return closeCount.Load() == 1 + }, time.Second, 10*time.Millisecond) + + select { + case <-secondDialStarted: + case <-time.After(time.Second): + t.Fatal("second dial did not start after first dial completed") + } + + err := <-secondResult + require.NoError(t, err) + require.EqualValues(t, 2, dialCount.Load()) +} + func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) { t.Parallel()