Fix TUIC UDP

This commit is contained in:
世界
2023-08-21 18:11:44 +08:00
parent 738c25d818
commit c0bbb3849d
10 changed files with 100 additions and 62 deletions

View File

@@ -41,7 +41,6 @@ type udpMessage struct {
fragmentTotal uint8
fragmentID uint8
destination M.Socksaddr
dataLength uint16
data *buf.Buffer
}
@@ -72,7 +71,7 @@ func (m *udpMessage) pack() *buf.Buffer {
}
func (m *udpMessage) headerSize() int {
return 2 + 10 + addressSerializer.AddrPortLen(m.destination)
return 10 + addressSerializer.AddrPortLen(m.destination)
}
func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
@@ -106,18 +105,19 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
}
type udpPacketConn struct {
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint16
quicConn quic.Connection
data chan *udpMessage
udpStream bool
udpMTU int
packetId atomic.Uint32
closeOnce sync.Once
isServer bool
defragger *udpDefragger
onDestroy func()
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint16
quicConn quic.Connection
data chan *udpMessage
udpStream bool
udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32
closeOnce sync.Once
isServer bool
defragger *udpDefragger
onDestroy func()
}
func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn {
@@ -186,6 +186,15 @@ func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
}
}
func (c *udpPacketConn) needFragment() bool {
nowTime := time.Now()
if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
c.udpMTUTime = nowTime
return true
}
return false
}
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
select {
@@ -211,7 +220,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
}
defer message.releaseMessage()
var err error
if !c.udpStream && c.udpMTU > 0 && buffer.Len() > c.udpMTU {
if !c.udpStream && c.needFragment() && buffer.Len() > c.udpMTU {
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
} else {
err = c.writePacket(message)
@@ -224,6 +233,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
return err
}
c.udpMTU = int(tooLargeErr)
c.udpMTUTime = time.Now()
return c.writePackets(fragUDPMessage(message, c.udpMTU))
}
@@ -265,6 +275,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return
}
c.udpMTU = int(tooLargeErr)
c.udpMTUTime = time.Now()
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil {
return len(p), nil
@@ -414,10 +425,14 @@ func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
}
newMessage := udpMessagePool.Get().(*udpMessage)
*newMessage = *item.messages[0]
if m.dataLength > 0 {
newMessage.data = buf.NewSize(int(m.dataLength))
var dataLength uint16
for _, message := range item.messages {
dataLength += uint16(message.data.Len())
}
if dataLength > 0 {
newMessage.data = buf.NewSize(int(dataLength))
for _, message := range item.messages {
newMessage.data.Write(message.data.Bytes())
common.Must1(newMessage.data.Write(message.data.Bytes()))
message.releaseMessage()
}
item.messages = nil
@@ -447,7 +462,8 @@ func readUDPMessage(message *udpMessage, reader io.Reader) error {
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.dataLength)
var dataLength uint16
err = binary.Read(reader, binary.BigEndian, &dataLength)
if err != nil {
return err
}
@@ -455,7 +471,7 @@ func readUDPMessage(message *udpMessage, reader io.Reader) error {
if err != nil {
return err
}
message.data = buf.NewSize(int(message.dataLength))
message.data = buf.NewSize(int(dataLength))
_, err = message.data.ReadFullFrom(reader, message.data.FreeLen())
if err != nil {
return err
@@ -481,7 +497,8 @@ func decodeUDPMessage(message *udpMessage, data []byte) error {
if err != nil {
return err
}
err = binary.Read(reader, binary.BigEndian, &message.dataLength)
var dataLength uint16
err = binary.Read(reader, binary.BigEndian, &dataLength)
if err != nil {
return err
}
@@ -489,7 +506,7 @@ func decodeUDPMessage(message *udpMessage, data []byte) error {
if err != nil {
return err
}
if reader.Len() != int(message.dataLength) {
if reader.Len() != int(dataLength) {
return io.ErrUnexpectedEOF
}
message.data = buf.As(data[len(data)-reader.Len():])