Improve smux write

This commit is contained in:
世界
2022-08-12 16:49:25 +08:00
parent f8d13d79c7
commit 51bbf93ff2
7 changed files with 49 additions and 50 deletions

View File

@@ -146,7 +146,12 @@ func (c *Client) offerNew() (abstractSession, error) {
if err != nil {
return nil, err
}
session, err := c.protocol.newClient(&protocolConn{Conn: conn, protocol: c.protocol})
if vectorisedWriter, isVectorised := bufio.CreateVectorisedWriter(conn); isVectorised {
conn = &vectorisedProtocolConn{protocolConn{Conn: conn, protocol: c.protocol}, vectorisedWriter}
} else {
conn = &protocolConn{Conn: conn, protocol: c.protocol}
}
session, err := c.protocol.newClient(conn)
if err != nil {
return nil, err
}

View File

@@ -43,7 +43,7 @@ func ParseProtocol(name string) (Protocol, error) {
func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
switch p {
case ProtocolSMux:
session, err := smux.Server(wrapSMuxConn(conn), nil)
session, err := smux.Server(conn, nil)
if err != nil {
return nil, err
}
@@ -58,7 +58,7 @@ func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
func (p Protocol) newClient(conn net.Conn) (abstractSession, error) {
switch p {
case ProtocolSMux:
session, err := smux.Client(wrapSMuxConn(conn), nil)
session, err := smux.Client(conn, nil)
if err != nil {
return nil, err
}
@@ -201,31 +201,6 @@ func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
return &response, nil
}
type smuxTCPConn struct {
*net.TCPConn
}
func wrapSMuxConn(originConn net.Conn) net.Conn {
switch conn := originConn.(type) {
case *net.TCPConn:
return &smuxTCPConn{conn}
}
return originConn
}
func (w *smuxTCPConn) WriteBuffers(v [][]byte) (n int, err error) {
buffers := net.Buffers(v)
writeN, err := buffers.WriteTo(w.TCPConn)
if err != nil {
return
}
return int(writeN), nil
}
func (w *smuxTCPConn) Upstream() any {
return w.TCPConn
}
type wrapStream struct {
net.Conn
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/smux"
)
@@ -68,3 +69,23 @@ func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
func (c *protocolConn) Upstream() any {
return c.Conn
}
type vectorisedProtocolConn struct {
protocolConn
N.VectorisedWriter
}
func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error {
if c.protocolWritten {
return c.VectorisedWriter.WriteVectorised(buffers)
}
c.protocolWritten = true
_buffer := buf.StackNewSize(2)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
EncodeRequest(buffer, Request{
Protocol: c.protocol,
})
return c.VectorisedWriter.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
}