Reject invalid connection
This commit is contained in:
@@ -204,10 +204,13 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr
|
||||
common.Must1(header.Write(key[:]))
|
||||
common.Must1(header.Write(CRLF))
|
||||
common.Must(header.WriteByte(CommandTCP))
|
||||
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.Must1(header.Write(CRLF))
|
||||
common.Must1(header.Write(payload))
|
||||
_, err := conn.Write(header.Bytes())
|
||||
_, err = conn.Write(header.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write request")
|
||||
}
|
||||
@@ -219,10 +222,13 @@ func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Soc
|
||||
common.Must1(header.Write(key[:]))
|
||||
common.Must1(header.Write(CRLF))
|
||||
common.Must(header.WriteByte(CommandTCP))
|
||||
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.Must1(header.Write(CRLF))
|
||||
|
||||
_, err := conn.Write(payload.Bytes())
|
||||
_, err = conn.Write(payload.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write request")
|
||||
}
|
||||
@@ -244,7 +250,10 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Soc
|
||||
common.Must1(header.Write(key[:]))
|
||||
common.Must1(header.Write(CRLF))
|
||||
common.Must(header.WriteByte(CommandUDP))
|
||||
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.Must1(header.Write(CRLF))
|
||||
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
|
||||
common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
|
||||
@@ -257,7 +266,7 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Soc
|
||||
}
|
||||
}
|
||||
|
||||
_, err := conn.Write(payload.Bytes())
|
||||
_, err = conn.Write(payload.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write payload")
|
||||
}
|
||||
@@ -289,10 +298,13 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) err
|
||||
defer buffer.Release()
|
||||
bufferLen := buffer.Len()
|
||||
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
|
||||
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
|
||||
common.Must1(header.Write(CRLF))
|
||||
_, err := conn.Write(buffer.Bytes())
|
||||
_, err = conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write packet")
|
||||
}
|
||||
|
||||
@@ -271,9 +271,13 @@ func (c *clientConn) Read(b []byte) (n int, err error) {
|
||||
func (c *clientConn) Write(b []byte) (n int, err error) {
|
||||
if !c.requestWritten {
|
||||
request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
|
||||
defer request.Release()
|
||||
request.WriteByte(Version)
|
||||
request.WriteByte(CommandConnect)
|
||||
addressSerializer.WriteAddrPort(request, c.destination)
|
||||
err = addressSerializer.WriteAddrPort(request, c.destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
request.Write(b)
|
||||
_, err = c.stream.Write(request.Bytes())
|
||||
if err != nil {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/cache"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
@@ -205,6 +206,9 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
|
||||
if buffer.Len() > 0xffff {
|
||||
return quic.ErrMessageTooLarge(0xffff)
|
||||
}
|
||||
if !destination.IsValid() {
|
||||
return E.New("invalid destination address")
|
||||
}
|
||||
packetId := c.packetId.Add(1)
|
||||
if packetId > math.MaxUint16 {
|
||||
c.packetId.Store(0)
|
||||
@@ -246,6 +250,10 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if len(p) > 0xffff {
|
||||
return 0, quic.ErrMessageTooLarge(0xffff)
|
||||
}
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if !destination.IsValid() {
|
||||
return 0, E.New("invalid destination address")
|
||||
}
|
||||
packetId := c.packetId.Add(1)
|
||||
if packetId > math.MaxUint16 {
|
||||
c.packetId.Store(0)
|
||||
@@ -256,7 +264,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
sessionID: c.sessionID,
|
||||
packetID: uint16(packetId),
|
||||
fragmentTotal: 1,
|
||||
destination: M.SocksaddrFromNet(addr),
|
||||
destination: destination,
|
||||
data: buf.As(p),
|
||||
}
|
||||
if !c.udpStream && c.needFragment() && len(p) > c.udpMTU {
|
||||
|
||||
@@ -150,7 +150,10 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
|
||||
func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if !c.requestWritten {
|
||||
EncodeRequest(c.request, buf.With(buffer.ExtendHeader(RequestLen(c.request))))
|
||||
err := EncodeRequest(c.request, buf.With(buffer.ExtendHeader(RequestLen(c.request))))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.requestWritten = true
|
||||
}
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
@@ -159,7 +162,11 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
func (c *Conn) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
if !c.requestWritten {
|
||||
buffer := buf.NewSize(RequestLen(c.request))
|
||||
EncodeRequest(c.request, buffer)
|
||||
err := EncodeRequest(c.request, buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}
|
||||
c.requestWritten = true
|
||||
return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
|
||||
}
|
||||
|
||||
@@ -156,14 +156,17 @@ func WriteRequest(writer io.Writer, request Request, payload []byte) error {
|
||||
)
|
||||
|
||||
if request.Command != vmess.CommandMux {
|
||||
common.Must(vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination))
|
||||
err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
common.Must1(buffer.Write(payload))
|
||||
return common.Error(writer.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func EncodeRequest(request Request, buffer *buf.Buffer) {
|
||||
func EncodeRequest(request Request, buffer *buf.Buffer) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // version
|
||||
requestLen += 16 // uuid
|
||||
@@ -195,8 +198,12 @@ func EncodeRequest(request Request, buffer *buf.Buffer) {
|
||||
)
|
||||
|
||||
if request.Command != vmess.CommandMux {
|
||||
common.Must(vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination))
|
||||
err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RequestLen(request Request) int {
|
||||
@@ -251,10 +258,12 @@ func WritePacketRequest(writer io.Writer, request Request, payload []byte) error
|
||||
common.Must(common.Error(buffer.WriteString(request.Flow)))
|
||||
}
|
||||
|
||||
common.Must(
|
||||
buffer.WriteByte(vmess.CommandUDP),
|
||||
vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination),
|
||||
)
|
||||
common.Must(buffer.WriteByte(vmess.CommandUDP))
|
||||
|
||||
err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(payload) > 0 {
|
||||
common.Must(
|
||||
|
||||
Reference in New Issue
Block a user