Add proxy support for ICMP echo request

This commit is contained in:
世界
2025-02-17 22:00:57 +08:00
parent 44fafcef73
commit f84129ca79
26 changed files with 676 additions and 163 deletions

View File

@@ -5,6 +5,7 @@ import (
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
@@ -17,6 +18,8 @@ type Device interface {
N.Dialer
Start() error
SetDevice(device *device.Device)
Inet4Address() netip.Addr
Inet6Address() netip.Addr
}
type DeviceOptions struct {
@@ -35,9 +38,14 @@ type DeviceOptions struct {
func NewDevice(options DeviceOptions) (Device, error) {
if !options.System {
return newStackDevice(options)
} else if options.Handler == nil {
} else if !tun.WithGVisor {
return newSystemDevice(options)
} else {
return newSystemStackDevice(options)
}
}
type NatDevice interface {
Device
CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error)
}

View File

@@ -0,0 +1,103 @@
package wireguard
import (
"context"
"sync/atomic"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/logger"
)
var _ Device = (*natDeviceWrapper)(nil)
type natDeviceWrapper struct {
Device
ctx context.Context
logger logger.ContextLogger
packetOutbound chan *buf.Buffer
rewriter *ping.Rewriter
buffer [][]byte
}
func NewNATDevice(ctx context.Context, logger logger.ContextLogger, upstream Device) NatDevice {
wrapper := &natDeviceWrapper{
Device: upstream,
ctx: ctx,
logger: logger,
packetOutbound: make(chan *buf.Buffer, 256),
rewriter: ping.NewRewriter(ctx, logger, upstream.Inet4Address(), upstream.Inet6Address()),
}
return wrapper
}
func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
select {
case packet := <-d.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
return 1, nil
default:
}
return d.Device.Read(bufs, sizes, offset)
}
func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
for _, buffer := range bufs {
handled, err := d.rewriter.WriteBack(buffer[offset:])
if handled {
if err != nil {
return 0, err
}
} else {
d.buffer = append(d.buffer, buffer)
}
}
if len(d.buffer) > 0 {
_, err := d.Device.Write(d.buffer, offset)
if err != nil {
return 0, err
}
d.buffer = d.buffer[:0]
}
return 0, nil
}
func (d *natDeviceWrapper) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
ctx := log.ContextWithNewID(d.ctx)
session := tun.DirectRouteSession{
Source: metadata.Source.Addr,
Destination: metadata.Destination.Addr,
}
d.rewriter.CreateSession(session, routeContext)
d.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
return &natDestination{device: d, session: session}, nil
}
var _ tun.DirectRouteDestination = (*natDestination)(nil)
type natDestination struct {
device *natDeviceWrapper
session tun.DirectRouteSession
closed atomic.Bool
}
func (d *natDestination) WritePacket(buffer *buf.Buffer) error {
d.device.rewriter.RewritePacket(buffer.Bytes())
d.device.packetOutbound <- buffer
return nil
}
func (d *natDestination) Close() error {
d.closed.Store(true)
d.device.rewriter.DeleteSession(d.session)
return nil
}
func (d *natDestination) IsClosed() bool {
return d.closed.Load()
}

View File

