From 95b664a772fbc8869d4afbfdd30a1c90d08ab7f8 Mon Sep 17 00:00:00 2001 From: CN-JS-HuiBai Date: Thu, 16 Apr 2026 21:40:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DDNS=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- adapter/dns.go | 31 +- adapter/dns.go.834007539997830043 | 122 +++ adapter/dns_test.go | 137 ++++ dns/router.go | 8 +- dns/router.go.7853130561294400935 | 464 +++++++++++ dns/transport_dialer.go | 5 +- dns/transport_dialer.go.6124520021629164267 | 108 +++ route/route.go | 10 +- route/route.go.3955832358233176917 | 821 ++++++++++++++++++++ 9 files changed, 1700 insertions(+), 6 deletions(-) create mode 100644 adapter/dns.go.834007539997830043 create mode 100644 adapter/dns_test.go create mode 100644 dns/router.go.7853130561294400935 create mode 100644 dns/transport_dialer.go.6124520021629164267 create mode 100644 route/route.go.3955832358233176917 diff --git a/adapter/dns.go b/adapter/dns.go index 8f065e2e..d827aaba 100644 --- a/adapter/dns.go +++ b/adapter/dns.go @@ -39,13 +39,42 @@ type DNSQueryOptions struct { ClientSubnet netip.Prefix } +func LookupDNSTransport(manager DNSTransportManager, reference string) (DNSTransport, bool, bool) { + transport, loaded := manager.Transport(reference) + if loaded { + return transport, true, false + } + switch reference { + case C.DNSTypeLocal, C.DNSTypeFakeIP: + default: + return nil, false, false + } + var matchedTransport DNSTransport + for _, transport := range manager.Transports() { + if transport.Type() != reference { + continue + } + if matchedTransport != nil { + return nil, false, true + } + matchedTransport = transport + } + if matchedTransport != nil { + return matchedTransport, true, false + } + return nil, false, false +} + func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptions) (*DNSQueryOptions, error) { if options == nil { return &DNSQueryOptions{}, nil } transportManager := service.FromContext[DNSTransportManager](ctx) - transport, loaded := transportManager.Transport(options.Server) + transport, loaded, ambiguous := LookupDNSTransport(transportManager, options.Server) if !loaded { + if ambiguous { + return nil, E.New("domain resolver is ambiguous: " + options.Server) + } return nil, E.New("domain resolver not found: " + options.Server) } return &DNSQueryOptions{ diff --git a/adapter/dns.go.834007539997830043 b/adapter/dns.go.834007539997830043 new file mode 100644 index 00000000..d827aaba --- /dev/null +++ b/adapter/dns.go.834007539997830043 @@ -0,0 +1,122 @@ +package adapter + +import ( + "context" + "net/netip" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/service" + + "github.com/miekg/dns" +) + +type DNSRouter interface { + Lifecycle + Exchange(ctx context.Context, message *dns.Msg, options DNSQueryOptions) (*dns.Msg, error) + Lookup(ctx context.Context, domain string, options DNSQueryOptions) ([]netip.Addr, error) + ClearCache() + LookupReverseMapping(ip netip.Addr) (string, bool) + ResetNetwork() +} + +type DNSClient interface { + Start() + Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) + Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) + ClearCache() +} + +type DNSQueryOptions struct { + Transport DNSTransport + Strategy C.DomainStrategy + LookupStrategy C.DomainStrategy + DisableCache bool + RewriteTTL *uint32 + ClientSubnet netip.Prefix +} + +func LookupDNSTransport(manager DNSTransportManager, reference string) (DNSTransport, bool, bool) { + transport, loaded := manager.Transport(reference) + if loaded { + return transport, true, false + } + switch reference { + case C.DNSTypeLocal, C.DNSTypeFakeIP: + default: + return nil, false, false + } + var matchedTransport DNSTransport + for _, transport := range manager.Transports() { + if transport.Type() != reference { + continue + } + if matchedTransport != nil { + return nil, false, true + } + matchedTransport = transport + } + if matchedTransport != nil { + return matchedTransport, true, false + } + return nil, false, false +} + +func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptions) (*DNSQueryOptions, error) { + if options == nil { + return &DNSQueryOptions{}, nil + } + transportManager := service.FromContext[DNSTransportManager](ctx) + transport, loaded, ambiguous := LookupDNSTransport(transportManager, options.Server) + if !loaded { + if ambiguous { + return nil, E.New("domain resolver is ambiguous: " + options.Server) + } + return nil, E.New("domain resolver not found: " + options.Server) + } + return &DNSQueryOptions{ + Transport: transport, + Strategy: C.DomainStrategy(options.Strategy), + DisableCache: options.DisableCache, + RewriteTTL: options.RewriteTTL, + ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), + }, nil +} + +type RDRCStore interface { + LoadRDRC(transportName string, qName string, qType uint16) (rejected bool) + SaveRDRC(transportName string, qName string, qType uint16) error + SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger) +} + +type DNSTransport interface { + Lifecycle + Type() string + Tag() string + Dependencies() []string + Reset() + Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) +} + +type LegacyDNSTransport interface { + LegacyStrategy() C.DomainStrategy + LegacyClientSubnet() netip.Prefix +} + +type DNSTransportRegistry interface { + option.DNSTransportOptionsRegistry + CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error) +} + +type DNSTransportManager interface { + Lifecycle + Transports() []DNSTransport + Transport(tag string) (DNSTransport, bool) + Default() DNSTransport + FakeIP() FakeIPTransport + Remove(tag string) error + Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) error +} diff --git a/adapter/dns_test.go b/adapter/dns_test.go new file mode 100644 index 00000000..92e5fa4e --- /dev/null +++ b/adapter/dns_test.go @@ -0,0 +1,137 @@ +package adapter + +import ( + "context" + "testing" + + "github.com/sagernet/sing-box/log" + "github.com/stretchr/testify/require" + + mDNS "github.com/miekg/dns" +) + +type testDNSTransport struct { + transportType string + tag string +} + +func (t *testDNSTransport) Start(stage StartStage) error { + return nil +} + +func (t *testDNSTransport) Close() error { + return nil +} + +func (t *testDNSTransport) Type() string { + return t.transportType +} + +func (t *testDNSTransport) Tag() string { + return t.tag +} + +func (t *testDNSTransport) Dependencies() []string { + return nil +} + +func (t *testDNSTransport) Reset() { +} + +func (t *testDNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + return nil, nil +} + +type testDNSTransportManager struct { + transports []DNSTransport + transportByTag map[string]DNSTransport + fakeIPTransport FakeIPTransport +} + +func newTestDNSTransportManager(transports ...DNSTransport) *testDNSTransportManager { + manager := &testDNSTransportManager{ + transports: transports, + transportByTag: make(map[string]DNSTransport), + } + for _, transport := range transports { + manager.transportByTag[transport.Tag()] = transport + } + return manager +} + +func (m *testDNSTransportManager) Start(stage StartStage) error { + return nil +} + +func (m *testDNSTransportManager) Close() error { + return nil +} + +func (m *testDNSTransportManager) Transports() []DNSTransport { + return m.transports +} + +func (m *testDNSTransportManager) Transport(tag string) (DNSTransport, bool) { + transport, loaded := m.transportByTag[tag] + return transport, loaded +} + +func (m *testDNSTransportManager) Default() DNSTransport { + return nil +} + +func (m *testDNSTransportManager) FakeIP() FakeIPTransport { + return m.fakeIPTransport +} + +func (m *testDNSTransportManager) Remove(tag string) error { + return nil +} + +func (m *testDNSTransportManager) Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) error { + return nil +} + +func TestLookupDNSTransportLocalAlias(t *testing.T) { + t.Parallel() + + localTransport := &testDNSTransport{ + transportType: "local", + tag: "dns-local", + } + manager := newTestDNSTransportManager(localTransport) + + transport, loaded, ambiguous := LookupDNSTransport(manager, "local") + require.True(t, loaded) + require.False(t, ambiguous) + require.Same(t, localTransport, transport) +} + +func TestLookupDNSTransportExactTagPreferred(t *testing.T) { + t.Parallel() + + localTransport := &testDNSTransport{ + transportType: "local", + tag: "local", + } + manager := newTestDNSTransportManager(localTransport) + + transport, loaded, ambiguous := LookupDNSTransport(manager, "local") + require.True(t, loaded) + require.False(t, ambiguous) + require.Same(t, localTransport, transport) +} + +func TestLookupDNSTransportLocalAliasAmbiguous(t *testing.T) { + t.Parallel() + + manager := newTestDNSTransportManager( + &testDNSTransport{transportType: "local", tag: "dns-local-a"}, + &testDNSTransport{transportType: "local", tag: "dns-local-b"}, + ) + + transport, loaded, ambiguous := LookupDNSTransport(manager, "local") + require.Nil(t, transport) + require.False(t, loaded) + require.True(t, ambiguous) +} diff --git a/dns/router.go b/dns/router.go index 4f18959b..bce41abc 100644 --- a/dns/router.go +++ b/dns/router.go @@ -145,9 +145,13 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, } switch action := currentRule.Action().(type) { case *R.RuleActionDNSRoute: - transport, loaded := r.transport.Transport(action.Server) + transport, loaded, ambiguous := adapter.LookupDNSTransport(r.transport, action.Server) if !loaded { - r.logger.ErrorContext(ctx, "transport not found: ", action.Server) + if ambiguous { + r.logger.ErrorContext(ctx, "transport is ambiguous: ", action.Server) + } else { + r.logger.ErrorContext(ctx, "transport not found: ", action.Server) + } continue } isFakeIP := transport.Type() == C.DNSTypeFakeIP diff --git a/dns/router.go.7853130561294400935 b/dns/router.go.7853130561294400935 new file mode 100644 index 00000000..bce41abc --- /dev/null +++ b/dns/router.go.7853130561294400935 @@ -0,0 +1,464 @@ +package dns + +import ( + "context" + "errors" + "net/netip" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + R "github.com/sagernet/sing-box/route/rule" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" + "github.com/sagernet/sing/service" + + mDNS "github.com/miekg/dns" +) + +var _ adapter.DNSRouter = (*Router)(nil) + +type Router struct { + ctx context.Context + logger logger.ContextLogger + transport adapter.DNSTransportManager + outbound adapter.OutboundManager + client adapter.DNSClient + rules []adapter.DNSRule + defaultDomainStrategy C.DomainStrategy + dnsReverseMapping freelru.Cache[netip.Addr, string] + platformInterface adapter.PlatformInterface +} + +func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router { + router := &Router{ + ctx: ctx, + logger: logFactory.NewLogger("dns"), + transport: service.FromContext[adapter.DNSTransportManager](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), + rules: make([]adapter.DNSRule, 0, len(options.Rules)), + defaultDomainStrategy: C.DomainStrategy(options.Strategy), + } + router.client = NewClient(ClientOptions{ + DisableCache: options.DNSClientOptions.DisableCache, + DisableExpire: options.DNSClientOptions.DisableExpire, + IndependentCache: options.DNSClientOptions.IndependentCache, + CacheCapacity: options.DNSClientOptions.CacheCapacity, + ClientSubnet: options.DNSClientOptions.ClientSubnet.Build(netip.Prefix{}), + RDRC: func() adapter.RDRCStore { + cacheFile := service.FromContext[adapter.CacheFile](ctx) + if cacheFile == nil { + return nil + } + if !cacheFile.StoreRDRC() { + return nil + } + return cacheFile + }, + Logger: router.logger, + }) + if options.ReverseMapping { + router.dnsReverseMapping = common.Must1(freelru.NewSharded[netip.Addr, string](1024, maphash.NewHasher[netip.Addr]().Hash32)) + } + return router +} + +func (r *Router) Initialize(rules []option.DNSRule) error { + for i, ruleOptions := range rules { + dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true) + if err != nil { + return E.Cause(err, "parse dns rule[", i, "]") + } + r.rules = append(r.rules, dnsRule) + } + return nil +} + +func (r *Router) Start(stage adapter.StartStage) error { + monitor := taskmonitor.New(r.logger, C.StartTimeout) + switch stage { + case adapter.StartStateStart: + monitor.Start("initialize DNS client") + r.client.Start() + monitor.Finish() + + for i, rule := range r.rules { + monitor.Start("initialize DNS rule[", i, "]") + err := rule.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "initialize DNS rule[", i, "]") + } + } + } + return nil +} + +func (r *Router) Close() error { + monitor := taskmonitor.New(r.logger, C.StopTimeout) + var err error + for i, rule := range r.rules { + monitor.Start("close dns rule[", i, "]") + err = E.Append(err, rule.Close(), func(err error) error { + return E.Cause(err, "close dns rule[", i, "]") + }) + monitor.Finish() + } + return err +} + +func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) { + metadata := adapter.ContextFrom(ctx) + if metadata == nil { + panic("no context") + } + var currentRuleIndex int + if ruleIndex != -1 { + currentRuleIndex = ruleIndex + 1 + } + for ; currentRuleIndex < len(r.rules); currentRuleIndex++ { + currentRule := r.rules[currentRuleIndex] + if currentRule.WithAddressLimit() && !isAddressQuery { + continue + } + metadata.ResetRuleCache() + if currentRule.Match(metadata) { + displayRuleIndex := currentRuleIndex + if displayRuleIndex != -1 { + displayRuleIndex += displayRuleIndex + 1 + } + ruleDescription := currentRule.String() + if ruleDescription != "" { + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action()) + } else { + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + } + switch action := currentRule.Action().(type) { + case *R.RuleActionDNSRoute: + transport, loaded, ambiguous := adapter.LookupDNSTransport(r.transport, action.Server) + if !loaded { + if ambiguous { + r.logger.ErrorContext(ctx, "transport is ambiguous: ", action.Server) + } else { + r.logger.ErrorContext(ctx, "transport not found: ", action.Server) + } + continue + } + isFakeIP := transport.Type() == C.DNSTypeFakeIP + if isFakeIP && !allowFakeIP { + continue + } + if action.Strategy != C.DomainStrategyAsIS { + options.Strategy = action.Strategy + } + if isFakeIP || action.DisableCache { + options.DisableCache = true + } + if action.RewriteTTL != nil { + options.RewriteTTL = action.RewriteTTL + } + if action.ClientSubnet.IsValid() { + options.ClientSubnet = action.ClientSubnet + } + if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = legacyTransport.LegacyStrategy() + } + if !options.ClientSubnet.IsValid() { + options.ClientSubnet = legacyTransport.LegacyClientSubnet() + } + } + return transport, currentRule, currentRuleIndex + case *R.RuleActionDNSRouteOptions: + if action.Strategy != C.DomainStrategyAsIS { + options.Strategy = action.Strategy + } + if action.DisableCache { + options.DisableCache = true + } + if action.RewriteTTL != nil { + options.RewriteTTL = action.RewriteTTL + } + if action.ClientSubnet.IsValid() { + options.ClientSubnet = action.ClientSubnet + } + case *R.RuleActionReject: + return nil, currentRule, currentRuleIndex + case *R.RuleActionPredefined: + return nil, currentRule, currentRuleIndex + } + } + } + transport := r.transport.Default() + if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = legacyTransport.LegacyStrategy() + } + if !options.ClientSubnet.IsValid() { + options.ClientSubnet = legacyTransport.LegacyClientSubnet() + } + } + return transport, nil, -1 +} + +func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) { + if len(message.Question) != 1 { + r.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) + responseMessage := mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Response: true, + Rcode: mDNS.RcodeFormatError, + }, + Question: message.Question, + } + return &responseMessage, nil + } + r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String())) + var ( + response *mDNS.Msg + transport adapter.DNSTransport + err error + ) + var metadata *adapter.InboundContext + ctx, metadata = adapter.ExtendContext(ctx) + metadata.Destination = M.Socksaddr{} + metadata.QueryType = message.Question[0].Qtype + switch metadata.QueryType { + case mDNS.TypeA: + metadata.IPVersion = 4 + case mDNS.TypeAAAA: + metadata.IPVersion = 6 + } + metadata.Domain = FqdnToDomain(message.Question[0].Name) + if options.Transport != nil { + transport = options.Transport + if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = legacyTransport.LegacyStrategy() + } + if !options.ClientSubnet.IsValid() { + options.ClientSubnet = legacyTransport.LegacyClientSubnet() + } + } + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = r.defaultDomainStrategy + } + response, err = r.client.Exchange(ctx, transport, message, options, nil) + } else { + var ( + rule adapter.DNSRule + ruleIndex int + ) + ruleIndex = -1 + for { + dnsCtx := adapter.OverrideContext(ctx) + dnsOptions := options + transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions) + if rule != nil { + switch action := rule.Action().(type) { + case *R.RuleActionReject: + switch action.Method { + case C.RuleActionRejectMethodDefault: + return &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Rcode: mDNS.RcodeRefused, + Response: true, + }, + Question: []mDNS.Question{message.Question[0]}, + }, nil + case C.RuleActionRejectMethodDrop: + return nil, tun.ErrDrop + } + case *R.RuleActionPredefined: + return action.Response(message), nil + } + } + responseCheck := addressLimitResponseCheck(rule, metadata) + if dnsOptions.Strategy == C.DomainStrategyAsIS { + dnsOptions.Strategy = r.defaultDomainStrategy + } + response, err = r.client.Exchange(dnsCtx, transport, message, dnsOptions, responseCheck) + var rejected bool + if err != nil { + if errors.Is(err, ErrResponseRejectedCached) { + rejected = true + r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String())), " (cached)") + } else if errors.Is(err, ErrResponseRejected) { + rejected = true + r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String()))) + } else if len(message.Question) > 0 { + r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", FormatQuestion(message.Question[0].String()))) + } else { + r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for ")) + } + } + if responseCheck != nil && rejected { + continue + } + break + } + } + if err != nil { + return nil, err + } + if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 { + if transport == nil || transport.Type() != C.DNSTypeFakeIP { + for _, answer := range response.Answer { + switch record := answer.(type) { + case *mDNS.A: + r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.A), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second) + case *mDNS.AAAA: + r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.AAAA), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second) + } + } + } + } + return response, nil +} + +func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { + var ( + responseAddrs []netip.Addr + err error + ) + printResult := func() { + if err == nil && len(responseAddrs) == 0 { + err = E.New("empty result") + } + if err != nil { + if errors.Is(err, ErrResponseRejectedCached) { + r.logger.DebugContext(ctx, "response rejected for ", domain, " (cached)") + } else if errors.Is(err, ErrResponseRejected) { + r.logger.DebugContext(ctx, "response rejected for ", domain) + } else { + r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain)) + } + } + if err != nil { + err = E.Cause(err, "lookup ", domain) + } + } + r.logger.DebugContext(ctx, "lookup domain ", domain) + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Destination = M.Socksaddr{} + metadata.Domain = FqdnToDomain(domain) + if options.Transport != nil { + transport := options.Transport + if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = legacyTransport.LegacyStrategy() + } + if !options.ClientSubnet.IsValid() { + options.ClientSubnet = legacyTransport.LegacyClientSubnet() + } + } + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = r.defaultDomainStrategy + } + responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil) + } else { + var ( + transport adapter.DNSTransport + rule adapter.DNSRule + ruleIndex int + ) + ruleIndex = -1 + for { + dnsCtx := adapter.OverrideContext(ctx) + dnsOptions := options + transport, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true, &dnsOptions) + if rule != nil { + switch action := rule.Action().(type) { + case *R.RuleActionReject: + return nil, &R.RejectedError{Cause: action.Error(ctx)} + case *R.RuleActionPredefined: + responseAddrs = nil + if action.Rcode != mDNS.RcodeSuccess { + err = RcodeError(action.Rcode) + } else { + err = nil + for _, answer := range action.Answer { + switch record := answer.(type) { + case *mDNS.A: + responseAddrs = append(responseAddrs, M.AddrFromIP(record.A)) + case *mDNS.AAAA: + responseAddrs = append(responseAddrs, M.AddrFromIP(record.AAAA)) + } + } + } + goto response + } + } + responseCheck := addressLimitResponseCheck(rule, metadata) + if dnsOptions.Strategy == C.DomainStrategyAsIS { + dnsOptions.Strategy = r.defaultDomainStrategy + } + responseAddrs, err = r.client.Lookup(dnsCtx, transport, domain, dnsOptions, responseCheck) + if responseCheck == nil || err == nil { + break + } + printResult() + } + } +response: + printResult() + if len(responseAddrs) > 0 { + r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " ")) + } + return responseAddrs, err +} + +func isAddressQuery(message *mDNS.Msg) bool { + for _, question := range message.Question { + if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA || question.Qtype == mDNS.TypeHTTPS { + return true + } + } + return false +} + +func addressLimitResponseCheck(rule adapter.DNSRule, metadata *adapter.InboundContext) func(responseAddrs []netip.Addr) bool { + if rule == nil || !rule.WithAddressLimit() { + return nil + } + responseMetadata := *metadata + return func(responseAddrs []netip.Addr) bool { + checkMetadata := responseMetadata + checkMetadata.DestinationAddresses = responseAddrs + return rule.MatchAddressLimit(&checkMetadata) + } +} + +func (r *Router) ClearCache() { + r.client.ClearCache() + if r.platformInterface != nil { + r.platformInterface.ClearDNSCache() + } +} + +func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) { + if r.dnsReverseMapping == nil { + return "", false + } + domain, loaded := r.dnsReverseMapping.Get(ip) + return domain, loaded +} + +func (r *Router) ResetNetwork() { + r.ClearCache() + for _, transport := range r.transport.Transports() { + transport.Reset() + } +} diff --git a/dns/transport_dialer.go b/dns/transport_dialer.go index b3ee8082..2f60006b 100644 --- a/dns/transport_dialer.go +++ b/dns/transport_dialer.go @@ -33,8 +33,11 @@ func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) transportDialer := dialer.NewDefaultOutbound(ctx) if options.LegacyAddressResolver != "" { transport := service.FromContext[adapter.DNSTransportManager](ctx) - resolverTransport, loaded := transport.Transport(options.LegacyAddressResolver) + resolverTransport, loaded, ambiguous := adapter.LookupDNSTransport(transport, options.LegacyAddressResolver) if !loaded { + if ambiguous { + return nil, E.New("address resolver is ambiguous: ", options.LegacyAddressResolver) + } return nil, E.New("address resolver not found: ", options.LegacyAddressResolver) } transportDialer = newTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.LegacyAddressStrategy), time.Duration(options.LegacyAddressFallbackDelay)) diff --git a/dns/transport_dialer.go.6124520021629164267 b/dns/transport_dialer.go.6124520021629164267 new file mode 100644 index 00000000..2f60006b --- /dev/null +++ b/dns/transport_dialer.go.6124520021629164267 @@ -0,0 +1,108 @@ +package dns + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" +) + +func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (N.Dialer, error) { + if options.LegacyDefaultDialer { + return dialer.NewDefaultOutbound(ctx), nil + } else { + return dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: options.DialerOptions, + DirectResolver: true, + LegacyDNSDialer: options.Legacy, + }) + } +} + +func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) { + if options.LegacyDefaultDialer { + transportDialer := dialer.NewDefaultOutbound(ctx) + if options.LegacyAddressResolver != "" { + transport := service.FromContext[adapter.DNSTransportManager](ctx) + resolverTransport, loaded, ambiguous := adapter.LookupDNSTransport(transport, options.LegacyAddressResolver) + if !loaded { + if ambiguous { + return nil, E.New("address resolver is ambiguous: ", options.LegacyAddressResolver) + } + return nil, E.New("address resolver not found: ", options.LegacyAddressResolver) + } + transportDialer = newTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.LegacyAddressStrategy), time.Duration(options.LegacyAddressFallbackDelay)) + } else if options.ServerIsDomain() { + return nil, E.New("missing address resolver for server: ", options.Server) + } + return transportDialer, nil + } else { + return dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: options.DialerOptions, + RemoteIsDomain: options.ServerIsDomain(), + DirectResolver: true, + LegacyDNSDialer: options.Legacy, + }) + } +} + +type legacyTransportDialer struct { + dialer N.Dialer + dnsRouter adapter.DNSRouter + transport adapter.DNSTransport + strategy C.DomainStrategy + fallbackDelay time.Duration +} + +func newTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *legacyTransportDialer { + return &legacyTransportDialer{ + dialer, + dnsRouter, + transport, + strategy, + fallbackDelay, + } +} + +func (d *legacyTransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if destination.IsIP() { + return d.dialer.DialContext(ctx, network, destination) + } + addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{ + Transport: d.transport, + Strategy: d.strategy, + }) + if err != nil { + return nil, err + } + return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay) +} + +func (d *legacyTransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if destination.IsIP() { + return d.dialer.ListenPacket(ctx, destination) + } + addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{ + Transport: d.transport, + Strategy: d.strategy, + }) + if err != nil { + return nil, err + } + conn, _, err := N.ListenSerial(ctx, d.dialer, destination, addresses) + return conn, err +} + +func (d *legacyTransportDialer) Upstream() any { + return d.dialer +} diff --git a/route/route.go b/route/route.go index 77b66ea4..97aa83e3 100644 --- a/route/route.go +++ b/route/route.go @@ -792,9 +792,15 @@ func (r *Router) actionResolve(ctx context.Context, metadata *adapter.InboundCon if metadata.Destination.IsDomain() { var transport adapter.DNSTransport if action.Server != "" { - var loaded bool - transport, loaded = r.dnsTransport.Transport(action.Server) + var ( + loaded bool + ambiguous bool + ) + transport, loaded, ambiguous = adapter.LookupDNSTransport(r.dnsTransport, action.Server) if !loaded { + if ambiguous { + return E.New("DNS server is ambiguous: ", action.Server) + } return E.New("DNS server not found: ", action.Server) } } diff --git a/route/route.go.3955832358233176917 b/route/route.go.3955832358233176917 new file mode 100644 index 00000000..97aa83e3 --- /dev/null +++ b/route/route.go.3955832358233176917 @@ -0,0 +1,821 @@ +package route + +import ( + "context" + "errors" + "net" + "net/netip" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/sniff" + C "github.com/sagernet/sing-box/constant" + R "github.com/sagernet/sing-box/route/rule" + "github.com/sagernet/sing-mux" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" + "github.com/sagernet/sing-vmess" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/bufio/deadline" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/uot" + + "golang.org/x/exp/slices" +) + +// Deprecated: use RouteConnectionEx instead. +func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + done := make(chan interface{}) + err := r.routeConnection(ctx, conn, metadata, N.OnceClose(func(it error) { + close(done) + })) + if err != nil { + return err + } + select { + case <-done: + case <-r.ctx.Done(): + } + return nil +} + +func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + err := r.routeConnection(ctx, conn, metadata, onClose) + if err != nil { + N.CloseOnHandshakeFailure(conn, onClose, err) + if E.IsClosedOrCanceled(err) || R.IsRejected(err) { + r.logger.DebugContext(ctx, "connection closed: ", err) + } else { + r.logger.ErrorContext(ctx, err) + } + } +} + +func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + //nolint:staticcheck + if metadata.InboundDetour != "" { + if metadata.LastInbound == metadata.InboundDetour { + return E.New("routing loop on detour: ", metadata.InboundDetour) + } + detour, loaded := r.inbound.Get(metadata.InboundDetour) + if !loaded { + return E.New("inbound detour not found: ", metadata.InboundDetour) + } + injectable, isInjectable := detour.(adapter.TCPInjectableInbound) + if !isInjectable { + return E.New("inbound detour is not TCP injectable: ", metadata.InboundDetour) + } + metadata.LastInbound = metadata.Inbound + metadata.Inbound = metadata.InboundDetour + metadata.InboundDetour = "" + injectable.NewConnectionEx(ctx, conn, metadata, onClose) + return nil + } + metadata.Network = N.NetworkTCP + switch metadata.Destination.Fqdn { + case mux.Destination.Fqdn: + return E.New("global multiplex is deprecated since sing-box v1.7.0, enable multiplex in Inbound fields instead.") + case vmess.MuxDestination.Fqdn: + return E.New("global multiplex (v2ray legacy) not supported since sing-box v1.7.0.") + case uot.MagicAddress: + return E.New("global UoT not supported since sing-box v1.7.0.") + case uot.LegacyMagicAddress: + return E.New("global UoT (legacy) not supported since sing-box v1.7.0.") + } + if deadline.NeedAdditionalReadDeadline(conn) { + conn = deadline.NewConn(conn) + } + selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, false, conn, nil) + if err != nil { + return err + } + var selectedOutbound adapter.Outbound + if selectedRule != nil { + switch action := selectedRule.Action().(type) { + case *R.RuleActionRoute: + var loaded bool + selectedOutbound, loaded = r.outbound.Outbound(action.Outbound) + if !loaded { + buf.ReleaseMulti(buffers) + return E.New("outbound not found: ", action.Outbound) + } + if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) { + buf.ReleaseMulti(buffers) + return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag()) + } + case *R.RuleActionBypass: + if action.Outbound == "" { + break + } + var loaded bool + selectedOutbound, loaded = r.outbound.Outbound(action.Outbound) + if !loaded { + buf.ReleaseMulti(buffers) + return E.New("outbound not found: ", action.Outbound) + } + if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) { + buf.ReleaseMulti(buffers) + return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag()) + } + case *R.RuleActionReject: + buf.ReleaseMulti(buffers) + if action.Method == C.RuleActionRejectMethodReply { + return E.New("reject method `reply` is not supported for TCP connections") + } + return action.Error(ctx) + case *R.RuleActionHijackDNS: + for _, buffer := range buffers { + conn = bufio.NewCachedConn(conn, buffer) + } + N.CloseOnHandshakeFailure(conn, onClose, r.hijackDNSStream(ctx, conn, metadata)) + return nil + } + } + if selectedRule == nil { + defaultOutbound := r.outbound.Default() + if !common.Contains(defaultOutbound.Network(), N.NetworkTCP) { + buf.ReleaseMulti(buffers) + return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag()) + } + selectedOutbound = defaultOutbound + } + + for _, buffer := range buffers { + conn = bufio.NewCachedConn(conn, buffer) + } + for _, tracker := range r.trackers { + conn = tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound) + } + if outboundHandler, isHandler := selectedOutbound.(adapter.ConnectionHandlerEx); isHandler { + outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose) + } else { + r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose) + } + return nil +} + +func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + done := make(chan interface{}) + err := r.routePacketConnection(ctx, conn, metadata, N.OnceClose(func(it error) { + close(done) + })) + if err != nil { + conn.Close() + if E.IsClosedOrCanceled(err) || R.IsRejected(err) { + r.logger.DebugContext(ctx, "connection closed: ", err) + } else { + r.logger.ErrorContext(ctx, err) + } + } + select { + case <-done: + case <-r.ctx.Done(): + } + return nil +} + +func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + err := r.routePacketConnection(ctx, conn, metadata, onClose) + if err != nil { + N.CloseOnHandshakeFailure(conn, onClose, err) + if E.IsClosedOrCanceled(err) || R.IsRejected(err) { + r.logger.DebugContext(ctx, "connection closed: ", err) + } else { + r.logger.ErrorContext(ctx, err) + } + } +} + +func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + //nolint:staticcheck + if metadata.InboundDetour != "" { + if metadata.LastInbound == metadata.InboundDetour { + return E.New("routing loop on detour: ", metadata.InboundDetour) + } + detour, loaded := r.inbound.Get(metadata.InboundDetour) + if !loaded { + return E.New("inbound detour not found: ", metadata.InboundDetour) + } + injectable, isInjectable := detour.(adapter.UDPInjectableInbound) + if !isInjectable { + return E.New("inbound detour is not UDP injectable: ", metadata.InboundDetour) + } + metadata.LastInbound = metadata.Inbound + metadata.Inbound = metadata.InboundDetour + metadata.InboundDetour = "" + injectable.NewPacketConnectionEx(ctx, conn, metadata, onClose) + return nil + } + // TODO: move to UoT + metadata.Network = N.NetworkUDP + + // Currently we don't have deadline usages for UDP connections + /*if deadline.NeedAdditionalReadDeadline(conn) { + conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn)) + }*/ + + selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, false, nil, conn) + if err != nil { + return err + } + var selectedOutbound adapter.Outbound + var selectReturn bool + if selectedRule != nil { + switch action := selectedRule.Action().(type) { + case *R.RuleActionRoute: + var loaded bool + selectedOutbound, loaded = r.outbound.Outbound(action.Outbound) + if !loaded { + N.ReleaseMultiPacketBuffer(packetBuffers) + return E.New("outbound not found: ", action.Outbound) + } + if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) { + N.ReleaseMultiPacketBuffer(packetBuffers) + return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag()) + } + case *R.RuleActionBypass: + if action.Outbound == "" { + break + } + var loaded bool + selectedOutbound, loaded = r.outbound.Outbound(action.Outbound) + if !loaded { + N.ReleaseMultiPacketBuffer(packetBuffers) + return E.New("outbound not found: ", action.Outbound) + } + if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) { + N.ReleaseMultiPacketBuffer(packetBuffers) + return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag()) + } + case *R.RuleActionReject: + N.ReleaseMultiPacketBuffer(packetBuffers) + if action.Method == C.RuleActionRejectMethodReply { + return E.New("reject method `reply` is not supported for UDP connections") + } + return action.Error(ctx) + case *R.RuleActionHijackDNS: + return r.hijackDNSPacket(ctx, conn, packetBuffers, metadata, onClose) + } + } + if selectedRule == nil || selectReturn { + defaultOutbound := r.outbound.Default() + if !common.Contains(defaultOutbound.Network(), N.NetworkUDP) { + N.ReleaseMultiPacketBuffer(packetBuffers) + return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag()) + } + selectedOutbound = defaultOutbound + } + for _, buffer := range packetBuffers { + conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination) + N.PutPacketBuffer(buffer) + } + for _, tracker := range r.trackers { + conn = tracker.RoutedPacketConnection(ctx, conn, metadata, selectedRule, selectedOutbound) + } + if metadata.FakeIP { + conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) + } + if outboundHandler, isHandler := selectedOutbound.(adapter.PacketConnectionHandlerEx); isHandler { + outboundHandler.NewPacketConnectionEx(ctx, conn, metadata, onClose) + } else { + r.connection.NewPacketConnection(ctx, selectedOutbound, conn, metadata, onClose) + } + return nil +} + +func (r *Router) PreMatch(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, supportBypass, nil, nil) + if err != nil { + return nil, err + } + var directRouteOutbound adapter.DirectRouteOutbound + if selectedRule != nil { + switch action := selectedRule.Action().(type) { + case *R.RuleActionReject: + switch metadata.Network { + case N.NetworkTCP: + if action.Method == C.RuleActionRejectMethodReply { + return nil, E.New("reject method `reply` is not supported for TCP connections") + } + case N.NetworkUDP: + if action.Method == C.RuleActionRejectMethodReply { + return nil, E.New("reject method `reply` is not supported for UDP connections") + } + } + return nil, action.Error(context.Background()) + case *R.RuleActionBypass: + if supportBypass { + return nil, &R.BypassedError{Cause: tun.ErrBypass} + } + if routeContext == nil { + return nil, nil + } + outbound, loaded := r.outbound.Outbound(action.Outbound) + if !loaded { + return nil, E.New("outbound not found: ", action.Outbound) + } + if !common.Contains(outbound.Network(), metadata.Network) { + return nil, E.New(metadata.Network, " is not supported by outbound: ", action.Outbound) + } + directRouteOutbound = outbound.(adapter.DirectRouteOutbound) + case *R.RuleActionRoute: + if routeContext == nil { + return nil, nil + } + outbound, loaded := r.outbound.Outbound(action.Outbound) + if !loaded { + return nil, E.New("outbound not found: ", action.Outbound) + } + if !common.Contains(outbound.Network(), metadata.Network) { + return nil, E.New(metadata.Network, " is not supported by outbound: ", action.Outbound) + } + directRouteOutbound = outbound.(adapter.DirectRouteOutbound) + } + } + if directRouteOutbound == nil { + if selectedRule != nil || metadata.Network != N.NetworkICMP { + return nil, nil + } + defaultOutbound := r.outbound.Default() + if !common.Contains(defaultOutbound.Network(), metadata.Network) { + return nil, E.New(metadata.Network, " is not supported by default outbound: ", defaultOutbound.Tag()) + } + directRouteOutbound = defaultOutbound.(adapter.DirectRouteOutbound) + } + if metadata.Destination.IsDomain() { + if len(metadata.DestinationAddresses) == 0 { + var strategy C.DomainStrategy + if metadata.Source.IsIPv4() { + strategy = C.DomainStrategyIPv4Only + } else { + strategy = C.DomainStrategyIPv6Only + } + err = r.actionResolve(r.ctx, &metadata, &R.RuleActionResolve{ + Strategy: strategy, + }) + if err != nil { + return nil, err + } + } + var newDestination netip.Addr + if metadata.Source.IsIPv4() { + for _, address := range metadata.DestinationAddresses { + if address.Is4() { + newDestination = address + break + } + } + } else { + for _, address := range metadata.DestinationAddresses { + if address.Is6() { + newDestination = address + break + } + } + } + if !newDestination.IsValid() { + if metadata.Source.IsIPv4() { + return nil, E.New("no IPv4 address found for domain: ", metadata.Destination.Fqdn) + } else { + return nil, E.New("no IPv6 address found for domain: ", metadata.Destination.Fqdn) + } + } + metadata.Destination = M.Socksaddr{ + Addr: newDestination, + } + routeContext = ping.NewContextDestinationWriter(routeContext, metadata.OriginDestination.Addr) + var routeDestination tun.DirectRouteDestination + routeDestination, err = directRouteOutbound.NewDirectRouteConnection(metadata, routeContext, timeout) + if err != nil { + return nil, err + } + return ping.NewDestinationWriter(routeDestination, newDestination), nil + } + return directRouteOutbound.NewDirectRouteConnection(metadata, routeContext, timeout) +} + +func (r *Router) matchRule( + ctx context.Context, metadata *adapter.InboundContext, preMatch bool, supportBypass bool, + inputConn net.Conn, inputPacketConn N.PacketConn, +) ( + selectedRule adapter.Rule, selectedRuleIndex int, + buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error, +) { + if r.processSearcher != nil && metadata.ProcessInfo == nil { + var originDestination netip.AddrPort + if metadata.OriginDestination.IsValid() { + originDestination = metadata.OriginDestination.AddrPort() + } else if metadata.Destination.IsIP() { + originDestination = metadata.Destination.AddrPort() + } + processInfo, fErr := r.findProcessInfoCached(ctx, metadata.Network, metadata.Source.AddrPort(), originDestination) + if fErr != nil { + r.logger.InfoContext(ctx, "failed to search process: ", fErr) + } else { + if processInfo.ProcessPath != "" { + if processInfo.UserName != "" { + r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath, ", user: ", processInfo.UserName) + } else if processInfo.UserId != -1 { + r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath, ", user id: ", processInfo.UserId) + } else { + r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath) + } + } else if len(processInfo.AndroidPackageNames) > 0 { + r.logger.InfoContext(ctx, "found package name: ", strings.Join(processInfo.AndroidPackageNames, ", ")) + } else if processInfo.UserId != -1 { + if processInfo.UserName != "" { + r.logger.InfoContext(ctx, "found user: ", processInfo.UserName) + } else { + r.logger.InfoContext(ctx, "found user id: ", processInfo.UserId) + } + } + metadata.ProcessInfo = processInfo + } + } + if metadata.Destination.Addr.IsValid() && r.dnsTransport.FakeIP() != nil && r.dnsTransport.FakeIP().Store().Contains(metadata.Destination.Addr) { + domain, loaded := r.dnsTransport.FakeIP().Store().Lookup(metadata.Destination.Addr) + if !loaded { + fatalErr = E.New("missing fakeip record, try enable `experimental.cache_file`") + return + } + if domain != "" { + metadata.OriginDestination = metadata.Destination + metadata.Destination = M.Socksaddr{ + Fqdn: domain, + Port: metadata.Destination.Port, + } + metadata.FakeIP = true + r.logger.DebugContext(ctx, "found fakeip domain: ", domain) + } + } else if metadata.Domain == "" { + domain, loaded := r.dns.LookupReverseMapping(metadata.Destination.Addr) + if loaded { + metadata.Domain = domain + r.logger.DebugContext(ctx, "found reserve mapped domain: ", metadata.Domain) + } + } + if metadata.Destination.IsIPv4() { + metadata.IPVersion = 4 + } else if metadata.Destination.IsIPv6() { + metadata.IPVersion = 6 + } + +match: + for currentRuleIndex, currentRule := range r.rules { + metadata.ResetRuleCache() + if !currentRule.Match(metadata) { + continue + } + if !preMatch { + ruleDescription := currentRule.String() + if ruleDescription != "" { + r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action()) + } else { + r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action()) + } + } else { + switch currentRule.Action().Type() { + case C.RuleActionTypeReject: + ruleDescription := currentRule.String() + if ruleDescription != "" { + r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action()) + } else { + r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] => ", currentRule.Action()) + } + } + } + var routeOptions *R.RuleActionRouteOptions + switch action := currentRule.Action().(type) { + case *R.RuleActionRoute: + routeOptions = &action.RuleActionRouteOptions + case *R.RuleActionRouteOptions: + routeOptions = action + } + if routeOptions != nil { + // TODO: add nat + if (routeOptions.OverrideAddress.IsValid() || routeOptions.OverridePort > 0) && !metadata.RouteOriginalDestination.IsValid() { + metadata.RouteOriginalDestination = metadata.Destination + } + if routeOptions.OverrideAddress.IsValid() { + metadata.Destination = M.Socksaddr{ + Addr: routeOptions.OverrideAddress.Addr, + Port: metadata.Destination.Port, + Fqdn: routeOptions.OverrideAddress.Fqdn, + } + metadata.DestinationAddresses = nil + } + if routeOptions.OverridePort > 0 { + metadata.Destination = M.Socksaddr{ + Addr: metadata.Destination.Addr, + Port: routeOptions.OverridePort, + Fqdn: metadata.Destination.Fqdn, + } + } + if routeOptions.NetworkStrategy != nil { + metadata.NetworkStrategy = routeOptions.NetworkStrategy + } + if len(routeOptions.NetworkType) > 0 { + metadata.NetworkType = routeOptions.NetworkType + } + if len(routeOptions.FallbackNetworkType) > 0 { + metadata.FallbackNetworkType = routeOptions.FallbackNetworkType + } + if routeOptions.FallbackDelay != 0 { + metadata.FallbackDelay = routeOptions.FallbackDelay + } + if routeOptions.UDPDisableDomainUnmapping { + metadata.UDPDisableDomainUnmapping = true + } + if routeOptions.UDPConnect { + metadata.UDPConnect = true + } + if routeOptions.UDPTimeout > 0 { + metadata.UDPTimeout = routeOptions.UDPTimeout + } + if routeOptions.TLSFragment { + metadata.TLSFragment = true + metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay + } + if routeOptions.TLSRecordFragment { + metadata.TLSRecordFragment = true + } + } + switch action := currentRule.Action().(type) { + case *R.RuleActionSniff: + if !preMatch { + newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers, packetBuffers) + if newBuffer != nil { + buffers = append(buffers, newBuffer) + } else if len(newPacketBuffers) > 0 { + packetBuffers = append(packetBuffers, newPacketBuffers...) + } + if newErr != nil { + fatalErr = newErr + return + } + } else if metadata.Network != N.NetworkICMP { + selectedRule = currentRule + selectedRuleIndex = currentRuleIndex + break match + } + case *R.RuleActionResolve: + fatalErr = r.actionResolve(ctx, metadata, action) + if fatalErr != nil { + return + } + } + actionType := currentRule.Action().Type() + if actionType == C.RuleActionTypeRoute || + actionType == C.RuleActionTypeReject || + actionType == C.RuleActionTypeHijackDNS { + selectedRule = currentRule + selectedRuleIndex = currentRuleIndex + break match + } + if actionType == C.RuleActionTypeBypass { + bypassAction := currentRule.Action().(*R.RuleActionBypass) + if !supportBypass && bypassAction.Outbound == "" { + continue match + } + selectedRule = currentRule + selectedRuleIndex = currentRuleIndex + break match + } + } + return +} + +func (r *Router) actionSniff( + ctx context.Context, metadata *adapter.InboundContext, action *R.RuleActionSniff, + inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer, inputPacketBuffers []*N.PacketBuffer, +) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) { + if sniff.Skip(metadata) { + r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first") + return + } else if metadata.Protocol != "" { + r.logger.DebugContext(ctx, "duplicate sniff skipped") + return + } + if inputConn != nil { + if len(action.StreamSniffers) == 0 && len(action.PacketSniffers) > 0 { + return + } else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) { + r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError) + return + } + var streamSniffers []sniff.StreamSniffer + if len(action.StreamSniffers) > 0 { + streamSniffers = action.StreamSniffers + } else { + streamSniffers = []sniff.StreamSniffer{ + sniff.TLSClientHello, + sniff.HTTPHost, + sniff.StreamDomainNameQuery, + sniff.BitTorrent, + sniff.SSH, + sniff.RDP, + } + } + sniffBuffer := buf.NewPacket() + err := sniff.PeekStream( + ctx, + metadata, + inputConn, + inputBuffers, + sniffBuffer, + action.Timeout, + streamSniffers..., + ) + metadata.SnifferNames = action.SnifferNames + metadata.SniffError = err + if err == nil { + //goland:noinspection GoDeprecation + if action.OverrideDestination && M.IsDomainName(metadata.Domain) { + metadata.Destination = M.Socksaddr{ + Fqdn: metadata.Domain, + Port: metadata.Destination.Port, + } + } + if metadata.Domain != "" && metadata.Client != "" { + r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client) + } else if metadata.Domain != "" { + r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) + } else { + r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol) + } + } + if !sniffBuffer.IsEmpty() { + buffer = sniffBuffer + } else { + sniffBuffer.Release() + } + } else if inputPacketConn != nil { + if len(action.PacketSniffers) == 0 && len(action.StreamSniffers) > 0 { + return + } else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) { + r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError) + return + } + quicMoreData := func() bool { + return slices.Equal(metadata.SnifferNames, action.SnifferNames) && errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) + } + var packetSniffers []sniff.PacketSniffer + if len(action.PacketSniffers) > 0 { + packetSniffers = action.PacketSniffers + } else { + packetSniffers = []sniff.PacketSniffer{ + sniff.DomainNameQuery, + sniff.QUICClientHello, + sniff.STUNMessage, + sniff.UTP, + sniff.UDPTracker, + sniff.DTLSRecord, + sniff.NTP, + } + } + var err error + for _, packetBuffer := range inputPacketBuffers { + if quicMoreData() { + err = sniff.PeekPacket( + ctx, + metadata, + packetBuffer.Buffer.Bytes(), + sniff.QUICClientHello, + ) + } else { + err = sniff.PeekPacket( + ctx, metadata, + packetBuffer.Buffer.Bytes(), + packetSniffers..., + ) + } + metadata.SnifferNames = action.SnifferNames + metadata.SniffError = err + if errors.Is(err, sniff.ErrNeedMoreData) { + // TODO: replace with generic message when there are more multi-packet protocols + r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") + continue + } + goto finally + } + packetBuffers = inputPacketBuffers + for { + var ( + sniffBuffer = buf.NewPacket() + destination M.Socksaddr + done = make(chan struct{}) + ) + go func() { + sniffTimeout := C.ReadPayloadTimeout + if action.Timeout > 0 { + sniffTimeout = action.Timeout + } + inputPacketConn.SetReadDeadline(time.Now().Add(sniffTimeout)) + destination, err = inputPacketConn.ReadPacket(sniffBuffer) + inputPacketConn.SetReadDeadline(time.Time{}) + close(done) + }() + select { + case <-done: + case <-ctx.Done(): + inputPacketConn.Close() + fatalErr = ctx.Err() + return + } + if err != nil { + sniffBuffer.Release() + if !errors.Is(err, context.DeadlineExceeded) { + fatalErr = err + return + } + } else { + if quicMoreData() { + err = sniff.PeekPacket( + ctx, + metadata, + sniffBuffer.Bytes(), + sniff.QUICClientHello, + ) + } else { + err = sniff.PeekPacket( + ctx, metadata, + sniffBuffer.Bytes(), + packetSniffers..., + ) + } + packetBuffer := N.NewPacketBuffer() + *packetBuffer = N.PacketBuffer{ + Buffer: sniffBuffer, + Destination: destination, + } + packetBuffers = append(packetBuffers, packetBuffer) + metadata.SnifferNames = action.SnifferNames + metadata.SniffError = err + if errors.Is(err, sniff.ErrNeedMoreData) { + // TODO: replace with generic message when there are more multi-packet protocols + r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") + continue + } + } + goto finally + } + finally: + if err == nil { + //goland:noinspection GoDeprecation + if action.OverrideDestination && M.IsDomainName(metadata.Domain) { + metadata.Destination = M.Socksaddr{ + Fqdn: metadata.Domain, + Port: metadata.Destination.Port, + } + } + if metadata.Domain != "" && metadata.Client != "" { + r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client) + } else if metadata.Domain != "" { + r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) + } else if metadata.Client != "" { + r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client) + } else { + r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol) + } + } + } + return +} + +func (r *Router) actionResolve(ctx context.Context, metadata *adapter.InboundContext, action *R.RuleActionResolve) error { + if metadata.Destination.IsDomain() { + var transport adapter.DNSTransport + if action.Server != "" { + var ( + loaded bool + ambiguous bool + ) + transport, loaded, ambiguous = adapter.LookupDNSTransport(r.dnsTransport, action.Server) + if !loaded { + if ambiguous { + return E.New("DNS server is ambiguous: ", action.Server) + } + return E.New("DNS server not found: ", action.Server) + } + } + addresses, err := r.dns.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, adapter.DNSQueryOptions{ + Transport: transport, + Strategy: action.Strategy, + DisableCache: action.DisableCache, + RewriteTTL: action.RewriteTTL, + ClientSubnet: action.ClientSubnet, + }) + if err != nil { + return err + } + metadata.DestinationAddresses = addresses + r.logger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") + } + return nil +}