diff --git a/dns/transport/dhcp/dhcp_shared.go b/dns/transport/dhcp/dhcp_shared.go index 086d1d68..2a6b5773 100644 --- a/dns/transport/dhcp/dhcp_shared.go +++ b/dns/transport/dhcp/dhcp_shared.go @@ -2,12 +2,13 @@ package dhcp import ( "context" + "errors" "math/rand" "strings" - "time" + "syscall" - C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -83,7 +84,7 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn server := servers[j] question := message.Question[0] question.Name = fqdn - response, err := t.exchangeOne(ctx, server, question, C.DNSTimeout, false, true) + response, err := t.exchangeOne(ctx, server, question) if err != nil { lastErr = err continue @@ -94,62 +95,77 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn return nil, E.Cause(lastErr, fqdn) } -func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) { +func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question) (*mDNS.Msg, error) { if server.Port == 0 { server.Port = 53 } - var networks []string - if useTCP { - networks = []string{N.NetworkTCP} - } else { - networks = []string{N.NetworkUDP, N.NetworkTCP} - } request := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ Id: uint16(rand.Uint32()), RecursionDesired: true, - AuthenticatedData: ad, + AuthenticatedData: true, }, Question: []mDNS.Question{question}, Compress: true, } request.SetEdns0(buf.UDPBufferSize, false) - buffer := buf.Get(buf.UDPBufferSize) - defer buf.Put(buffer) - for _, network := range networks { - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) - defer cancel() - conn, err := t.dialer.DialContext(ctx, network, server) - if err != nil { - return nil, err - } - defer conn.Close() - if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { - conn.SetDeadline(deadline) - } - rawMessage, err := request.PackBuffer(buffer) - if err != nil { - return nil, E.Cause(err, "pack request") - } - _, err = conn.Write(rawMessage) - if err != nil { - return nil, E.Cause(err, "write request") - } - n, err := conn.Read(buffer) - if err != nil { - return nil, E.Cause(err, "read response") - } - var response mDNS.Msg - err = response.Unpack(buffer[:n]) - if err != nil { - return nil, E.Cause(err, "unpack response") - } - if response.Truncated && network == N.NetworkUDP { - continue - } - return &response, nil + return t.exchangeUDP(ctx, server, request) +} + +func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) { + conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server) + if err != nil { + return nil, err } - panic("unexpected") + defer conn.Close() + if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { + conn.SetDeadline(deadline) + } + buffer := buf.Get(1 + request.Len()) + defer buf.Put(buffer) + rawMessage, err := request.PackBuffer(buffer) + if err != nil { + return nil, E.Cause(err, "pack request") + } + _, err = conn.Write(rawMessage) + if err != nil { + if errors.Is(err, syscall.EMSGSIZE) { + return t.exchangeTCP(ctx, server, request) + } + return nil, E.Cause(err, "write request") + } + n, err := conn.Read(buffer) + if err != nil { + if errors.Is(err, syscall.EMSGSIZE) { + return t.exchangeTCP(ctx, server, request) + } + return nil, E.Cause(err, "read response") + } + var response mDNS.Msg + err = response.Unpack(buffer[:n]) + if err != nil { + return nil, E.Cause(err, "unpack response") + } + if response.Truncated { + return t.exchangeTCP(ctx, server, request) + } + return &response, nil +} + +func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) { + conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server) + if err != nil { + return nil, err + } + defer conn.Close() + if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { + conn.SetDeadline(deadline) + } + err = transport.WriteMessage(conn, 0, request) + if err != nil { + return nil, err + } + return transport.ReadMessage(conn) } func (t *Transport) nameList(name string) []string { diff --git a/dns/transport/local/local.go b/dns/transport/local/local.go index 8187681a..ec3baad1 100644 --- a/dns/transport/local/local.go +++ b/dns/transport/local/local.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport" "github.com/sagernet/sing-box/dns/transport/hosts" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" @@ -151,12 +152,6 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio if server.Port == 0 { server.Port = 53 } - var networks []string - if useTCP { - networks = []string{N.NetworkTCP} - } else { - networks = []string{N.NetworkUDP, N.NetworkTCP} - } request := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ Id: uint16(rand.Uint32()), @@ -167,43 +162,73 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio Compress: true, } request.SetEdns0(buf.UDPBufferSize, false) - buffer := buf.Get(buf.UDPBufferSize) - defer buf.Put(buffer) - for _, network := range networks { - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) - defer cancel() - conn, err := t.dialer.DialContext(ctx, network, server) - if err != nil { - return nil, err - } - defer conn.Close() - if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { - conn.SetDeadline(deadline) - } - rawMessage, err := request.PackBuffer(buffer) - if err != nil { - return nil, E.Cause(err, "pack request") - } - _, err = conn.Write(rawMessage) - if err != nil { - if errors.Is(err, syscall.EMSGSIZE) && network == N.NetworkUDP { - continue - } - return nil, E.Cause(err, "write request") - } - n, err := conn.Read(buffer) - if err != nil { - return nil, E.Cause(err, "read response") - } - var response mDNS.Msg - err = response.Unpack(buffer[:n]) - if err != nil { - return nil, E.Cause(err, "unpack response") - } - if response.Truncated && network == N.NetworkUDP { - continue - } - return &response, nil + if !useTCP { + return t.exchangeUDP(ctx, server, request, timeout) + } else { + return t.exchangeTCP(ctx, server, request, timeout) } - panic("unexpected") +} + +func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) { + conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server) + if err != nil { + return nil, err + } + defer conn.Close() + if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { + newDeadline := time.Now().Add(timeout) + if deadline.After(newDeadline) { + deadline = newDeadline + } + conn.SetDeadline(deadline) + } + buffer := buf.Get(1 + request.Len()) + defer buf.Put(buffer) + rawMessage, err := request.PackBuffer(buffer) + if err != nil { + return nil, E.Cause(err, "pack request") + } + _, err = conn.Write(rawMessage) + if err != nil { + if errors.Is(err, syscall.EMSGSIZE) { + return t.exchangeTCP(ctx, server, request, timeout) + } + return nil, E.Cause(err, "write request") + } + n, err := conn.Read(buffer) + if err != nil { + if errors.Is(err, syscall.EMSGSIZE) { + return t.exchangeTCP(ctx, server, request, timeout) + } + return nil, E.Cause(err, "read response") + } + var response mDNS.Msg + err = response.Unpack(buffer[:n]) + if err != nil { + return nil, E.Cause(err, "unpack response") + } + if response.Truncated { + return t.exchangeTCP(ctx, server, request, timeout) + } + return &response, nil +} + +func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) { + conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server) + if err != nil { + return nil, err + } + defer conn.Close() + if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { + newDeadline := time.Now().Add(timeout) + if deadline.After(newDeadline) { + deadline = newDeadline + } + conn.SetDeadline(deadline) + } + err = transport.WriteMessage(conn, 0, request) + if err != nil { + return nil, err + } + return transport.ReadMessage(conn) }