refactor: WireGuard endpoint
This commit is contained in:
@@ -1,35 +1,260 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/sagernet/sing/service/pause"
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
var _ conn.Endpoint = (*Endpoint)(nil)
|
||||
|
||||
type Endpoint netip.AddrPort
|
||||
|
||||
func (e Endpoint) ClearSrc() {
|
||||
type Endpoint struct {
|
||||
options EndpointOptions
|
||||
peers []peerConfig
|
||||
ipcConf string
|
||||
allowedAddress []netip.Prefix
|
||||
tunDevice Device
|
||||
device *device.Device
|
||||
pauseManager pause.Manager
|
||||
pauseCallback *list.Element[pause.Callback]
|
||||
}
|
||||
|
||||
func (e Endpoint) SrcToString() string {
|
||||
return ""
|
||||
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
|
||||
if options.PrivateKey == "" {
|
||||
return nil, E.New("missing private key")
|
||||
}
|
||||
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode private key")
|
||||
}
|
||||
privateKey := hex.EncodeToString(privateKeyBytes)
|
||||
ipcConf := "private_key=" + privateKey
|
||||
if options.ListenPort != 0 {
|
||||
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
|
||||
}
|
||||
var peers []peerConfig
|
||||
for peerIndex, rawPeer := range options.Peers {
|
||||
peer := peerConfig{
|
||||
allowedIPs: rawPeer.AllowedIPs,
|
||||
keepalive: rawPeer.PersistentKeepaliveInterval,
|
||||
}
|
||||
if rawPeer.Endpoint.Addr.IsValid() {
|
||||
peer.endpoint = rawPeer.Endpoint.AddrPort()
|
||||
} else if rawPeer.Endpoint.IsFqdn() {
|
||||
peer.destination = rawPeer.Endpoint
|
||||
}
|
||||
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
||||
}
|
||||
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
|
||||
if rawPeer.PreSharedKey != "" {
|
||||
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
||||
}
|
||||
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
|
||||
}
|
||||
if len(rawPeer.AllowedIPs) == 0 {
|
||||
return nil, E.New("missing allowed ips for peer ", peerIndex)
|
||||
}
|
||||
if len(rawPeer.Reserved) > 0 {
|
||||
if len(rawPeer.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
|
||||
}
|
||||
copy(peer.reserved[:], rawPeer.Reserved[:])
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
var allowedPrefixBuilder netipx.IPSetBuilder
|
||||
for _, peer := range options.Peers {
|
||||
for _, prefix := range peer.AllowedIPs {
|
||||
allowedPrefixBuilder.AddPrefix(prefix)
|
||||
}
|
||||
}
|
||||
allowedIPSet, err := allowedPrefixBuilder.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowedAddresses := allowedIPSet.Prefixes()
|
||||
if options.MTU == 0 {
|
||||
options.MTU = 1408
|
||||
}
|
||||
deviceOptions := DeviceOptions{
|
||||
Context: options.Context,
|
||||
Logger: options.Logger,
|
||||
System: options.System,
|
||||
Handler: options.Handler,
|
||||
UDPTimeout: options.UDPTimeout,
|
||||
CreateDialer: options.CreateDialer,
|
||||
Name: options.Name,
|
||||
MTU: options.MTU,
|
||||
GSO: options.GSO,
|
||||
Address: options.Address,
|
||||
AllowedAddress: allowedAddresses,
|
||||
}
|
||||
tunDevice, err := NewDevice(deviceOptions)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create WireGuard device")
|
||||
}
|
||||
return &Endpoint{
|
||||
options: options,
|
||||
peers: peers,
|
||||
ipcConf: ipcConf,
|
||||
allowedAddress: allowedAddresses,
|
||||
tunDevice: tunDevice,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e Endpoint) DstToString() string {
|
||||
return (netip.AddrPort)(e).String()
|
||||
func (e *Endpoint) Start(resolve bool) error {
|
||||
if common.Any(e.peers, func(peer peerConfig) bool {
|
||||
return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
|
||||
}) {
|
||||
if !resolve {
|
||||
return nil
|
||||
}
|
||||
for peerIndex, peer := range e.peers {
|
||||
if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
|
||||
continue
|
||||
}
|
||||
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
|
||||
}
|
||||
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
|
||||
}
|
||||
} else if resolve {
|
||||
return nil
|
||||
}
|
||||
var bind conn.Bind
|
||||
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
|
||||
if isWgListener {
|
||||
bind = conn.NewStdNetBind(wgListener)
|
||||
} else {
|
||||
var (
|
||||
isConnect bool
|
||||
connectAddr netip.AddrPort
|
||||
reserved [3]uint8
|
||||
)
|
||||
if len(e.peers) == 1 {
|
||||
isConnect = true
|
||||
connectAddr = e.peers[0].endpoint
|
||||
reserved = e.peers[0].reserved
|
||||
}
|
||||
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
|
||||
}
|
||||
if isWgListener || len(e.peers) > 1 {
|
||||
for _, peer := range e.peers {
|
||||
if peer.reserved != [3]uint8{} {
|
||||
bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
|
||||
}
|
||||
}
|
||||
}
|
||||
err := e.tunDevice.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger := &device.Logger{
|
||||
Verbosef: func(format string, args ...interface{}) {
|
||||
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
Errorf: func(format string, args ...interface{}) {
|
||||
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
}
|
||||
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
|
||||
e.tunDevice.SetDevice(wgDevice)
|
||||
ipcConf := e.ipcConf
|
||||
for _, peer := range e.peers {
|
||||
ipcConf += peer.GenerateIpcLines()
|
||||
}
|
||||
err = wgDevice.IpcSet(ipcConf)
|
||||
if err != nil {
|
||||
return E.Cause(err, "setup wireguard: \n", ipcConf)
|
||||
}
|
||||
e.device = wgDevice
|
||||
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
|
||||
if e.pauseManager != nil {
|
||||
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e Endpoint) DstToBytes() []byte {
|
||||
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
||||
return b
|
||||
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
return e.tunDevice.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (e Endpoint) DstIP() netip.Addr {
|
||||
return (netip.AddrPort)(e).Addr()
|
||||
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
return e.tunDevice.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (e Endpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
func (e *Endpoint) BindUpdate() error {
|
||||
return e.device.BindUpdate()
|
||||
}
|
||||
|
||||
func (e *Endpoint) Close() error {
|
||||
if e.device != nil {
|
||||
e.device.Close()
|
||||
}
|
||||
if e.pauseCallback != nil {
|
||||
e.pauseManager.UnregisterCallback(e.pauseCallback)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) onPauseUpdated(event int) {
|
||||
switch event {
|
||||
case pause.EventDevicePaused:
|
||||
e.device.Down()
|
||||
case pause.EventDeviceWake:
|
||||
e.device.Up()
|
||||
}
|
||||
}
|
||||
|
||||
type peerConfig struct {
|
||||
destination M.Socksaddr
|
||||
endpoint netip.AddrPort
|
||||
publicKeyHex string
|
||||
preSharedKeyHex string
|
||||
allowedIPs []netip.Prefix
|
||||
keepalive uint16
|
||||
reserved [3]uint8
|
||||
}
|
||||
|
||||
func (c peerConfig) GenerateIpcLines() string {
|
||||
ipcLines := "\npublic_key=" + c.publicKeyHex
|
||||
if c.endpoint.IsValid() {
|
||||
ipcLines += "\nendpoint=" + c.endpoint.String()
|
||||
}
|
||||
if c.preSharedKeyHex != "" {
|
||||
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
|
||||
}
|
||||
for _, allowedIP := range c.allowedIPs {
|
||||
ipcLines += "\nallowed_ip=" + allowedIP.String()
|
||||
}
|
||||
if c.keepalive > 0 {
|
||||
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
|
||||
}
|
||||
return ipcLines
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user