diff --git a/route/rule/match_state.go b/route/rule/match_state.go new file mode 100644 index 00000000..f537d5de --- /dev/null +++ b/route/rule/match_state.go @@ -0,0 +1,108 @@ +package rule + +import "github.com/sagernet/sing-box/adapter" + +type ruleMatchState uint8 + +const ( + ruleMatchSourceAddress ruleMatchState = 1 << iota + ruleMatchSourcePort + ruleMatchDestinationAddress + ruleMatchDestinationPort +) + +type ruleMatchStateSet uint16 + +func singleRuleMatchState(state ruleMatchState) ruleMatchStateSet { + return 1 << state +} + +func emptyRuleMatchState() ruleMatchStateSet { + return singleRuleMatchState(0) +} + +func (s ruleMatchStateSet) isEmpty() bool { + return s == 0 +} + +func (s ruleMatchStateSet) contains(state ruleMatchState) bool { + return s&(1< 0 +} + +func (r *abstractDefaultRule) destinationIPCIDRMatchesDestination(metadata *adapter.InboundContext) bool { + return !metadata.IgnoreDestinationIPCIDRMatch && !metadata.IPCIDRMatchSource && len(r.destinationIPCIDRItems) > 0 +} + +func (r *abstractDefaultRule) requiresSourceAddressMatch(metadata *adapter.InboundContext) bool { + return len(r.sourceAddressItems) > 0 || r.destinationIPCIDRMatchesSource(metadata) +} + +func (r *abstractDefaultRule) requiresDestinationAddressMatch(metadata *adapter.InboundContext) bool { + return len(r.destinationAddressItems) > 0 || r.destinationIPCIDRMatchesDestination(metadata) +} + +func (r *abstractDefaultRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { if len(r.allItems) == 0 { - return true + return emptyRuleMatchState() } - - if len(r.sourceAddressItems) > 0 && !metadata.SourceAddressMatch { + var baseState ruleMatchState + if len(r.sourceAddressItems) > 0 { metadata.DidMatch = true - for _, item := range r.sourceAddressItems { - if item.Match(metadata) { - metadata.SourceAddressMatch = true - break - } + if matchAnyItem(r.sourceAddressItems, metadata) { + baseState |= ruleMatchSourceAddress } } - - if len(r.sourcePortItems) > 0 && !metadata.SourcePortMatch { + if r.destinationIPCIDRMatchesSource(metadata) && !baseState.has(ruleMatchSourceAddress) { metadata.DidMatch = true - for _, item := range r.sourcePortItems { - if item.Match(metadata) { - metadata.SourcePortMatch = true - break - } + if matchAnyItem(r.destinationIPCIDRItems, metadata) { + baseState |= ruleMatchSourceAddress + } + } else if r.destinationIPCIDRMatchesSource(metadata) { + metadata.DidMatch = true + } + if len(r.sourcePortItems) > 0 { + metadata.DidMatch = true + if matchAnyItem(r.sourcePortItems, metadata) { + baseState |= ruleMatchSourcePort } } - - if len(r.destinationAddressItems) > 0 && !metadata.DestinationAddressMatch { + if len(r.destinationAddressItems) > 0 { metadata.DidMatch = true - for _, item := range r.destinationAddressItems { - if item.Match(metadata) { - metadata.DestinationAddressMatch = true - break - } + if matchAnyItem(r.destinationAddressItems, metadata) { + baseState |= ruleMatchDestinationAddress } } - - if !metadata.IgnoreDestinationIPCIDRMatch && len(r.destinationIPCIDRItems) > 0 && !metadata.DestinationAddressMatch { + if r.destinationIPCIDRMatchesDestination(metadata) && !baseState.has(ruleMatchDestinationAddress) { metadata.DidMatch = true - for _, item := range r.destinationIPCIDRItems { - if item.Match(metadata) { - metadata.DestinationAddressMatch = true - break - } + if matchAnyItem(r.destinationIPCIDRItems, metadata) { + baseState |= ruleMatchDestinationAddress + } + } else if r.destinationIPCIDRMatchesDestination(metadata) { + metadata.DidMatch = true + } + if len(r.destinationPortItems) > 0 { + metadata.DidMatch = true + if matchAnyItem(r.destinationPortItems, metadata) { + baseState |= ruleMatchDestinationPort } } - - if len(r.destinationPortItems) > 0 && !metadata.DestinationPortMatch { - metadata.DidMatch = true - for _, item := range r.destinationPortItems { - if item.Match(metadata) { - metadata.DestinationPortMatch = true - break - } - } - } - for _, item := range r.items { metadata.DidMatch = true if !item.Match(metadata) { - return r.invert + return r.invertedFailure() } } - - if len(r.sourceAddressItems) > 0 && !metadata.SourceAddressMatch { - return r.invert + stateSet := singleRuleMatchState(baseState) + if r.ruleSetItem != nil { + metadata.DidMatch = true + ruleSetStates := matchRuleItemStates(r.ruleSetItem, metadata) + if ruleSetStates.isEmpty() { + return r.invertedFailure() + } + stateSet = ruleSetStates.withBase(baseState) } - - if len(r.sourcePortItems) > 0 && !metadata.SourcePortMatch { - return r.invert - } - - if ((!metadata.IgnoreDestinationIPCIDRMatch && len(r.destinationIPCIDRItems) > 0) || len(r.destinationAddressItems) > 0) && !metadata.DestinationAddressMatch { - return r.invert - } - - if len(r.destinationPortItems) > 0 && !metadata.DestinationPortMatch { - return r.invert - } - - if !metadata.DidMatch { + stateSet = stateSet.filter(func(state ruleMatchState) bool { + if r.requiresSourceAddressMatch(metadata) && !state.has(ruleMatchSourceAddress) { + return false + } + if len(r.sourcePortItems) > 0 && !state.has(ruleMatchSourcePort) { + return false + } + if r.requiresDestinationAddressMatch(metadata) && !state.has(ruleMatchDestinationAddress) { + return false + } + if len(r.destinationPortItems) > 0 && !state.has(ruleMatchDestinationPort) { + return false + } return true + }) + if stateSet.isEmpty() { + return r.invertedFailure() } + if r.invert { + // DNS pre-lookup defers destination address-limit checks until the response phase. + if metadata.IgnoreDestinationIPCIDRMatch && stateSet == emptyRuleMatchState() && !metadata.DidMatch && len(r.destinationIPCIDRItems) > 0 { + return emptyRuleMatchState() + } + return 0 + } + return stateSet +} - return !r.invert +func (r *abstractDefaultRule) invertedFailure() ruleMatchStateSet { + if r.invert { + return emptyRuleMatchState() + } + return 0 } func (r *abstractDefaultRule) Action() adapter.RuleAction { @@ -191,17 +221,42 @@ func (r *abstractLogicalRule) Close() error { } func (r *abstractLogicalRule) Match(metadata *adapter.InboundContext) bool { + return !r.matchStates(metadata).isEmpty() +} + +func (r *abstractLogicalRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + var stateSet ruleMatchStateSet if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it adapter.HeadlessRule) bool { - metadata.ResetRuleCache() - return it.Match(metadata) - }) != r.invert + stateSet = emptyRuleMatchState() + for _, rule := range r.rules { + nestedMetadata := *metadata + nestedMetadata.ResetRuleCache() + nestedStateSet := matchHeadlessRuleStates(rule, &nestedMetadata) + if nestedStateSet.isEmpty() { + if r.invert { + return emptyRuleMatchState() + } + return 0 + } + stateSet = stateSet.combine(nestedStateSet) + } } else { - return common.Any(r.rules, func(it adapter.HeadlessRule) bool { - metadata.ResetRuleCache() - return it.Match(metadata) - }) != r.invert + for _, rule := range r.rules { + nestedMetadata := *metadata + nestedMetadata.ResetRuleCache() + stateSet = stateSet.merge(matchHeadlessRuleStates(rule, &nestedMetadata)) + } + if stateSet.isEmpty() { + if r.invert { + return emptyRuleMatchState() + } + return 0 + } } + if r.invert { + return 0 + } + return stateSet } func (r *abstractLogicalRule) Action() adapter.RuleAction { @@ -222,3 +277,13 @@ func (r *abstractLogicalRule) String() string { return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" } } + +func matchAnyItem(items []RuleItem, metadata *adapter.InboundContext) bool { + return common.Any(items, func(it RuleItem) bool { + return it.Match(metadata) + }) +} + +func (s ruleMatchState) has(target ruleMatchState) bool { + return s&target != 0 +} diff --git a/route/rule/rule_abstract_test.go b/route/rule/rule_abstract_test.go index 2d2e8ba8..ace3dec6 100644 --- a/route/rule/rule_abstract_test.go +++ b/route/rule/rule_abstract_test.go @@ -78,9 +78,9 @@ func newRuleSetOnlyRule(ruleSetMatched bool, invert bool) *DefaultRule { } return &DefaultRule{ abstractDefaultRule: abstractDefaultRule{ - items: []RuleItem{ruleSetItem}, - allItems: []RuleItem{ruleSetItem}, - invert: invert, + ruleSetItem: ruleSetItem, + allItems: []RuleItem{ruleSetItem}, + invert: invert, }, } } diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index 202fb3b3..b921c8b2 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -47,6 +47,10 @@ type DefaultRule struct { abstractDefaultRule } +func (r *DefaultRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + return r.abstractDefaultRule.matchStates(metadata) +} + type RuleItem interface { Match(metadata *adapter.InboundContext) bool String() string @@ -275,7 +279,7 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio matchSource = true } item := NewRuleSetItem(router, options.RuleSet, matchSource, false) - rule.items = append(rule.items, item) + rule.ruleSetItem = item rule.allItems = append(rule.allItems, item) } return rule, nil @@ -287,6 +291,10 @@ type LogicalRule struct { abstractLogicalRule } +func (r *LogicalRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + return r.abstractLogicalRule.matchStates(metadata) +} + func NewLogicalRule(ctx context.Context, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { action, err := NewRuleAction(ctx, logger, options.RuleAction) if err != nil { diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index 9235dd6f..04f0f236 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -47,6 +47,10 @@ type DefaultDNSRule struct { abstractDefaultRule } +func (r *DefaultDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + return r.abstractDefaultRule.matchStates(metadata) +} + func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { rule := &DefaultDNSRule{ abstractDefaultRule: abstractDefaultRule{ @@ -271,7 +275,7 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op matchSource = true } item := NewRuleSetItem(router, options.RuleSet, matchSource, options.RuleSetIPCIDRAcceptEmpty) - rule.items = append(rule.items, item) + rule.ruleSetItem = item rule.allItems = append(rule.allItems, item) } return rule, nil @@ -285,12 +289,9 @@ func (r *DefaultDNSRule) WithAddressLimit() bool { if len(r.destinationIPCIDRItems) > 0 { return true } - for _, rawRule := range r.items { - ruleSet, isRuleSet := rawRule.(*RuleSetItem) - if !isRuleSet { - continue - } - if ruleSet.ContainsDestinationIPCIDRRule() { + if r.ruleSetItem != nil { + ruleSet, isRuleSet := r.ruleSetItem.(*RuleSetItem) + if isRuleSet && ruleSet.ContainsDestinationIPCIDRRule() { return true } } @@ -302,11 +303,11 @@ func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { defer func() { metadata.IgnoreDestinationIPCIDRMatch = false }() - return r.abstractDefaultRule.Match(metadata) + return !r.matchStates(metadata).isEmpty() } func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool { - return r.abstractDefaultRule.Match(metadata) + return !r.matchStates(metadata).isEmpty() } var _ adapter.DNSRule = (*LogicalDNSRule)(nil) @@ -315,6 +316,10 @@ type LogicalDNSRule struct { abstractLogicalRule } +func (r *LogicalDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + return r.abstractLogicalRule.matchStates(metadata) +} + func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { r := &LogicalDNSRule{ abstractLogicalRule: abstractLogicalRule{ @@ -362,29 +367,13 @@ func (r *LogicalDNSRule) WithAddressLimit() bool { } func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { - if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it adapter.HeadlessRule) bool { - metadata.ResetRuleCache() - return it.(adapter.DNSRule).Match(metadata) - }) != r.invert - } else { - return common.Any(r.rules, func(it adapter.HeadlessRule) bool { - metadata.ResetRuleCache() - return it.(adapter.DNSRule).Match(metadata) - }) != r.invert - } + metadata.IgnoreDestinationIPCIDRMatch = true + defer func() { + metadata.IgnoreDestinationIPCIDRMatch = false + }() + return !r.matchStates(metadata).isEmpty() } func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool { - if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it adapter.HeadlessRule) bool { - metadata.ResetRuleCache() - return it.(adapter.DNSRule).MatchAddressLimit(metadata) - }) != r.invert - } else { - return common.Any(r.rules, func(it adapter.HeadlessRule) bool { - metadata.ResetRuleCache() - return it.(adapter.DNSRule).MatchAddressLimit(metadata) - }) != r.invert - } + return !r.matchStates(metadata).isEmpty() } diff --git a/route/rule/rule_headless.go b/route/rule/rule_headless.go index 689e6e3e..f180bacc 100644 --- a/route/rule/rule_headless.go +++ b/route/rule/rule_headless.go @@ -34,6 +34,10 @@ type DefaultHeadlessRule struct { abstractDefaultRule } +func (r *DefaultHeadlessRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + return r.abstractDefaultRule.matchStates(metadata) +} + func NewDefaultHeadlessRule(ctx context.Context, options option.DefaultHeadlessRule) (*DefaultHeadlessRule, error) { networkManager := service.FromContext[adapter.NetworkManager](ctx) rule := &DefaultHeadlessRule{ @@ -199,6 +203,10 @@ type LogicalHeadlessRule struct { abstractLogicalRule } +func (r *LogicalHeadlessRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + return r.abstractLogicalRule.matchStates(metadata) +} + func NewLogicalHeadlessRule(ctx context.Context, options option.LogicalHeadlessRule) (*LogicalHeadlessRule, error) { r := &LogicalHeadlessRule{ abstractLogicalRule{ diff --git a/route/rule/rule_item_rule_set.go b/route/rule/rule_item_rule_set.go index 858bb877..0916279d 100644 --- a/route/rule/rule_item_rule_set.go +++ b/route/rule/rule_item_rule_set.go @@ -41,16 +41,19 @@ func (r *RuleSetItem) Start() error { } func (r *RuleSetItem) Match(metadata *adapter.InboundContext) bool { + return !r.matchStates(metadata).isEmpty() +} + +func (r *RuleSetItem) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + var stateSet ruleMatchStateSet for _, ruleSet := range r.setList { nestedMetadata := *metadata nestedMetadata.ResetRuleMatchCache() nestedMetadata.IPCIDRMatchSource = r.ipCidrMatchSource nestedMetadata.IPCIDRAcceptEmpty = r.ipCidrAcceptEmpty - if ruleSet.Match(&nestedMetadata) { - return true - } + stateSet = stateSet.merge(matchHeadlessRuleStates(ruleSet, &nestedMetadata)) } - return false + return stateSet } func (r *RuleSetItem) ContainsDestinationIPCIDRRule() bool { diff --git a/route/rule/rule_set_local.go b/route/rule/rule_set_local.go index 8409831b..ec0f91b2 100644 --- a/route/rule/rule_set_local.go +++ b/route/rule/rule_set_local.go @@ -202,12 +202,15 @@ func (s *LocalRuleSet) Close() error { } func (s *LocalRuleSet) Match(metadata *adapter.InboundContext) bool { + return !s.matchStates(metadata).isEmpty() +} + +func (s *LocalRuleSet) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + var stateSet ruleMatchStateSet for _, rule := range s.rules { nestedMetadata := *metadata nestedMetadata.ResetRuleMatchCache() - if rule.Match(&nestedMetadata) { - return true - } + stateSet = stateSet.merge(matchHeadlessRuleStates(rule, &nestedMetadata)) } - return false + return stateSet } diff --git a/route/rule/rule_set_remote.go b/route/rule/rule_set_remote.go index 81a8d3fc..c85dc859 100644 --- a/route/rule/rule_set_remote.go +++ b/route/rule/rule_set_remote.go @@ -322,12 +322,15 @@ func (s *RemoteRuleSet) Close() error { } func (s *RemoteRuleSet) Match(metadata *adapter.InboundContext) bool { + return !s.matchStates(metadata).isEmpty() +} + +func (s *RemoteRuleSet) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { + var stateSet ruleMatchStateSet for _, rule := range s.rules { nestedMetadata := *metadata nestedMetadata.ResetRuleMatchCache() - if rule.Match(&nestedMetadata) { - return true - } + stateSet = stateSet.merge(matchHeadlessRuleStates(rule, &nestedMetadata)) } - return false + return stateSet } diff --git a/route/rule/rule_set_semantics_test.go b/route/rule/rule_set_semantics_test.go new file mode 100644 index 00000000..27461ce6 --- /dev/null +++ b/route/rule/rule_set_semantics_test.go @@ -0,0 +1,620 @@ +package rule + +import ( + "context" + "net/netip" + "strings" + "testing" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/convertor/adguard" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + slogger "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/stretchr/testify/require" +) + +func TestRouteRuleSetMergeDestinationAddressGroup(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + metadata adapter.InboundContext + inner adapter.HeadlessRule + }{ + { + name: "domain", + metadata: testMetadata("www.example.com"), + inner: headlessDefaultRule(t, func(rule *abstractDefaultRule) { addDestinationAddressItem(t, rule, []string{"www.example.com"}, nil) }), + }, + { + name: "domain_suffix", + metadata: testMetadata("www.example.com"), + inner: headlessDefaultRule(t, func(rule *abstractDefaultRule) { addDestinationAddressItem(t, rule, nil, []string{"example.com"}) }), + }, + { + name: "domain_keyword", + metadata: testMetadata("www.example.com"), + inner: headlessDefaultRule(t, func(rule *abstractDefaultRule) { addDestinationKeywordItem(rule, []string{"example"}) }), + }, + { + name: "domain_regex", + metadata: testMetadata("www.example.com"), + inner: headlessDefaultRule(t, func(rule *abstractDefaultRule) { addDestinationRegexItem(t, rule, []string{`^www\.example\.com$`}) }), + }, + { + name: "ip_cidr", + metadata: func() adapter.InboundContext { + metadata := testMetadata("lookup.example") + metadata.DestinationAddresses = []netip.Addr{netip.MustParseAddr("8.8.8.8")} + return metadata + }(), + inner: headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{"8.8.8.0/24"}) + }), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + ruleSet := newLocalRuleSetForTest("merge-destination", testCase.inner) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.True(t, rule.Match(&testCase.metadata)) + }) + } +} + +func TestRouteRuleSetMergeSourceAndPortGroups(t *testing.T) { + t.Parallel() + t.Run("source address", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("merge-source-address", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addSourceAddressItem(t, rule, []string{"10.0.0.0/8"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addSourceAddressItem(t, rule, []string{"198.51.100.0/24"}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("source address via ruleset ipcidr match source", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("merge-source-address-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{"10.0.0.0/8"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{ + setList: []adapter.RuleSet{ruleSet}, + ipCidrMatchSource: true, + }) + addSourceAddressItem(t, rule, []string{"198.51.100.0/24"}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("destination port", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("merge-destination-port", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationPortItem(rule, []uint16{443}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationPortItem(rule, []uint16{8443}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("destination port range", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("merge-destination-port-range", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationPortRangeItem(t, rule, []string{"400:500"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationPortItem(rule, []uint16{8443}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("source port", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("merge-source-port", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addSourcePortItem(rule, []uint16{1000}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addSourcePortItem(rule, []uint16{2000}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("source port range", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("merge-source-port-range", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addSourcePortRangeItem(t, rule, []string{"900:1100"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addSourcePortItem(rule, []uint16{2000}) + }) + require.True(t, rule.Match(&metadata)) + }) +} + +func TestRouteRuleSetOtherFieldsStayAnd(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("other-fields-and", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addOtherItem(rule, NewNetworkItem([]string{N.NetworkUDP})) + }) + require.False(t, rule.Match(&metadata)) +} + +func TestRouteRuleSetOrSemantics(t *testing.T) { + t.Parallel() + t.Run("later ruleset can satisfy outer group", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + emptyStateSet := newLocalRuleSetForTest("network-only", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addOtherItem(rule, NewNetworkItem([]string{N.NetworkTCP})) + })) + destinationStateSet := newLocalRuleSetForTest("domain-only", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{emptyStateSet, destinationStateSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("later rule in same set can satisfy outer group", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest( + "rule-set-or", + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addOtherItem(rule, NewNetworkItem([]string{N.NetworkTCP})) + }), + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + }), + ) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("cross ruleset union is not allowed", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + sourceStateSet := newLocalRuleSetForTest("source-only", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addSourcePortItem(rule, []uint16{1000}) + })) + destinationStateSet := newLocalRuleSetForTest("destination-only", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{sourceStateSet, destinationStateSet}}) + addSourcePortItem(rule, []uint16{2000}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.False(t, rule.Match(&metadata)) + }) +} + +func TestRouteRuleSetLogicalSemantics(t *testing.T) { + t.Parallel() + t.Run("logical or keeps all successful branch states", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("logical-or", headlessLogicalRule( + C.LogicalTypeOr, + false, + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addOtherItem(rule, NewNetworkItem([]string{N.NetworkTCP})) + }), + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + }), + )) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("logical and unions child states", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("logical-and", headlessLogicalRule( + C.LogicalTypeAnd, + false, + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + }), + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addSourcePortItem(rule, []uint16{1000}) + }), + )) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + addSourcePortItem(rule, []uint16{2000}) + }) + require.True(t, rule.Match(&metadata)) + }) + t.Run("invert success does not contribute positive state", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("invert", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + rule.invert = true + addDestinationAddressItem(t, rule, nil, []string{"cn"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.False(t, rule.Match(&metadata)) + }) +} + +func TestRouteRuleSetNoLeakageRegressions(t *testing.T) { + t.Parallel() + t.Run("same ruleset failed branch does not leak", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest( + "same-set", + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + addSourcePortItem(rule, []uint16{1}) + }), + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + addSourcePortItem(rule, []uint16{1000}) + }), + ) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + }) + require.False(t, rule.Match(&metadata)) + }) + t.Run("adguard exclusion remains isolated across rulesets", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("im.qq.com") + excludeSet := newLocalRuleSetForTest("adguard", mustAdGuardRule(t, "@@||im.qq.com^\n||whatever1.com^\n")) + otherSet := newLocalRuleSetForTest("other", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"whatever2.com"}) + })) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{excludeSet, otherSet}}) + }) + require.False(t, rule.Match(&metadata)) + }) +} + +func TestDefaultRuleDoesNotReuseGroupedMatchCacheAcrossEvaluations(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + }) + require.True(t, rule.Match(&metadata)) + + metadata.Destination.Fqdn = "www.example.org" + require.False(t, rule.Match(&metadata)) +} + +func TestRouteRuleSetRemoteUsesSameSemantics(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newRemoteRuleSetForTest( + "remote", + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addOtherItem(rule, NewNetworkItem([]string{N.NetworkTCP})) + }), + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + }), + ) + rule := routeRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.True(t, rule.Match(&metadata)) +} + +func TestDNSRuleSetSemantics(t *testing.T) { + t.Parallel() + t.Run("match address limit merges destination group", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("dns-merge", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + })) + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.True(t, rule.MatchAddressLimit(&metadata)) + }) + t.Run("dns keeps ruleset or semantics", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + emptyStateSet := newLocalRuleSetForTest("dns-empty", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addOtherItem(rule, NewNetworkItem([]string{N.NetworkTCP})) + })) + destinationStateSet := newLocalRuleSetForTest("dns-destination", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationAddressItem(t, rule, nil, []string{"example.com"}) + })) + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{emptyStateSet, destinationStateSet}}) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.True(t, rule.MatchAddressLimit(&metadata)) + }) + t.Run("ruleset ip cidr flags stay scoped", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("www.example.com") + ruleSet := newLocalRuleSetForTest("dns-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + })) + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{ + setList: []adapter.RuleSet{ruleSet}, + ipCidrAcceptEmpty: true, + }) + }) + require.True(t, rule.MatchAddressLimit(&metadata)) + require.False(t, metadata.IPCIDRMatchSource) + require.False(t, metadata.IPCIDRAcceptEmpty) + }) +} + +func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + build func(*testing.T, *abstractDefaultRule) + matchedAddrs []netip.Addr + unmatchedAddrs []netip.Addr + }{ + { + name: "ip_cidr", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }, + matchedAddrs: []netip.Addr{netip.MustParseAddr("203.0.113.1")}, + unmatchedAddrs: []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + }, + { + name: "ip_is_private", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPIsPrivateItem(rule) + }, + matchedAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + unmatchedAddrs: []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + }, + { + name: "ip_accept_any", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPAcceptAnyItem(rule) + }, + matchedAddrs: []netip.Addr{netip.MustParseAddr("203.0.113.1")}, + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + testCase.build(t, rule) + }) + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, rule.Match(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + matchedMetadata.DestinationAddresses = testCase.matchedAddrs + require.False(t, rule.MatchAddressLimit(&matchedMetadata)) + + unmatchedMetadata := testMetadata("lookup.example") + unmatchedMetadata.DestinationAddresses = testCase.unmatchedAddrs + require.True(t, rule.MatchAddressLimit(&unmatchedMetadata)) + }) + } + t.Run("mixed resolved and deferred fields keep old pre lookup false", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("lookup.example") + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + addOtherItem(rule, NewNetworkItem([]string{N.NetworkTCP})) + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }) + require.False(t, rule.Match(&metadata)) + }) + t.Run("ruleset only deferred fields keep old pre lookup false", func(t *testing.T) { + t.Parallel() + metadata := testMetadata("lookup.example") + ruleSet := newLocalRuleSetForTest("dns-ruleset-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + })) + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + }) + require.False(t, rule.Match(&metadata)) + }) +} + +func routeRuleForTest(build func(*abstractDefaultRule)) *DefaultRule { + rule := &DefaultRule{} + build(&rule.abstractDefaultRule) + return rule +} + +func dnsRuleForTest(build func(*abstractDefaultRule)) *DefaultDNSRule { + rule := &DefaultDNSRule{} + build(&rule.abstractDefaultRule) + return rule +} + +func headlessDefaultRule(t *testing.T, build func(*abstractDefaultRule)) *DefaultHeadlessRule { + t.Helper() + rule := &DefaultHeadlessRule{} + build(&rule.abstractDefaultRule) + return rule +} + +func headlessLogicalRule(mode string, invert bool, rules ...adapter.HeadlessRule) *LogicalHeadlessRule { + return &LogicalHeadlessRule{ + abstractLogicalRule: abstractLogicalRule{ + rules: rules, + mode: mode, + invert: invert, + }, + } +} + +func newLocalRuleSetForTest(tag string, rules ...adapter.HeadlessRule) *LocalRuleSet { + return &LocalRuleSet{ + tag: tag, + rules: rules, + } +} + +func newRemoteRuleSetForTest(tag string, rules ...adapter.HeadlessRule) *RemoteRuleSet { + return &RemoteRuleSet{ + options: option.RuleSet{Tag: tag}, + rules: rules, + } +} + +func mustAdGuardRule(t *testing.T, content string) adapter.HeadlessRule { + t.Helper() + rules, err := adguard.ToOptions(strings.NewReader(content), slogger.NOP()) + require.NoError(t, err) + require.Len(t, rules, 1) + rule, err := NewHeadlessRule(context.Background(), rules[0]) + require.NoError(t, err) + return rule +} + +func testMetadata(domain string) adapter.InboundContext { + return adapter.InboundContext{ + Network: N.NetworkTCP, + Source: M.Socksaddr{ + Addr: netip.MustParseAddr("10.0.0.1"), + Port: 1000, + }, + Destination: M.Socksaddr{ + Fqdn: domain, + Port: 443, + }, + } +} + +func addRuleSetItem(rule *abstractDefaultRule, item *RuleSetItem) { + rule.ruleSetItem = item + rule.allItems = append(rule.allItems, item) +} + +func addOtherItem(rule *abstractDefaultRule, item RuleItem) { + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) +} + +func addSourceAddressItem(t *testing.T, rule *abstractDefaultRule, cidrs []string) { + t.Helper() + item, err := NewIPCIDRItem(true, cidrs) + require.NoError(t, err) + rule.sourceAddressItems = append(rule.sourceAddressItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addDestinationAddressItem(t *testing.T, rule *abstractDefaultRule, domains []string, suffixes []string) { + t.Helper() + item, err := NewDomainItem(domains, suffixes) + require.NoError(t, err) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addDestinationKeywordItem(rule *abstractDefaultRule, keywords []string) { + item := NewDomainKeywordItem(keywords) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addDestinationRegexItem(t *testing.T, rule *abstractDefaultRule, regexes []string) { + t.Helper() + item, err := NewDomainRegexItem(regexes) + require.NoError(t, err) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addDestinationIPCIDRItem(t *testing.T, rule *abstractDefaultRule, cidrs []string) { + t.Helper() + item, err := NewIPCIDRItem(false, cidrs) + require.NoError(t, err) + rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addDestinationIPIsPrivateItem(rule *abstractDefaultRule) { + item := NewIPIsPrivateItem(false) + rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addDestinationIPAcceptAnyItem(rule *abstractDefaultRule) { + item := NewIPAcceptAnyItem() + rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addSourcePortItem(rule *abstractDefaultRule, ports []uint16) { + item := NewPortItem(true, ports) + rule.sourcePortItems = append(rule.sourcePortItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addSourcePortRangeItem(t *testing.T, rule *abstractDefaultRule, ranges []string) { + t.Helper() + item, err := NewPortRangeItem(true, ranges) + require.NoError(t, err) + rule.sourcePortItems = append(rule.sourcePortItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addDestinationPortItem(rule *abstractDefaultRule, ports []uint16) { + item := NewPortItem(false, ports) + rule.destinationPortItems = append(rule.destinationPortItems, item) + rule.allItems = append(rule.allItems, item) +} + +func addDestinationPortRangeItem(t *testing.T, rule *abstractDefaultRule, ranges []string) { + t.Helper() + item, err := NewPortRangeItem(false, ranges) + require.NoError(t, err) + rule.destinationPortItems = append(rule.destinationPortItems, item) + rule.allItems = append(rule.allItems, item) +}