diff --git a/common/certificate/store.go b/common/certificate/store.go index 34f20019..ee127278 100644 --- a/common/certificate/store.go +++ b/common/certificate/store.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" @@ -21,6 +22,7 @@ import ( var _ adapter.CertificateStore = (*Store)(nil) type Store struct { + access sync.RWMutex systemPool *x509.CertPool currentPool *x509.CertPool certificate string @@ -115,10 +117,14 @@ func (s *Store) Close() error { } func (s *Store) Pool() *x509.CertPool { + s.access.RLock() + defer s.access.RUnlock() return s.currentPool } func (s *Store) update() error { + s.access.Lock() + defer s.access.Unlock() var currentPool *x509.CertPool if s.systemPool == nil { currentPool = x509.NewCertPool() diff --git a/common/tls/ech.go b/common/tls/ech.go index 830a8d08..d264de9b 100644 --- a/common/tls/ech.go +++ b/common/tls/ech.go @@ -69,11 +69,7 @@ func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions, } else { return E.New("missing ECH keys") } - block, rest := pem.Decode(echKey) - if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 { - return E.New("invalid ECH keys pem") - } - echKeys, err := UnmarshalECHKeys(block.Bytes) + echKeys, err := parseECHKeys(echKey) if err != nil { return E.Cause(err, "parse ECH keys") } @@ -85,21 +81,16 @@ func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions, return nil } -func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error { - echKey, err := os.ReadFile(echKeyPath) - if err != nil { - return E.Cause(err, "reload ECH keys from ", echKeyPath) - } +func parseECHKeys(echKey []byte) ([]tls.EncryptedClientHelloKey, error) { block, _ := pem.Decode(echKey) if block == nil || block.Type != "ECH KEYS" { - return E.New("invalid ECH keys pem") + return nil, E.New("invalid ECH keys pem") } echKeys, err := UnmarshalECHKeys(block.Bytes) if err != nil { - return E.Cause(err, "parse ECH keys") + return nil, E.Cause(err, "parse ECH keys") } - tlsConfig.EncryptedClientHelloKeys = echKeys - return nil + return echKeys, nil } type ECHClientConfig struct { diff --git a/common/tls/std_server.go b/common/tls/std_server.go index 82ba71ed..541818fa 100644 --- a/common/tls/std_server.go +++ b/common/tls/std_server.go @@ -6,6 +6,7 @@ import ( "net" "os" "strings" + "sync" "time" "github.com/sagernet/fswatch" @@ -20,6 +21,7 @@ import ( var errInsecureUnused = E.New("tls: insecure unused") type STDServerConfig struct { + access sync.RWMutex config *tls.Config logger log.Logger acmeService adapter.SimpleLifecycle @@ -32,14 +34,22 @@ type STDServerConfig struct { } func (c *STDServerConfig) ServerName() string { + c.access.RLock() + defer c.access.RUnlock() return c.config.ServerName } func (c *STDServerConfig) SetServerName(serverName string) { - c.config.ServerName = serverName + c.access.Lock() + defer c.access.Unlock() + config := c.config.Clone() + config.ServerName = serverName + c.config = config } func (c *STDServerConfig) NextProtos() []string { + c.access.RLock() + defer c.access.RUnlock() if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol { return c.config.NextProtos[1:] } else { @@ -48,11 +58,15 @@ func (c *STDServerConfig) NextProtos() []string { } func (c *STDServerConfig) SetNextProtos(nextProto []string) { + c.access.Lock() + defer c.access.Unlock() + config := c.config.Clone() if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol { - c.config.NextProtos = append(c.config.NextProtos[:1], nextProto...) + config.NextProtos = append(c.config.NextProtos[:1], nextProto...) } else { - c.config.NextProtos = nextProto + config.NextProtos = nextProto } + c.config = config } func (c *STDServerConfig) Config() (*STDConfig, error) { @@ -77,9 +91,6 @@ func (c *STDServerConfig) Start() error { if c.acmeService != nil { return c.acmeService.Start() } else { - if c.certificatePath == "" && c.keyPath == "" { - return nil - } err := c.startWatcher() if err != nil { c.logger.Warn("create fsnotify watcher: ", err) @@ -99,6 +110,9 @@ func (c *STDServerConfig) startWatcher() error { if c.echKeyPath != "" { watchPath = append(watchPath, c.echKeyPath) } + if len(watchPath) == 0 { + return nil + } watcher, err := fswatch.NewWatcher(fswatch.Options{ Path: watchPath, Callback: func(path string) { @@ -138,13 +152,26 @@ func (c *STDServerConfig) certificateUpdated(path string) error { if err != nil { return E.Cause(err, "reload key pair") } - c.config.Certificates = []tls.Certificate{keyPair} + c.access.Lock() + config := c.config.Clone() + config.Certificates = []tls.Certificate{keyPair} + c.config = config + c.access.Unlock() c.logger.Info("reloaded TLS certificate") } else if path == c.echKeyPath { - err := reloadECHKeys(c.echKeyPath, c.config) + echKey, err := os.ReadFile(c.echKeyPath) + if err != nil { + return E.Cause(err, "reload ECH keys from ", c.echKeyPath) + } + echKeys, err := parseECHKeys(echKey) if err != nil { return err } + c.access.Lock() + config := c.config.Clone() + config.EncryptedClientHelloKeys = echKeys + c.config = config + c.access.Unlock() c.logger.Info("reloaded ECH keys") } return nil @@ -262,7 +289,7 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound return nil, err } } - return &STDServerConfig{ + serverConfig := &STDServerConfig{ config: tlsConfig, logger: logger, acmeService: acmeService, @@ -271,5 +298,11 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound certificatePath: options.CertificatePath, keyPath: options.KeyPath, echKeyPath: echKeyPath, - }, nil + } + serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + serverConfig.access.Lock() + defer serverConfig.access.Unlock() + return serverConfig.config, nil + } + return serverConfig, nil } diff --git a/dns/client.go b/dns/client.go index 0d0e712b..6063b1c6 100644 --- a/dns/client.go +++ b/dns/client.go @@ -280,7 +280,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m } } logExchangedResponse(c.logger, ctx, response, timeToLive) - return response, err + return response, nil } func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { diff --git a/dns/rcode.go b/dns/rcode.go index 08545474..59c564b6 100644 --- a/dns/rcode.go +++ b/dns/rcode.go @@ -5,6 +5,7 @@ import ( ) const ( + RcodeSuccess RcodeError = mDNS.RcodeSuccess RcodeFormatError RcodeError = mDNS.RcodeFormatError RcodeNameError RcodeError = mDNS.RcodeNameError RcodeRefused RcodeError = mDNS.RcodeRefused diff --git a/dns/transport/dhcp/dhcp_shared.go b/dns/transport/dhcp/dhcp_shared.go index 3d8d5512..086d1d68 100644 --- a/dns/transport/dhcp/dhcp_shared.go +++ b/dns/transport/dhcp/dhcp_shared.go @@ -43,7 +43,7 @@ func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr, if response.Rcode != mDNS.RcodeSuccess { err = dns.RcodeError(response.Rcode) } else if len(dns.MessageToAddresses(response)) == 0 { - err = E.New(fqdn, ": empty result") + err = dns.RcodeSuccess } } select { diff --git a/dns/transport/local/local.go b/dns/transport/local/local.go index 7c26b936..8187681a 100644 --- a/dns/transport/local/local.go +++ b/dns/transport/local/local.go @@ -95,7 +95,7 @@ func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfi if response.Rcode != mDNS.RcodeSuccess { err = dns.RcodeError(response.Rcode) } else if len(dns.MessageToAddresses(response)) == 0 { - err = E.New(fqdn, ": empty result") + err = dns.RcodeSuccess } } select {