From e11dbf3a8ee6b25c1219ee53610c534bcff451d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 3 Feb 2026 02:51:05 +0800 Subject: [PATCH] bufio: Refactor copy --- go.mod | 2 +- go.sum | 4 +- route/conn.go | 151 ++++++++++++++++---------------------------------- 3 files changed, 51 insertions(+), 106 deletions(-) diff --git a/go.mod b/go.mod index 5ebde511..d7697105 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/sagernet/gomobile v0.1.11 github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1 github.com/sagernet/quic-go v0.59.0-sing-box-mod.2 - github.com/sagernet/sing v0.8.0-beta.15 + github.com/sagernet/sing v0.8.0-beta.15.0.20260202162209-7c477e13f41e github.com/sagernet/sing-mux v0.3.4 github.com/sagernet/sing-quic v0.6.0-beta.11 github.com/sagernet/sing-shadowsocks v0.2.8 diff --git a/go.sum b/go.sum index bddfd235..375d6038 100644 --- a/go.sum +++ b/go.sum @@ -210,8 +210,8 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= github.com/sagernet/quic-go v0.59.0-sing-box-mod.2 h1:hJUL+HtxEOjxsa0CsucbBVqI/AMS4k52NwNU637zmdw= github.com/sagernet/quic-go v0.59.0-sing-box-mod.2/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4= -github.com/sagernet/sing v0.8.0-beta.15 h1:lP6XnzeQvVBfuTkByo5YnG4Oy/AVkDC2ZljghSfHzKQ= -github.com/sagernet/sing v0.8.0-beta.15/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.8.0-beta.15.0.20260202162209-7c477e13f41e h1:H8izpW6d9l8Ub5UFSV/Q2WCehss2KAlmnDiABa4BHp0= +github.com/sagernet/sing v0.8.0-beta.15.0.20260202162209-7c477e13f41e/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s= github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= github.com/sagernet/sing-quic v0.6.0-beta.11 h1:eUusxITKKRedhWC2ScUYFUvD96h/QfbKLaS3N6/7in4= diff --git a/route/conn.go b/route/conn.go index 16654704..3e9c831c 100644 --- a/route/conn.go +++ b/route/conn.go @@ -2,7 +2,6 @@ package route import ( "context" - "errors" "io" "net" "net/netip" @@ -102,8 +101,12 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co m.connections.Remove(element) }) var done atomic.Bool - m.preConnectionCopy(ctx, conn, remoteConn, false, &done, onClose) - m.preConnectionCopy(ctx, remoteConn, conn, true, &done, onClose) + if m.kickWriteHandshake(ctx, conn, remoteConn, false, &done, onClose) { + return + } + if m.kickWriteHandshake(ctx, remoteConn, conn, true, &done, onClose) { + return + } go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose) } @@ -226,75 +229,8 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) } -func (m *ConnectionManager) preConnectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { - readHandshake := N.NeedHandshakeForRead(source) - writeHandshake := N.NeedHandshakeForWrite(destination) - if readHandshake || writeHandshake { - var err error - for { - err = m.connectionCopyEarlyWrite(source, destination, readHandshake, writeHandshake) - if err == nil && N.NeedHandshakeForRead(source) { - continue - } else if E.IsMulti(err, os.ErrInvalid, context.DeadlineExceeded, io.EOF) { - err = nil - } - break - } - if err != nil { - if done.Swap(true) { - onClose(err) - } - common.Close(source, destination) - if !direction { - m.logger.ErrorContext(ctx, "connection upload handshake: ", err) - } else { - m.logger.ErrorContext(ctx, "connection download handshake: ", err) - } - return - } - } -} - func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { - var ( - sourceReader io.Reader = source - destinationWriter io.Writer = destination - ) - var readCounters, writeCounters []N.CountFunc - for { - sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters) - destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters) - if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached { - cachedBuffer := cachedSrc.ReadCached() - if cachedBuffer != nil { - dataLen := cachedBuffer.Len() - _, err := destination.Write(cachedBuffer.Bytes()) - cachedBuffer.Release() - if err != nil { - if done.Swap(true) { - onClose(err) - } - common.Close(source, destination) - if !direction { - m.logger.ErrorContext(ctx, "connection upload payload: ", err) - } else { - m.logger.ErrorContext(ctx, "connection download payload: ", err) - } - return - } - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - } - continue - } - break - } - - _, err := bufio.CopyWithCounters(destinationWriter, sourceReader, source, readCounters, writeCounters, bufio.DefaultIncreaseBufferAfter, bufio.DefaultBatchSize) + _, err := bufio.CopyWithIncreateBuffer(destination, source, bufio.DefaultIncreaseBufferAfter, bufio.DefaultBatchSize) if err != nil { common.Close(source, destination) } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { @@ -328,45 +264,54 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, } } -func (m *ConnectionManager) connectionCopyEarlyWrite(source net.Conn, destination io.Writer, readHandshake bool, writeHandshake bool) error { - payload := buf.NewPacket() - defer payload.Release() - err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout)) - if err != nil { - if err == os.ErrInvalid { - if writeHandshake { - return common.Error(destination.Write(nil)) - } - } - return err +func (m *ConnectionManager) kickWriteHandshake(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) bool { + if !N.NeedHandshakeForWrite(destination) { + return false } var ( - isTimeout bool - isEOF bool + cachedBuffer *buf.Buffer + wrotePayload bool ) - _, err = payload.ReadOnceFrom(source) - if err != nil { - if E.IsTimeout(err) { - isTimeout = true - } else if errors.Is(err, io.EOF) { - isEOF = true - } else { - return E.Cause(err, "read payload") + sourceReader, readCounters := N.UnwrapCountReader(source, nil) + destinationWriter, writeCounters := N.UnwrapCountWriter(destination, nil) + if cachedReader, ok := sourceReader.(N.CachedReader); ok { + cachedBuffer = cachedReader.ReadCached() + } + var err error + if cachedBuffer != nil { + wrotePayload = true + dataLen := cachedBuffer.Len() + _, err = destinationWriter.Write(cachedBuffer.Bytes()) + cachedBuffer.Release() + if err == nil { + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } } + } else { + _ = destination.SetWriteDeadline(time.Now().Add(C.ReadPayloadTimeout)) + _, err = destinationWriter.Write(nil) + _ = destination.SetWriteDeadline(time.Time{}) } - _ = source.SetReadDeadline(time.Time{}) - if !payload.IsEmpty() || writeHandshake { - _, err = destination.Write(payload.Bytes()) - if err != nil { - return E.Cause(err, "write payload") - } + if err == nil { + return false } - if isTimeout { - return context.DeadlineExceeded - } else if isEOF { - return io.EOF + if !wrotePayload && (E.IsMulti(err, os.ErrInvalid, context.DeadlineExceeded, io.EOF) || E.IsTimeout(err)) { + return false } - return nil + if !done.Swap(true) { + onClose(err) + } + common.Close(source, destination) + if !direction { + m.logger.ErrorContext(ctx, "connection upload handshake: ", err) + } else { + m.logger.ErrorContext(ctx, "connection download handshake: ", err) + } + return true } func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {