Reject invalid connection

This commit is contained in:
世界
2023-09-07 09:09:03 +08:00
parent 4ea2d460f4
commit 6b943caf37
9 changed files with 93 additions and 60 deletions

View File

@@ -204,10 +204,13 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandTCP))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
common.Must1(header.Write(CRLF))
common.Must1(header.Write(payload))
_, err := conn.Write(header.Bytes())
_, err = conn.Write(header.Bytes())
if err != nil {
return E.Cause(err, "write request")
}
@@ -219,10 +222,13 @@ func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Soc
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandTCP))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
common.Must1(header.Write(CRLF))
_, err := conn.Write(payload.Bytes())
_, err = conn.Write(payload.Bytes())
if err != nil {
return E.Cause(err, "write request")
}
@@ -244,7 +250,10 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Soc
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandUDP))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
common.Must1(header.Write(CRLF))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
@@ -257,7 +266,7 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Soc
}
}
_, err := conn.Write(payload.Bytes())
_, err = conn.Write(payload.Bytes())
if err != nil {
return E.Cause(err, "write payload")
}
@@ -289,10 +298,13 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) err
defer buffer.Release()
bufferLen := buffer.Len()
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
common.Must1(header.Write(CRLF))
_, err := conn.Write(buffer.Bytes())
_, err = conn.Write(buffer.Bytes())
if err != nil {
return E.Cause(err, "write packet")
}

View File

@@ -271,9 +271,13 @@ func (c *clientConn) Read(b []byte) (n int, err error) {
func (c *clientConn) Write(b []byte) (n int, err error) {
if !c.requestWritten {
request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
defer request.Release()
request.WriteByte(Version)
request.WriteByte(CommandConnect)
addressSerializer.WriteAddrPort(request, c.destination)
err = addressSerializer.WriteAddrPort(request, c.destination)
if err != nil {
return
}
request.Write(b)
_, err = c.stream.Write(request.Bytes())
if err != nil {

View File

@@ -17,6 +17,7 @@ import (
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
@@ -205,6 +206,9 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
if buffer.Len() > 0xffff {
return quic.ErrMessageTooLarge(0xffff)
}
if !destination.IsValid() {
return E.New("invalid destination address")
}
packetId := c.packetId.Add(1)
if packetId > math.MaxUint16 {
c.packetId.Store(0)
@@ -246,6 +250,10 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if len(p) > 0xffff {
return 0, quic.ErrMessageTooLarge(0xffff)
}
destination := M.SocksaddrFromNet(addr)
if !destination.IsValid() {
return 0, E.New("invalid destination address")
}
packetId := c.packetId.Add(1)
if packetId > math.MaxUint16 {
c.packetId.Store(0)
@@ -256,7 +264,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
sessionID: c.sessionID,
packetID: uint16(packetId),
fragmentTotal: 1,
destination: M.SocksaddrFromNet(addr),
destination: destination,
data: buf.As(p),
}
if !c.udpStream && c.needFragment() && len(p) > c.udpMTU {

View File

@@ -150,7 +150,10 @@ func (c *Conn) Write(b []byte) (n int, err error) {
func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
if !c.requestWritten {
EncodeRequest(c.request, buf.With(buffer.ExtendHeader(RequestLen(c.request))))
err := EncodeRequest(c.request, buf.With(buffer.ExtendHeader(RequestLen(c.request))))
if err != nil {
return err
}
c.requestWritten = true
}
return c.ExtendedConn.WriteBuffer(buffer)
@@ -159,7 +162,11 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
func (c *Conn) WriteVectorised(buffers []*buf.Buffer) error {
if !c.requestWritten {
buffer := buf.NewSize(RequestLen(c.request))
EncodeRequest(c.request, buffer)
err := EncodeRequest(c.request, buffer)
if err != nil {
buffer.Release()
return err
}
c.requestWritten = true
return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
}

View File

@@ -156,14 +156,17 @@ func WriteRequest(writer io.Writer, request Request, payload []byte) error {
)
if request.Command != vmess.CommandMux {
common.Must(vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination))
err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination)
if err != nil {
return err
}
}
common.Must1(buffer.Write(payload))
return common.Error(writer.Write(buffer.Bytes()))
}
func EncodeRequest(request Request, buffer *buf.Buffer) {
func EncodeRequest(request Request, buffer *buf.Buffer) error {
var requestLen int
requestLen += 1 // version
requestLen += 16 // uuid
@@ -195,8 +198,12 @@ func EncodeRequest(request Request, buffer *buf.Buffer) {
)
if request.Command != vmess.CommandMux {
common.Must(vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination))
err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination)
if err != nil {
return err
}
}
return nil
}
func RequestLen(request Request) int {
@@ -251,10 +258,12 @@ func WritePacketRequest(writer io.Writer, request Request, payload []byte) error
common.Must(common.Error(buffer.WriteString(request.Flow)))
}
common.Must(
buffer.WriteByte(vmess.CommandUDP),
vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination),
)
common.Must(buffer.WriteByte(vmess.CommandUDP))
err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination)
if err != nil {
return err
}
if len(payload) > 0 {
common.Must(