Fix TCP exchange for local/dhcp DNS servers

This commit is contained in:
世界
2025-09-12 18:49:41 +08:00
parent 146383499e
commit 5d1d1a1456
2 changed files with 131 additions and 90 deletions

View File

@@ -2,12 +2,13 @@ package dhcp
import (
"context"
"errors"
"math/rand"
"strings"
"time"
"syscall"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
@@ -83,7 +84,7 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn
server := servers[j]
question := message.Question[0]
question.Name = fqdn
response, err := t.exchangeOne(ctx, server, question, C.DNSTimeout, false, true)
response, err := t.exchangeOne(ctx, server, question)
if err != nil {
lastErr = err
continue
@@ -94,62 +95,77 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn
return nil, E.Cause(lastErr, fqdn)
}
func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question) (*mDNS.Msg, error) {
if server.Port == 0 {
server.Port = 53
}
var networks []string
if useTCP {
networks = []string{N.NetworkTCP}
} else {
networks = []string{N.NetworkUDP, N.NetworkTCP}
}
request := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: uint16(rand.Uint32()),
RecursionDesired: true,
AuthenticatedData: ad,
AuthenticatedData: true,
},
Question: []mDNS.Question{question},
Compress: true,
}
request.SetEdns0(buf.UDPBufferSize, false)
buffer := buf.Get(buf.UDPBufferSize)
defer buf.Put(buffer)
for _, network := range networks {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
conn, err := t.dialer.DialContext(ctx, network, server)
if err != nil {
return nil, err
}
defer conn.Close()
if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
conn.SetDeadline(deadline)
}
rawMessage, err := request.PackBuffer(buffer)
if err != nil {
return nil, E.Cause(err, "pack request")
}
_, err = conn.Write(rawMessage)
if err != nil {
return nil, E.Cause(err, "write request")
}
n, err := conn.Read(buffer)
if err != nil {
return nil, E.Cause(err, "read response")
}
var response mDNS.Msg
err = response.Unpack(buffer[:n])
if err != nil {
return nil, E.Cause(err, "unpack response")
}
if response.Truncated && network == N.NetworkUDP {
continue
}
return &response, nil
return t.exchangeUDP(ctx, server, request)
}
func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) {
conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
if err != nil {
return nil, err
}
panic("unexpected")
defer conn.Close()
if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
conn.SetDeadline(deadline)
}
buffer := buf.Get(1 + request.Len())
defer buf.Put(buffer)
rawMessage, err := request.PackBuffer(buffer)
if err != nil {
return nil, E.Cause(err, "pack request")
}
_, err = conn.Write(rawMessage)
if err != nil {
if errors.Is(err, syscall.EMSGSIZE) {
return t.exchangeTCP(ctx, server, request)
}
return nil, E.Cause(err, "write request")
}
n, err := conn.Read(buffer)
if err != nil {
if errors.Is(err, syscall.EMSGSIZE) {
return t.exchangeTCP(ctx, server, request)
}
return nil, E.Cause(err, "read response")
}
var response mDNS.Msg
err = response.Unpack(buffer[:n])
if err != nil {
return nil, E.Cause(err, "unpack response")
}
if response.Truncated {
return t.exchangeTCP(ctx, server, request)
}
return &response, nil
}
func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) {
conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
if err != nil {
return nil, err
}
defer conn.Close()
if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
conn.SetDeadline(deadline)
}
err = transport.WriteMessage(conn, 0, request)
if err != nil {
return nil, err
}
return transport.ReadMessage(conn)
}
func (t *Transport) nameList(name string) []string {

View File

@@ -10,6 +10,7 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing-box/dns/transport/hosts"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
@@ -151,12 +152,6 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio
if server.Port == 0 {
server.Port = 53
}
var networks []string
if useTCP {
networks = []string{N.NetworkTCP}
} else {
networks = []string{N.NetworkUDP, N.NetworkTCP}
}
request := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: uint16(rand.Uint32()),
@@ -167,43 +162,73 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio
Compress: true,
}
request.SetEdns0(buf.UDPBufferSize, false)
buffer := buf.Get(buf.UDPBufferSize)
defer buf.Put(buffer)
for _, network := range networks {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
conn, err := t.dialer.DialContext(ctx, network, server)
if err != nil {
return nil, err
}
defer conn.Close()
if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
conn.SetDeadline(deadline)
}
rawMessage, err := request.PackBuffer(buffer)
if err != nil {
return nil, E.Cause(err, "pack request")
}
_, err = conn.Write(rawMessage)
if err != nil {
if errors.Is(err, syscall.EMSGSIZE) && network == N.NetworkUDP {
continue
}
return nil, E.Cause(err, "write request")
}
n, err := conn.Read(buffer)
if err != nil {
return nil, E.Cause(err, "read response")
}
var response mDNS.Msg
err = response.Unpack(buffer[:n])
if err != nil {
return nil, E.Cause(err, "unpack response")
}
if response.Truncated && network == N.NetworkUDP {
continue
}
return &response, nil
if !useTCP {
return t.exchangeUDP(ctx, server, request, timeout)
} else {
return t.exchangeTCP(ctx, server, request, timeout)
}
panic("unexpected")
}
func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
if err != nil {
return nil, err
}
defer conn.Close()
if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
newDeadline := time.Now().Add(timeout)
if deadline.After(newDeadline) {
deadline = newDeadline
}
conn.SetDeadline(deadline)
}
buffer := buf.Get(1 + request.Len())
defer buf.Put(buffer)
rawMessage, err := request.PackBuffer(buffer)
if err != nil {
return nil, E.Cause(err, "pack request")
}
_, err = conn.Write(rawMessage)
if err != nil {
if errors.Is(err, syscall.EMSGSIZE) {
return t.exchangeTCP(ctx, server, request, timeout)
}
return nil, E.Cause(err, "write request")
}
n, err := conn.Read(buffer)
if err != nil {
if errors.Is(err, syscall.EMSGSIZE) {
return t.exchangeTCP(ctx, server, request, timeout)
}
return nil, E.Cause(err, "read response")
}
var response mDNS.Msg
err = response.Unpack(buffer[:n])
if err != nil {
return nil, E.Cause(err, "unpack response")
}
if response.Truncated {
return t.exchangeTCP(ctx, server, request, timeout)
}
return &response, nil
}
func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
if err != nil {
return nil, err
}
defer conn.Close()
if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
newDeadline := time.Now().Add(timeout)
if deadline.After(newDeadline) {
deadline = newDeadline
}
conn.SetDeadline(deadline)
}
err = transport.WriteMessage(conn, 0, request)
if err != nil {
return nil, err
}
return transport.ReadMessage(conn)
}