@@ -5,7 +5,9 @@ package wireguard
import (
"context"
"net"
"net/netip"
"os"
"time"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
@@ -14,9 +16,14 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@@ -24,30 +31,40 @@ import (
wgTun "github.com/sagernet/wireguard-go/tun"
)
var _ Device = (*stackDevice)(nil)
var _ NatDevice = (*stackDevice)(nil)
type stackDevice struct {
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
ctx context.Context
logger log.ContextLogger
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
packetOutbound chan *buf.Buffer
done chan struct{}
dispatcher stack.NetworkDispatcher
inet4Address netip.Addr
inet6Address netip.Addr
}
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
tunDevice := &stackDevice{
mtu: options.MTU,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
ctx: options.Context,
logger: options.Logger,
mtu: options.MTU,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}),
}
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true)
if err != nil {
return nil, err
}
var (
inet4Address netip.Addr
inet6Address netip.Addr
)
for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{
@@ -57,10 +74,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
},
}
if prefix.Addr().Is4() {
tunDevice.addr4 = addr
inet4Address = prefix.Addr()
tunDevice.inet4Address = inet4Address
protoAddr.Protocol = ipv4.ProtocolNumber
} else {
tunDevice.addr6 = addr
inet6Address = prefix.Addr()
tunDevice.inet6Address = inet6Address
protoAddr.Protocol = ipv6.ProtocolNumber
}
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
@@ -72,6 +91,10 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
if options.Handler != nil {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
}
return tunDevice, nil
}
@@ -87,11 +110,17 @@ func (w *stackDevice) DialContext(ctx context.Context, network string, destinati
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
if !w.inet4Address.IsValid() {
return nil, E.New("missing IPv4 local address")
}
networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.addr4
bind.Addr = tun.AddressFromAddr(w.inet4Address)
} else {
if !w.inet6Address.IsValid() {
return nil, E.New("missing IPv6 local address")
}
networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.addr6
bind.Addr = tun.AddressFromAddr(w.inet6Address)
}
switch N.NetworkName(network) {
case N.NetworkTCP:
@@ -118,10 +147,10 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.addr4
bind.Addr = tun.AddressFromAddr(w.inet4Address)
} else {
networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.addr6
bind.Addr = tun.AddressFromAddr(w.inet4Address)
}
udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
if err != nil {
@@ -130,6 +159,14 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
return udpConn, nil
}
func (w *stackDevice) Inet4Address() netip.Addr {
return w.inet4Address
}
func (w *stackDevice) Inet6Address() netip.Addr {
return w.inet6Address
}
func (w *stackDevice) SetDevice(device *device.Device) {
}
@@ -144,20 +181,24 @@ func (w *stackDevice) File() *os.File {
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select {
case packetBuffer, ok := <-w.outbound:
case packet, ok := <-w.outbound:
if !ok {
return 0, os.ErrClosed
}
defer packetBuffer.DecRef()
p := bufs[0]
p = p[offset:]
n := 0
for _, slice := range packetBuffer.AsSlices() {
n += copy(p[n:], slice)
defer packet.DecRef()
var copyN int
/*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) {
copyN += copy(bufs[0][offset+copyN:], view.AsSlice())
})*/
for _, view := range packet.AsSlices() {
copyN += copy(bufs[0][offset+copyN:], view)
}
sizes[0] = n
count = 1
return
sizes[0] = copyN
return 1, nil
case packet := <-w.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
return 1, nil
case <-w.done:
return 0, os.ErrClosed
}
@@ -217,6 +258,23 @@ func (w *stackDevice) BatchSize() int {
return 1
}
func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
ctx := log.ContextWithNewID(w.ctx)
destination, err := ping.ConnectGVisor(
ctx, w.logger,
metadata.Source.Addr, metadata.Destination.Addr,
routeContext,
w.stack,
w.inet4Address, w.inet6Address,
timeout,
)
if err != nil {
return nil, err
}
w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
return destination, nil
}
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint stackDevice

View File

@@ -22,22 +22,42 @@ import (
var _ Device = (*systemDevice)(nil)
type systemDevice struct {
options DeviceOptions
dialer N.Dialer
device tun.Tun
batchDevice tun.LinuxTUN
events chan wgTun.Event
closeOnce sync.Once
options DeviceOptions
dialer N.Dialer
device tun.Tun
batchDevice tun.LinuxTUN
events chan wgTun.Event
closeOnce sync.Once
inet4Address netip.Addr
inet6Address netip.Addr
}
func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
if options.Name == "" {
options.Name = tun.CalculateInterfaceName("wg")
}
var inet4Address netip.Addr
var inet6Address netip.Addr
if len(options.Address) > 0 {
if prefix := common.Find(options.Address, func(it netip.Prefix) bool {
return it.Addr().Is4()
}); prefix.IsValid() {
inet4Address = prefix.Addr()
}
}
if len(options.Address) > 0 {
if prefix := common.Find(options.Address, func(it netip.Prefix) bool {
return it.Addr().Is6()
}); prefix.IsValid() {
inet6Address = prefix.Addr()
}
}
return &systemDevice{
options: options,
dialer: options.CreateDialer(options.Name),
events: make(chan wgTun.Event, 1),
options: options,
dialer: options.CreateDialer(options.Name),
events: make(chan wgTun.Event, 1),
inet4Address: inet4Address,
inet6Address: inet6Address,
}, nil
}
@@ -49,6 +69,14 @@ func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr
return w.dialer.ListenPacket(ctx, destination)
}
func (w *systemDevice) Inet4Address() netip.Addr {
return w.inet4Address
}
func (w *systemDevice) Inet6Address() netip.Addr {
return w.inet6Address
}
func (w *systemDevice) SetDevice(device *device.Device) {
}

