Fix atomic pointer usages

This commit is contained in:
世界
2025-08-14 01:48:42 +08:00
parent 8752b631bd
commit fdc181106d
7 changed files with 100 additions and 78 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing-box/adapter"
@@ -29,7 +30,7 @@ type Client struct {
serverAddr string
serviceName string
dialOptions []grpc.DialOption
conn *grpc.ClientConn
conn atomic.Pointer[grpc.ClientConn]
connAccess sync.Mutex
}
@@ -74,13 +75,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
}
func (c *Client) connect() (*grpc.ClientConn, error) {
conn := c.conn
conn := c.conn.Load()
if conn != nil && conn.GetState() != connectivity.Shutdown {
return conn, nil
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
conn = c.conn
conn = c.conn.Load()
if conn != nil && conn.GetState() != connectivity.Shutdown {
return conn, nil
}
@@ -89,7 +90,7 @@ func (c *Client) connect() (*grpc.ClientConn, error) {
if err != nil {
return nil, err
}
c.conn = conn
c.conn.Store(conn)
return conn, nil
}
@@ -109,11 +110,9 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
}
func (c *Client) Close() error {
c.connAccess.Lock()
defer c.connAccess.Unlock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
conn := c.conn.Swap(nil)
if conn != nil {
conn.Close()
}
return nil
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-quic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@@ -29,7 +30,7 @@ type Client struct {
tlsConfig tls.Config
quicConfig *quic.Config
connAccess sync.Mutex
conn quic.Connection
conn atomic.TypedValue[quic.Connection]
rawConn net.Conn
}
@@ -50,13 +51,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
}
func (c *Client) offer() (quic.Connection, error) {
conn := c.conn
conn := c.conn.Load()
if conn != nil && !common.Done(conn.Context()) {
return conn, nil
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
conn = c.conn
conn = c.conn.Load()
if conn != nil && !common.Done(conn.Context()) {
return conn, nil
}
@@ -78,7 +79,7 @@ func (c *Client) offerNew() (quic.Connection, error) {
packetConn.Close()
return nil, err
}
c.conn = quicConn
c.conn.Store(quicConn)
c.rawConn = udpConn
return quicConn, nil
}
@@ -98,13 +99,13 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
func (c *Client) Close() error {
c.connAccess.Lock()
defer c.connAccess.Unlock()
if c.conn != nil {
c.conn.CloseWithError(0, "")
conn := c.conn.Swap(nil)
if conn != nil {
conn.CloseWithError(0, "")
}
if c.rawConn != nil {
c.rawConn.Close()
}
c.conn = nil
c.rawConn = nil
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"net"
"os"
"sync"
"sync/atomic"
"time"
C "github.com/sagernet/sing-box/constant"
@@ -135,20 +136,22 @@ func (c *WebsocketConn) Upstream() any {
type EarlyWebsocketConn struct {
*Client
ctx context.Context
conn *WebsocketConn
conn atomic.Pointer[WebsocketConn]
access sync.Mutex
create chan struct{}
err error
}
func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
if c.conn == nil {
conn := c.conn.Load()
if conn == nil {
<-c.create
if c.err != nil {
return 0, c.err
}
conn = c.conn.Load()
}
return wrapWsError0(c.conn.Read(b))
return wrapWsError0(conn.Read(b))
}
func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
@@ -187,21 +190,23 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
return err
}
}
c.conn = conn
c.conn.Store(conn)
return nil
}
func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
if c.conn != nil {
return wrapWsError0(c.conn.Write(b))
conn := c.conn.Load()
if conn != nil {
return wrapWsError0(conn.Write(b))
}
c.access.Lock()
defer c.access.Unlock()
conn = c.conn.Load()
if c.err != nil {
return 0, c.err
}
if c.conn != nil {
return wrapWsError0(c.conn.Write(b))
if conn != nil {
return wrapWsError0(conn.Write(b))
}
err = c.writeRequest(b)
c.err = err
@@ -213,17 +218,19 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
}
func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
if c.conn != nil {
return wrapWsError(c.conn.WriteBuffer(buffer))
conn := c.conn.Load()
if conn != nil {
return wrapWsError(conn.WriteBuffer(buffer))
}
c.access.Lock()
defer c.access.Unlock()
if c.conn != nil {
return wrapWsError(c.conn.WriteBuffer(buffer))
}
if c.err != nil {
return c.err
}
conn = c.conn.Load()
if conn != nil {
return wrapWsError(conn.WriteBuffer(buffer))
}
err := c.writeRequest(buffer.Bytes())
c.err = err
close(c.create)
@@ -231,24 +238,27 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
}
func (c *EarlyWebsocketConn) Close() error {
if c.conn == nil {
conn := c.conn.Load()
if conn == nil {
return nil
}
return c.conn.Close()
return conn.Close()
}
func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
if c.conn == nil {
conn := c.conn.Load()
if conn == nil {
return M.Socksaddr{}
}
return c.conn.LocalAddr()
return conn.LocalAddr()
}
func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
if c.conn == nil {
conn := c.conn.Load()
if conn == nil {
return M.Socksaddr{}
}
return c.conn.RemoteAddr()
return conn.RemoteAddr()
}
func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
@@ -268,11 +278,11 @@ func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool {
}
func (c *EarlyWebsocketConn) Upstream() any {
return common.PtrOrNil(c.conn)
return common.PtrOrNil(c.conn.Load())
}
func (c *EarlyWebsocketConn) LazyHeadroom() bool {
return c.conn == nil
return c.conn.Load() == nil
}
func wrapWsError(err error) error {