bufio: Refactor copy

This commit is contained in:
世界
2026-02-03 02:51:05 +08:00
parent baa9f29f0d
commit e11dbf3a8e
3 changed files with 51 additions and 106 deletions

View File

@@ -2,7 +2,6 @@ package route
import (
"context"
"errors"
"io"
"net"
"net/netip"
@@ -102,8 +101,12 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
m.connections.Remove(element)
})
var done atomic.Bool
m.preConnectionCopy(ctx, conn, remoteConn, false, &done, onClose)
m.preConnectionCopy(ctx, remoteConn, conn, true, &done, onClose)
if m.kickWriteHandshake(ctx, conn, remoteConn, false, &done, onClose) {
return
}
if m.kickWriteHandshake(ctx, remoteConn, conn, true, &done, onClose) {
return
}
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
}
@@ -226,75 +229,8 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
}
func (m *ConnectionManager) preConnectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
readHandshake := N.NeedHandshakeForRead(source)
writeHandshake := N.NeedHandshakeForWrite(destination)
if readHandshake || writeHandshake {
var err error
for {
err = m.connectionCopyEarlyWrite(source, destination, readHandshake, writeHandshake)
if err == nil && N.NeedHandshakeForRead(source) {
continue
} else if E.IsMulti(err, os.ErrInvalid, context.DeadlineExceeded, io.EOF) {
err = nil
}
break
}
if err != nil {
if done.Swap(true) {
onClose(err)
}
common.Close(source, destination)
if !direction {
m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
} else {
m.logger.ErrorContext(ctx, "connection download handshake: ", err)
}
return
}
}
}
func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
var (
sourceReader io.Reader = source
destinationWriter io.Writer = destination
)
var readCounters, writeCounters []N.CountFunc
for {
sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
dataLen := cachedBuffer.Len()
_, err := destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
if done.Swap(true) {
onClose(err)
}
common.Close(source, destination)
if !direction {
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
} else {
m.logger.ErrorContext(ctx, "connection download payload: ", err)
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
continue
}
break
}
_, err := bufio.CopyWithCounters(destinationWriter, sourceReader, source, readCounters, writeCounters, bufio.DefaultIncreaseBufferAfter, bufio.DefaultBatchSize)
_, err := bufio.CopyWithIncreateBuffer(destination, source, bufio.DefaultIncreaseBufferAfter, bufio.DefaultBatchSize)
if err != nil {
common.Close(source, destination)
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
@@ -328,45 +264,54 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn,
}
}
func (m *ConnectionManager) connectionCopyEarlyWrite(source net.Conn, destination io.Writer, readHandshake bool, writeHandshake bool) error {
payload := buf.NewPacket()
defer payload.Release()
err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
if err != nil {
if err == os.ErrInvalid {
if writeHandshake {
return common.Error(destination.Write(nil))
}
}
return err
func (m *ConnectionManager) kickWriteHandshake(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) bool {
if !N.NeedHandshakeForWrite(destination) {
return false
}
var (
isTimeout bool
isEOF bool
cachedBuffer *buf.Buffer
wrotePayload bool
)
_, err = payload.ReadOnceFrom(source)
if err != nil {
if E.IsTimeout(err) {
isTimeout = true
} else if errors.Is(err, io.EOF) {
isEOF = true
} else {
return E.Cause(err, "read payload")
sourceReader, readCounters := N.UnwrapCountReader(source, nil)
destinationWriter, writeCounters := N.UnwrapCountWriter(destination, nil)
if cachedReader, ok := sourceReader.(N.CachedReader); ok {
cachedBuffer = cachedReader.ReadCached()
}
var err error
if cachedBuffer != nil {
wrotePayload = true
dataLen := cachedBuffer.Len()
_, err = destinationWriter.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err == nil {
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
} else {
_ = destination.SetWriteDeadline(time.Now().Add(C.ReadPayloadTimeout))
_, err = destinationWriter.Write(nil)
_ = destination.SetWriteDeadline(time.Time{})
}
_ = source.SetReadDeadline(time.Time{})
if !payload.IsEmpty() || writeHandshake {
_, err = destination.Write(payload.Bytes())
if err != nil {
return E.Cause(err, "write payload")
}
if err == nil {
return false
}
if isTimeout {
return context.DeadlineExceeded
} else if isEOF {
return io.EOF
if !wrotePayload && (E.IsMulti(err, os.ErrInvalid, context.DeadlineExceeded, io.EOF) || E.IsTimeout(err)) {
return false
}
return nil
if !done.Swap(true) {
onClose(err)
}
common.Close(source, destination)
if !direction {
m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
} else {
m.logger.ErrorContext(ctx, "connection download handshake: ", err)
}
return true
}
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {