Add multi-peer support for wireguard outbound
This commit is contained in:
@@ -3,9 +3,12 @@ package wireguard
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
@@ -14,24 +17,32 @@ import (
|
||||
var _ conn.Bind = (*ClientBind)(nil)
|
||||
|
||||
type ClientBind struct {
|
||||
ctx context.Context
|
||||
dialer N.Dialer
|
||||
peerAddr M.Socksaddr
|
||||
reserved [3]uint8
|
||||
connAccess sync.Mutex
|
||||
conn *wireConn
|
||||
done chan struct{}
|
||||
ctx context.Context
|
||||
dialer N.Dialer
|
||||
reservedForEndpoint map[M.Socksaddr][3]uint8
|
||||
connAccess sync.Mutex
|
||||
conn *wireConn
|
||||
done chan struct{}
|
||||
isConnect bool
|
||||
connectAddr M.Socksaddr
|
||||
reserved [3]uint8
|
||||
}
|
||||
|
||||
func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
|
||||
func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
|
||||
return &ClientBind{
|
||||
ctx: ctx,
|
||||
dialer: dialer,
|
||||
peerAddr: peerAddr,
|
||||
reserved: reserved,
|
||||
ctx: ctx,
|
||||
dialer: dialer,
|
||||
reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
|
||||
isConnect: isConnect,
|
||||
connectAddr: connectAddr,
|
||||
reserved: reserved,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) {
|
||||
c.reservedForEndpoint[destination] = reserved
|
||||
}
|
||||
|
||||
func (c *ClientBind) connect() (*wireConn, error) {
|
||||
serverConn := c.conn
|
||||
if serverConn != nil {
|
||||
@@ -53,13 +64,27 @@ func (c *ClientBind) connect() (*wireConn, error) {
|
||||
return serverConn, nil
|
||||
}
|
||||
}
|
||||
udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr)
|
||||
if err != nil {
|
||||
return nil, &wireError{err}
|
||||
}
|
||||
c.conn = &wireConn{
|
||||
Conn: udpConn,
|
||||
done: make(chan struct{}),
|
||||
if c.isConnect {
|
||||
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
|
||||
if err != nil {
|
||||
return nil, &wireError{err}
|
||||
}
|
||||
c.conn = &wireConn{
|
||||
NetPacketConn: &bufio.UnbindPacketConn{
|
||||
ExtendedConn: bufio.NewExtendedConn(udpConn),
|
||||
Addr: c.connectAddr,
|
||||
},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
} else {
|
||||
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
|
||||
if err != nil {
|
||||
return nil, &wireError{err}
|
||||
}
|
||||
c.conn = &wireConn{
|
||||
NetPacketConn: bufio.NewPacketConn(udpConn),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
return c.conn, nil
|
||||
}
|
||||
@@ -80,7 +105,8 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
err = &wireError{err}
|
||||
return
|
||||
}
|
||||
n, err = udpConn.Read(b)
|
||||
buffer := buf.With(b)
|
||||
destination, err := udpConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
select {
|
||||
@@ -90,12 +116,16 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
n = buffer.Len()
|
||||
if buffer.Start() > 0 {
|
||||
copy(b, buffer.Bytes())
|
||||
}
|
||||
if n > 3 {
|
||||
b[1] = 0
|
||||
b[2] = 0
|
||||
b[3] = 0
|
||||
}
|
||||
ep = Endpoint(c.peerAddr)
|
||||
ep = Endpoint(destination)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -127,12 +157,17 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destination := M.Socksaddr(ep.(Endpoint))
|
||||
if len(b) > 3 {
|
||||
b[1] = c.reserved[0]
|
||||
b[2] = c.reserved[1]
|
||||
b[3] = c.reserved[2]
|
||||
reserved, loaded := c.reservedForEndpoint[destination]
|
||||
if !loaded {
|
||||
reserved = c.reserved
|
||||
}
|
||||
b[1] = reserved[0]
|
||||
b[2] = reserved[1]
|
||||
b[3] = reserved[2]
|
||||
}
|
||||
_, err = udpConn.Write(b)
|
||||
err = udpConn.WritePacket(buf.As(b), destination)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
}
|
||||
@@ -140,15 +175,11 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
}
|
||||
|
||||
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
return Endpoint(c.peerAddr), nil
|
||||
}
|
||||
|
||||
func (c *ClientBind) Endpoint() conn.Endpoint {
|
||||
return Endpoint(c.peerAddr)
|
||||
return Endpoint(M.ParseSocksaddr(s)), nil
|
||||
}
|
||||
|
||||
type wireConn struct {
|
||||
net.Conn
|
||||
N.NetPacketConn
|
||||
access sync.Mutex
|
||||
done chan struct{}
|
||||
}
|
||||
@@ -161,7 +192,7 @@ func (w *wireConn) Close() error {
|
||||
return net.ErrClosed
|
||||
default:
|
||||
}
|
||||
w.Conn.Close()
|
||||
w.NetPacketConn.Close()
|
||||
close(w.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user