View File

@@ -3,16 +3,26 @@
package wireguard
import (
"context"
"net/netip"
"time"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/wireguard-go/device"
)
@@ -20,6 +30,8 @@ var _ Device = (*systemStackDevice)(nil)
type systemStackDevice struct {
*systemDevice
ctx context.Context
logger logger.ContextLogger
stack *stack.Stack
endpoint *deviceEndpoint
writeBufs [][]byte
@@ -34,13 +46,45 @@ func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
mtu: options.MTU,
done: make(chan struct{}),
}
ipStack, err := tun.NewGVisorStack(endpoint)
ipStack, err := tun.NewGVisorStackWithOptions(endpoint, stack.NICOptions{}, true)
if err != nil {
return nil, err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
var (
inet4Address netip.Addr
inet6Address netip.Addr
)
for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: prefix.Bits(),
},
}
if prefix.Addr().Is4() {
inet4Address = prefix.Addr()
protoAddr.Protocol = ipv4.ProtocolNumber
} else {
inet6Address = prefix.Addr()
protoAddr.Protocol = ipv6.ProtocolNumber
}
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
if gErr != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
}
}
if options.Handler != nil {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
}
return &systemStackDevice{
ctx: options.Context,
logger: options.Logger,
systemDevice: system,
stack: ipStack,
endpoint: endpoint,
@@ -116,6 +160,23 @@ func (w *systemStackDevice) writeStack(packet []byte) bool {
return true
}
func (w *systemStackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
ctx := log.ContextWithNewID(w.ctx)
destination, err := ping.ConnectGVisor(
ctx, w.logger,
metadata.Source.Addr, metadata.Destination.Addr,
routeContext,
w.stack,
w.inet4Address, w.inet6Address,
timeout,
)
if err != nil {
return nil, err
}
w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
return destination, nil
}
type deviceEndpoint struct {
mtu uint32
done chan struct{}

View File

@@ -10,8 +10,11 @@ import (
"os"
"reflect"
"strings"
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
@@ -31,6 +34,7 @@ type Endpoint struct {
ipcConf string
allowedAddress []netip.Prefix
tunDevice Device
natDevice NatDevice
device *device.Device
allowedIPs *device.AllowedIPs
pause pause.Manager
@@ -114,12 +118,17 @@ func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
natDevice, isNatDevice := tunDevice.(NatDevice)
if !isNatDevice {
natDevice = NewNATDevice(options.Context, options.Logger, tunDevice)
}
return &Endpoint{
options: options,
peers: peers,
ipcConf: ipcConf,
allowedAddress: allowedAddresses,
tunDevice: tunDevice,
natDevice: natDevice,
}, nil
}
@@ -179,7 +188,13 @@ func (e *Endpoint) Start(resolve bool) error {
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
var deviceInput Device
if e.natDevice != nil {
deviceInput = e.natDevice
} else {
deviceInput = e.tunDevice
}
wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers)
e.tunDevice.SetDevice(wgDevice)
ipcConf := e.ipcConf
for _, peer := range e.peers {
@@ -229,6 +244,13 @@ func (e *Endpoint) Lookup(address netip.Addr) *device.Peer {
return e.allowedIPs.Lookup(address.AsSlice())
}
func (e *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
if e.natDevice == nil {
return nil, os.ErrInvalid
}
return e.natDevice.CreateDestination(metadata, routeContext, timeout)
}
func (e *Endpoint) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused, pause.EventNetworkPause: