Fix DNS exchange
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/sagernet/fswatch"
|
"github.com/sagernet/fswatch"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
var _ adapter.CertificateStore = (*Store)(nil)
|
var _ adapter.CertificateStore = (*Store)(nil)
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
|
access sync.RWMutex
|
||||||
systemPool *x509.CertPool
|
systemPool *x509.CertPool
|
||||||
currentPool *x509.CertPool
|
currentPool *x509.CertPool
|
||||||
certificate string
|
certificate string
|
||||||
@@ -115,10 +117,14 @@ func (s *Store) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) Pool() *x509.CertPool {
|
func (s *Store) Pool() *x509.CertPool {
|
||||||
|
s.access.RLock()
|
||||||
|
defer s.access.RUnlock()
|
||||||
return s.currentPool
|
return s.currentPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) update() error {
|
func (s *Store) update() error {
|
||||||
|
s.access.Lock()
|
||||||
|
defer s.access.Unlock()
|
||||||
var currentPool *x509.CertPool
|
var currentPool *x509.CertPool
|
||||||
if s.systemPool == nil {
|
if s.systemPool == nil {
|
||||||
currentPool = x509.NewCertPool()
|
currentPool = x509.NewCertPool()
|
||||||
|
|||||||
@@ -69,11 +69,7 @@ func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions,
|
|||||||
} else {
|
} else {
|
||||||
return E.New("missing ECH keys")
|
return E.New("missing ECH keys")
|
||||||
}
|
}
|
||||||
block, rest := pem.Decode(echKey)
|
echKeys, err := parseECHKeys(echKey)
|
||||||
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
|
|
||||||
return E.New("invalid ECH keys pem")
|
|
||||||
}
|
|
||||||
echKeys, err := UnmarshalECHKeys(block.Bytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "parse ECH keys")
|
return E.Cause(err, "parse ECH keys")
|
||||||
}
|
}
|
||||||
@@ -85,21 +81,16 @@ func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error {
|
func parseECHKeys(echKey []byte) ([]tls.EncryptedClientHelloKey, error) {
|
||||||
echKey, err := os.ReadFile(echKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "reload ECH keys from ", echKeyPath)
|
|
||||||
}
|
|
||||||
block, _ := pem.Decode(echKey)
|
block, _ := pem.Decode(echKey)
|
||||||
if block == nil || block.Type != "ECH KEYS" {
|
if block == nil || block.Type != "ECH KEYS" {
|
||||||
return E.New("invalid ECH keys pem")
|
return nil, E.New("invalid ECH keys pem")
|
||||||
}
|
}
|
||||||
echKeys, err := UnmarshalECHKeys(block.Bytes)
|
echKeys, err := UnmarshalECHKeys(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "parse ECH keys")
|
return nil, E.Cause(err, "parse ECH keys")
|
||||||
}
|
}
|
||||||
tlsConfig.EncryptedClientHelloKeys = echKeys
|
return echKeys, nil
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ECHClientConfig struct {
|
type ECHClientConfig struct {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/fswatch"
|
"github.com/sagernet/fswatch"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
var errInsecureUnused = E.New("tls: insecure unused")
|
var errInsecureUnused = E.New("tls: insecure unused")
|
||||||
|
|
||||||
type STDServerConfig struct {
|
type STDServerConfig struct {
|
||||||
|
access sync.RWMutex
|
||||||
config *tls.Config
|
config *tls.Config
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
acmeService adapter.SimpleLifecycle
|
acmeService adapter.SimpleLifecycle
|
||||||
@@ -32,14 +34,22 @@ type STDServerConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) ServerName() string {
|
func (c *STDServerConfig) ServerName() string {
|
||||||
|
c.access.RLock()
|
||||||
|
defer c.access.RUnlock()
|
||||||
return c.config.ServerName
|
return c.config.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) SetServerName(serverName string) {
|
func (c *STDServerConfig) SetServerName(serverName string) {
|
||||||
c.config.ServerName = serverName
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
config := c.config.Clone()
|
||||||
|
config.ServerName = serverName
|
||||||
|
c.config = config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) NextProtos() []string {
|
func (c *STDServerConfig) NextProtos() []string {
|
||||||
|
c.access.RLock()
|
||||||
|
defer c.access.RUnlock()
|
||||||
if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
|
if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
|
||||||
return c.config.NextProtos[1:]
|
return c.config.NextProtos[1:]
|
||||||
} else {
|
} else {
|
||||||
@@ -48,11 +58,15 @@ func (c *STDServerConfig) NextProtos() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
config := c.config.Clone()
|
||||||
if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
|
if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
|
||||||
c.config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
|
config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
|
||||||
} else {
|
} else {
|
||||||
c.config.NextProtos = nextProto
|
config.NextProtos = nextProto
|
||||||
}
|
}
|
||||||
|
c.config = config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) Config() (*STDConfig, error) {
|
func (c *STDServerConfig) Config() (*STDConfig, error) {
|
||||||
@@ -77,9 +91,6 @@ func (c *STDServerConfig) Start() error {
|
|||||||
if c.acmeService != nil {
|
if c.acmeService != nil {
|
||||||
return c.acmeService.Start()
|
return c.acmeService.Start()
|
||||||
} else {
|
} else {
|
||||||
if c.certificatePath == "" && c.keyPath == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err := c.startWatcher()
|
err := c.startWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Warn("create fsnotify watcher: ", err)
|
c.logger.Warn("create fsnotify watcher: ", err)
|
||||||
@@ -99,6 +110,9 @@ func (c *STDServerConfig) startWatcher() error {
|
|||||||
if c.echKeyPath != "" {
|
if c.echKeyPath != "" {
|
||||||
watchPath = append(watchPath, c.echKeyPath)
|
watchPath = append(watchPath, c.echKeyPath)
|
||||||
}
|
}
|
||||||
|
if len(watchPath) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||||
Path: watchPath,
|
Path: watchPath,
|
||||||
Callback: func(path string) {
|
Callback: func(path string) {
|
||||||
@@ -138,13 +152,26 @@ func (c *STDServerConfig) certificateUpdated(path string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "reload key pair")
|
return E.Cause(err, "reload key pair")
|
||||||
}
|
}
|
||||||
c.config.Certificates = []tls.Certificate{keyPair}
|
c.access.Lock()
|
||||||
|
config := c.config.Clone()
|
||||||
|
config.Certificates = []tls.Certificate{keyPair}
|
||||||
|
c.config = config
|
||||||
|
c.access.Unlock()
|
||||||
c.logger.Info("reloaded TLS certificate")
|
c.logger.Info("reloaded TLS certificate")
|
||||||
} else if path == c.echKeyPath {
|
} else if path == c.echKeyPath {
|
||||||
err := reloadECHKeys(c.echKeyPath, c.config)
|
echKey, err := os.ReadFile(c.echKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "reload ECH keys from ", c.echKeyPath)
|
||||||
|
}
|
||||||
|
echKeys, err := parseECHKeys(echKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
c.access.Lock()
|
||||||
|
config := c.config.Clone()
|
||||||
|
config.EncryptedClientHelloKeys = echKeys
|
||||||
|
c.config = config
|
||||||
|
c.access.Unlock()
|
||||||
c.logger.Info("reloaded ECH keys")
|
c.logger.Info("reloaded ECH keys")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -262,7 +289,7 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &STDServerConfig{
|
serverConfig := &STDServerConfig{
|
||||||
config: tlsConfig,
|
config: tlsConfig,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
acmeService: acmeService,
|
acmeService: acmeService,
|
||||||
@@ -271,5 +298,11 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
|
|||||||
certificatePath: options.CertificatePath,
|
certificatePath: options.CertificatePath,
|
||||||
keyPath: options.KeyPath,
|
keyPath: options.KeyPath,
|
||||||
echKeyPath: echKeyPath,
|
echKeyPath: echKeyPath,
|
||||||
}, nil
|
}
|
||||||
|
serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
serverConfig.access.Lock()
|
||||||
|
defer serverConfig.access.Unlock()
|
||||||
|
return serverConfig.config, nil
|
||||||
|
}
|
||||||
|
return serverConfig, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -280,7 +280,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
logExchangedResponse(c.logger, ctx, response, timeToLive)
|
logExchangedResponse(c.logger, ctx, response, timeToLive)
|
||||||
return response, err
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
|
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
RcodeSuccess RcodeError = mDNS.RcodeSuccess
|
||||||
RcodeFormatError RcodeError = mDNS.RcodeFormatError
|
RcodeFormatError RcodeError = mDNS.RcodeFormatError
|
||||||
RcodeNameError RcodeError = mDNS.RcodeNameError
|
RcodeNameError RcodeError = mDNS.RcodeNameError
|
||||||
RcodeRefused RcodeError = mDNS.RcodeRefused
|
RcodeRefused RcodeError = mDNS.RcodeRefused
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr,
|
|||||||
if response.Rcode != mDNS.RcodeSuccess {
|
if response.Rcode != mDNS.RcodeSuccess {
|
||||||
err = dns.RcodeError(response.Rcode)
|
err = dns.RcodeError(response.Rcode)
|
||||||
} else if len(dns.MessageToAddresses(response)) == 0 {
|
} else if len(dns.MessageToAddresses(response)) == 0 {
|
||||||
err = E.New(fqdn, ": empty result")
|
err = dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfi
|
|||||||
if response.Rcode != mDNS.RcodeSuccess {
|
if response.Rcode != mDNS.RcodeSuccess {
|
||||||
err = dns.RcodeError(response.Rcode)
|
err = dns.RcodeError(response.Rcode)
|
||||||
} else if len(dns.MessageToAddresses(response)) == 0 {
|
} else if len(dns.MessageToAddresses(response)) == 0 {
|
||||||
err = E.New(fqdn, ": empty result")
|
err = dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
|
|||||||
Reference in New Issue
Block a user