106
common/ktls/ktls.go
Normal file
106
common/ktls/ktls.go
Normal file
@@ -0,0 +1,106 @@
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-box/common/badtls"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
aTLS.Conn
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
conn net.Conn
|
||||
rawConn *badtls.RawConn
|
||||
syscallConn syscall.Conn
|
||||
rawSyscallConn syscall.RawConn
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
kernelTx bool
|
||||
kernelRx bool
|
||||
}
|
||||
|
||||
func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
|
||||
err := Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
syscallConn, isSyscallConn := N.CastReader[interface {
|
||||
io.Reader
|
||||
syscall.Conn
|
||||
}](conn.NetConn())
|
||||
if !isSyscallConn {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
rawSyscallConn, err := syscallConn.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawConn, err := badtls.NewRawConn(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if *rawConn.Vers != tls.VersionTLS13 {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
for rawConn.RawInput.Len() > 0 {
|
||||
err = rawConn.ReadRecord()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rawConn.Hand.Len() > 0 {
|
||||
err = rawConn.HandlePostHandshakeMessage()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "handle post-handshake messages")
|
||||
}
|
||||
}
|
||||
}
|
||||
kConn := &Conn{
|
||||
Conn: conn,
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
conn: conn.NetConn(),
|
||||
rawConn: rawConn,
|
||||
syscallConn: syscallConn,
|
||||
rawSyscallConn: rawSyscallConn,
|
||||
}
|
||||
err = kConn.setupKernel(txOffload, rxOffload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return kConn, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *Conn) SyscallConnForRead() syscall.Conn {
|
||||
if !c.kernelRx {
|
||||
return nil
|
||||
}
|
||||
if !*c.rawConn.IsClient {
|
||||
c.logger.WarnContext(c.ctx, "ktls: RX splice is unavailable on the server size, since it will cause an unknown failure")
|
||||
return nil
|
||||
}
|
||||
c.logger.DebugContext(c.ctx, "ktls: RX splice requested")
|
||||
return c.syscallConn
|
||||
}
|
||||
|
||||
func (c *Conn) SyscallConnForWrite() syscall.Conn {
|
||||
if !c.kernelTx {
|
||||
return nil
|
||||
}
|
||||
c.logger.DebugContext(c.ctx, "ktls: TX splice requested")
|
||||
return c.syscallConn
|
||||
}
|
||||
80
common/ktls/ktls_alert.go
Normal file
80
common/ktls/ktls_alert.go
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
// alert level
|
||||
alertLevelWarning = 1
|
||||
alertLevelError = 2
|
||||
)
|
||||
|
||||
const (
|
||||
alertCloseNotify = 0
|
||||
alertUnexpectedMessage = 10
|
||||
alertBadRecordMAC = 20
|
||||
alertDecryptionFailed = 21
|
||||
alertRecordOverflow = 22
|
||||
alertDecompressionFailure = 30
|
||||
alertHandshakeFailure = 40
|
||||
alertBadCertificate = 42
|
||||
alertUnsupportedCertificate = 43
|
||||
alertCertificateRevoked = 44
|
||||
alertCertificateExpired = 45
|
||||
alertCertificateUnknown = 46
|
||||
alertIllegalParameter = 47
|
||||
alertUnknownCA = 48
|
||||
alertAccessDenied = 49
|
||||
alertDecodeError = 50
|
||||
alertDecryptError = 51
|
||||
alertExportRestriction = 60
|
||||
alertProtocolVersion = 70
|
||||
alertInsufficientSecurity = 71
|
||||
alertInternalError = 80
|
||||
alertInappropriateFallback = 86
|
||||
alertUserCanceled = 90
|
||||
alertNoRenegotiation = 100
|
||||
alertMissingExtension = 109
|
||||
alertUnsupportedExtension = 110
|
||||
alertCertificateUnobtainable = 111
|
||||
alertUnrecognizedName = 112
|
||||
alertBadCertificateStatusResponse = 113
|
||||
alertBadCertificateHashValue = 114
|
||||
alertUnknownPSKIdentity = 115
|
||||
alertCertificateRequired = 116
|
||||
alertNoApplicationProtocol = 120
|
||||
alertECHRequired = 121
|
||||
)
|
||||
|
||||
func (c *Conn) sendAlertLocked(err uint8) error {
|
||||
switch err {
|
||||
case alertNoRenegotiation, alertCloseNotify:
|
||||
c.rawConn.Tmp[0] = alertLevelWarning
|
||||
default:
|
||||
c.rawConn.Tmp[0] = alertLevelError
|
||||
}
|
||||
c.rawConn.Tmp[1] = byte(err)
|
||||
|
||||
_, writeErr := c.writeRecordLocked(recordTypeAlert, c.rawConn.Tmp[0:2])
|
||||
if err == alertCloseNotify {
|
||||
// closeNotify is a special case in that it isn't an error.
|
||||
return writeErr
|
||||
}
|
||||
|
||||
return c.rawConn.Out.SetErrorLocked(&net.OpError{Op: "local error", Err: tls.AlertError(err)})
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
func (c *Conn) sendAlert(err uint8) error {
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
return c.sendAlertLocked(err)
|
||||
}
|
||||
326
common/ktls/ktls_cipher_suites_linux.go
Normal file
326
common/ktls/ktls_cipher_suites_linux.go
Normal file
@@ -0,0 +1,326 @@
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-box/common/badtls"
|
||||
)
|
||||
|
||||
type kernelCryptoCipherType uint16
|
||||
|
||||
const (
|
||||
TLS_CIPHER_AES_GCM_128 kernelCryptoCipherType = 51
|
||||
TLS_CIPHER_AES_GCM_128_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_AES_GCM_128_KEY_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_GCM_128_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_AES_GCM_128_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_AES_GCM_256 kernelCryptoCipherType = 52
|
||||
TLS_CIPHER_AES_GCM_256_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_AES_GCM_256_KEY_SIZE kernelCryptoCipherType = 32
|
||||
TLS_CIPHER_AES_GCM_256_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_AES_GCM_256_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_AES_CCM_128 kernelCryptoCipherType = 53
|
||||
TLS_CIPHER_AES_CCM_128_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_AES_CCM_128_KEY_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_CCM_128_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_AES_CCM_128_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_CHACHA20_POLY1305 kernelCryptoCipherType = 54
|
||||
TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE kernelCryptoCipherType = 12
|
||||
TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE kernelCryptoCipherType = 32
|
||||
TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE kernelCryptoCipherType = 0
|
||||
TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
// TLS_CIPHER_SM4_GCM kernelCryptoCipherType = 55
|
||||
// TLS_CIPHER_SM4_GCM_IV_SIZE kernelCryptoCipherType = 8
|
||||
// TLS_CIPHER_SM4_GCM_KEY_SIZE kernelCryptoCipherType = 16
|
||||
// TLS_CIPHER_SM4_GCM_SALT_SIZE kernelCryptoCipherType = 4
|
||||
// TLS_CIPHER_SM4_GCM_TAG_SIZE kernelCryptoCipherType = 16
|
||||
// TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
// TLS_CIPHER_SM4_CCM kernelCryptoCipherType = 56
|
||||
// TLS_CIPHER_SM4_CCM_IV_SIZE kernelCryptoCipherType = 8
|
||||
// TLS_CIPHER_SM4_CCM_KEY_SIZE kernelCryptoCipherType = 16
|
||||
// TLS_CIPHER_SM4_CCM_SALT_SIZE kernelCryptoCipherType = 4
|
||||
// TLS_CIPHER_SM4_CCM_TAG_SIZE kernelCryptoCipherType = 16
|
||||
// TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_ARIA_GCM_128 kernelCryptoCipherType = 57
|
||||
TLS_CIPHER_ARIA_GCM_128_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_ARIA_GCM_128_KEY_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_ARIA_GCM_128_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_ARIA_GCM_128_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_ARIA_GCM_256 kernelCryptoCipherType = 58
|
||||
TLS_CIPHER_ARIA_GCM_256_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_ARIA_GCM_256_KEY_SIZE kernelCryptoCipherType = 32
|
||||
TLS_CIPHER_ARIA_GCM_256_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_ARIA_GCM_256_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
)
|
||||
|
||||
type kernelCrypto interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
type kernelCryptoInfo struct {
|
||||
version uint16
|
||||
cipher_type kernelCryptoCipherType
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoAES128GCM{}
|
||||
|
||||
type kernelCryptoAES128GCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_AES_GCM_128_IV_SIZE]byte
|
||||
key [TLS_CIPHER_AES_GCM_128_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_AES_GCM_128_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoAES128GCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_AES_GCM_128
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoAES256GCM{}
|
||||
|
||||
type kernelCryptoAES256GCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_AES_GCM_256_IV_SIZE]byte
|
||||
key [TLS_CIPHER_AES_GCM_256_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_AES_GCM_256_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoAES256GCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_AES_GCM_256
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoAES128CCM{}
|
||||
|
||||
type kernelCryptoAES128CCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_AES_CCM_128_IV_SIZE]byte
|
||||
key [TLS_CIPHER_AES_CCM_128_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_AES_CCM_128_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoAES128CCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_AES_CCM_128
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoChacha20Poly1035{}
|
||||
|
||||
type kernelCryptoChacha20Poly1035 struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE]byte
|
||||
key [TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoChacha20Poly1035) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_CHACHA20_POLY1305
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
// var _ kernelCrypto = &kernelCryptoSM4GCM{}
|
||||
|
||||
// type kernelCryptoSM4GCM struct {
|
||||
// kernelCryptoInfo
|
||||
// iv [TLS_CIPHER_SM4_GCM_IV_SIZE]byte
|
||||
// key [TLS_CIPHER_SM4_GCM_KEY_SIZE]byte
|
||||
// salt [TLS_CIPHER_SM4_GCM_SALT_SIZE]byte
|
||||
// rec_seq [TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE]byte
|
||||
// }
|
||||
|
||||
// func (crypto *kernelCryptoSM4GCM) String() string {
|
||||
// crypto.cipher_type = TLS_CIPHER_SM4_GCM
|
||||
// return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
// }
|
||||
|
||||
// var _ kernelCrypto = &kernelCryptoSM4CCM{}
|
||||
|
||||
// type kernelCryptoSM4CCM struct {
|
||||
// kernelCryptoInfo
|
||||
// iv [TLS_CIPHER_SM4_CCM_IV_SIZE]byte
|
||||
// key [TLS_CIPHER_SM4_CCM_KEY_SIZE]byte
|
||||
// salt [TLS_CIPHER_SM4_CCM_SALT_SIZE]byte
|
||||
// rec_seq [TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE]byte
|
||||
// }
|
||||
|
||||
// func (crypto *kernelCryptoSM4CCM) String() string {
|
||||
// crypto.cipher_type = TLS_CIPHER_SM4_CCM
|
||||
// return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
// }
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoARIA128GCM{}
|
||||
|
||||
type kernelCryptoARIA128GCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_ARIA_GCM_128_IV_SIZE]byte
|
||||
key [TLS_CIPHER_ARIA_GCM_128_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_ARIA_GCM_128_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoARIA128GCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_ARIA_GCM_128
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoARIA256GCM{}
|
||||
|
||||
type kernelCryptoARIA256GCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_ARIA_GCM_256_IV_SIZE]byte
|
||||
key [TLS_CIPHER_ARIA_GCM_256_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_ARIA_GCM_256_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoARIA256GCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_ARIA_GCM_256
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
func kernelCipher(kernel *Support, hc *badtls.RawHalfConn, cipherSuite uint16, isRX bool) kernelCrypto {
|
||||
if !kernel.TLS {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch *hc.Version {
|
||||
case tls.VersionTLS12:
|
||||
if isRX && !kernel.TLS_Version13_RX {
|
||||
return nil
|
||||
}
|
||||
|
||||
case tls.VersionTLS13:
|
||||
if !kernel.TLS_Version13 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isRX && !kernel.TLS_Version13_RX {
|
||||
return nil
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
var key, iv []byte
|
||||
if *hc.Version == tls.VersionTLS13 {
|
||||
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), *hc.TrafficSecret)
|
||||
/*if isRX {
|
||||
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), keyLog.RemoteTrafficSecret)
|
||||
} else {
|
||||
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), keyLog.TrafficSecret)
|
||||
}*/
|
||||
} else {
|
||||
// csPtr := cipherSuiteByID(cipherSuite)
|
||||
// keysFromMasterSecret(*hc.Version, csPtr, keyLog.Secret, keyLog.Random)
|
||||
return nil
|
||||
}
|
||||
|
||||
switch cipherSuite {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
|
||||
crypto := new(kernelCryptoAES128GCM)
|
||||
|
||||
crypto.version = *hc.Version
|
||||
copy(crypto.key[:], key)
|
||||
copy(crypto.iv[:], iv[4:])
|
||||
copy(crypto.salt[:], iv[:4])
|
||||
crypto.rec_seq = *hc.Seq
|
||||
|
||||
return crypto
|
||||
case tls.TLS_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
|
||||
if !kernel.TLS_AES_256_GCM {
|
||||
return nil
|
||||
}
|
||||
|
||||
crypto := new(kernelCryptoAES256GCM)
|
||||
|
||||
crypto.version = *hc.Version
|
||||
copy(crypto.key[:], key)
|
||||
copy(crypto.iv[:], iv[4:])
|
||||
copy(crypto.salt[:], iv[:4])
|
||||
crypto.rec_seq = *hc.Seq
|
||||
|
||||
return crypto
|
||||
//case tls.TLS_AES_128_CCM_SHA256, tls.TLS_RSA_WITH_AES_128_CCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_SHA256:
|
||||
// if !kernel.TLS_AES_128_CCM {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// crypto := new(kernelCryptoAES128CCM)
|
||||
//
|
||||
// crypto.version = *hc.Version
|
||||
// copy(crypto.key[:], key)
|
||||
// copy(crypto.iv[:], iv[4:])
|
||||
// copy(crypto.salt[:], iv[:4])
|
||||
// crypto.rec_seq = *hc.Seq
|
||||
//
|
||||
// return crypto
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
|
||||
if !kernel.TLS_CHACHA20_POLY1305 {
|
||||
return nil
|
||||
}
|
||||
|
||||
crypto := new(kernelCryptoChacha20Poly1035)
|
||||
|
||||
crypto.version = *hc.Version
|
||||
copy(crypto.key[:], key)
|
||||
copy(crypto.iv[:], iv)
|
||||
crypto.rec_seq = *hc.Seq
|
||||
|
||||
return crypto
|
||||
//case tls.TLS_RSA_WITH_ARIA_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256:
|
||||
// if !kernel.TLS_ARIA_GCM {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// crypto := new(kernelCryptoARIA128GCM)
|
||||
//
|
||||
// crypto.version = *hc.Version
|
||||
// copy(crypto.key[:], key)
|
||||
// copy(crypto.iv[:], iv[4:])
|
||||
// copy(crypto.salt[:], iv[:4])
|
||||
// crypto.rec_seq = *hc.Seq
|
||||
//
|
||||
// return crypto
|
||||
//case tls.TLS_RSA_WITH_ARIA_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384:
|
||||
// if !kernel.TLS_ARIA_GCM {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// crypto := new(kernelCryptoARIA256GCM)
|
||||
//
|
||||
// crypto.version = *hc.Version
|
||||
// copy(crypto.key[:], key)
|
||||
// copy(crypto.iv[:], iv[4:])
|
||||
// copy(crypto.salt[:], iv[:4])
|
||||
// crypto.rec_seq = *hc.Seq
|
||||
//
|
||||
// return crypto
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
67
common/ktls/ktls_close.go
Normal file
67
common/ktls/ktls_close.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
if !c.kernelTx {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// Interlock with Conn.Write above.
|
||||
var x int32
|
||||
for {
|
||||
x = c.rawConn.ActiveCall.Load()
|
||||
if x&1 != 0 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if c.rawConn.ActiveCall.CompareAndSwap(x, x|1) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if x != 0 {
|
||||
// io.Writer and io.Closer should not be used concurrently.
|
||||
// If Close is called while a Write is currently in-flight,
|
||||
// interpret that as a sign that this Close is really just
|
||||
// being used to break the Write and/or clean up resources and
|
||||
// avoid sending the alertCloseNotify, which may block
|
||||
// waiting on handshakeMutex or the c.out mutex.
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
var alertErr error
|
||||
if c.rawConn.IsHandshakeComplete.Load() {
|
||||
if err := c.closeNotify(); err != nil {
|
||||
alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return alertErr
|
||||
}
|
||||
|
||||
func (c *Conn) closeNotify() error {
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
|
||||
if !*c.rawConn.CloseNotifySent {
|
||||
// Set a Write Deadline to prevent possibly blocking forever.
|
||||
c.SetWriteDeadline(time.Now().Add(time.Second * 5))
|
||||
*c.rawConn.CloseNotifyErr = c.sendAlertLocked(alertCloseNotify)
|
||||
*c.rawConn.CloseNotifySent = true
|
||||
// Any subsequent writes will fail.
|
||||
c.SetWriteDeadline(time.Now())
|
||||
}
|
||||
return *c.rawConn.CloseNotifyErr
|
||||
}
|
||||
24
common/ktls/ktls_const.go
Normal file
24
common/ktls/ktls_const.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
const (
|
||||
maxPlaintext = 16384 // maximum plaintext payload length
|
||||
maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
|
||||
maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3
|
||||
recordHeaderLen = 5 // record header length
|
||||
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
|
||||
maxHandshakeCertificateMsg = 262144 // maximum certificate message size (256 KiB)
|
||||
maxUselessRecords = 16 // maximum number of consecutive non-advancing records
|
||||
)
|
||||
|
||||
const (
|
||||
recordTypeChangeCipherSpec = 20
|
||||
recordTypeAlert = 21
|
||||
recordTypeHandshake = 22
|
||||
recordTypeApplicationData = 23
|
||||
)
|
||||
238
common/ktls/ktls_handshake_messages.go
Normal file
238
common/ktls/ktls_handshake_messages.go
Normal file
@@ -0,0 +1,238 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
)
|
||||
|
||||
// The marshalingFunction type is an adapter to allow the use of ordinary
|
||||
// functions as cryptobyte.MarshalingValue.
|
||||
type marshalingFunction func(b *cryptobyte.Builder) error
|
||||
|
||||
func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
|
||||
return f(b)
|
||||
}
|
||||
|
||||
// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
|
||||
// the length of the sequence is not the value specified, it produces an error.
|
||||
func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
|
||||
b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
|
||||
if len(v) != n {
|
||||
return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
|
||||
}
|
||||
b.AddBytes(v)
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
|
||||
func addUint64(b *cryptobyte.Builder, v uint64) {
|
||||
b.AddUint32(uint32(v >> 32))
|
||||
b.AddUint32(uint32(v))
|
||||
}
|
||||
|
||||
// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
|
||||
// It reports whether the read was successful.
|
||||
func readUint64(s *cryptobyte.String, out *uint64) bool {
|
||||
var hi, lo uint32
|
||||
if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
|
||||
return false
|
||||
}
|
||||
*out = uint64(hi)<<32 | uint64(lo)
|
||||
return true
|
||||
}
|
||||
|
||||
// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
|
||||
// []byte instead of a cryptobyte.String.
|
||||
func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
|
||||
return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
|
||||
}
|
||||
|
||||
// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
|
||||
// []byte instead of a cryptobyte.String.
|
||||
func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
|
||||
return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
|
||||
}
|
||||
|
||||
// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
|
||||
// []byte instead of a cryptobyte.String.
|
||||
func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
|
||||
return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
|
||||
}
|
||||
|
||||
type keyUpdateMsg struct {
|
||||
updateRequested bool
|
||||
}
|
||||
|
||||
func (m *keyUpdateMsg) marshal() ([]byte, error) {
|
||||
var b cryptobyte.Builder
|
||||
b.AddUint8(typeKeyUpdate)
|
||||
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
if m.updateRequested {
|
||||
b.AddUint8(1)
|
||||
} else {
|
||||
b.AddUint8(0)
|
||||
}
|
||||
})
|
||||
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
|
||||
s := cryptobyte.String(data)
|
||||
|
||||
var updateRequested uint8
|
||||
if !s.Skip(4) || // message type and uint24 length field
|
||||
!s.ReadUint8(&updateRequested) || !s.Empty() {
|
||||
return false
|
||||
}
|
||||
switch updateRequested {
|
||||
case 0:
|
||||
m.updateRequested = false
|
||||
case 1:
|
||||
m.updateRequested = true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TLS handshake message types.
|
||||
const (
|
||||
typeHelloRequest uint8 = 0
|
||||
typeClientHello uint8 = 1
|
||||
typeServerHello uint8 = 2
|
||||
typeNewSessionTicket uint8 = 4
|
||||
typeEndOfEarlyData uint8 = 5
|
||||
typeEncryptedExtensions uint8 = 8
|
||||
typeCertificate uint8 = 11
|
||||
typeServerKeyExchange uint8 = 12
|
||||
typeCertificateRequest uint8 = 13
|
||||
typeServerHelloDone uint8 = 14
|
||||
typeCertificateVerify uint8 = 15
|
||||
typeClientKeyExchange uint8 = 16
|
||||
typeFinished uint8 = 20
|
||||
typeCertificateStatus uint8 = 22
|
||||
typeKeyUpdate uint8 = 24
|
||||
typeCompressedCertificate uint8 = 25
|
||||
typeMessageHash uint8 = 254 // synthetic message
|
||||
)
|
||||
|
||||
// TLS compression types.
|
||||
const (
|
||||
compressionNone uint8 = 0
|
||||
)
|
||||
|
||||
// TLS extension numbers
|
||||
const (
|
||||
extensionServerName uint16 = 0
|
||||
extensionStatusRequest uint16 = 5
|
||||
extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7
|
||||
extensionSupportedPoints uint16 = 11
|
||||
extensionSignatureAlgorithms uint16 = 13
|
||||
extensionALPN uint16 = 16
|
||||
extensionSCT uint16 = 18
|
||||
extensionPadding uint16 = 21
|
||||
extensionExtendedMasterSecret uint16 = 23
|
||||
extensionCompressCertificate uint16 = 27 // compress_certificate in TLS 1.3
|
||||
extensionSessionTicket uint16 = 35
|
||||
extensionPreSharedKey uint16 = 41
|
||||
extensionEarlyData uint16 = 42
|
||||
extensionSupportedVersions uint16 = 43
|
||||
extensionCookie uint16 = 44
|
||||
extensionPSKModes uint16 = 45
|
||||
extensionCertificateAuthorities uint16 = 47
|
||||
extensionSignatureAlgorithmsCert uint16 = 50
|
||||
extensionKeyShare uint16 = 51
|
||||
extensionQUICTransportParameters uint16 = 57
|
||||
extensionALPS uint16 = 17513
|
||||
extensionRenegotiationInfo uint16 = 0xff01
|
||||
extensionECHOuterExtensions uint16 = 0xfd00
|
||||
extensionEncryptedClientHello uint16 = 0xfe0d
|
||||
)
|
||||
|
||||
type handshakeMessage interface {
|
||||
marshal() ([]byte, error)
|
||||
unmarshal([]byte) bool
|
||||
}
|
||||
type newSessionTicketMsgTLS13 struct {
|
||||
lifetime uint32
|
||||
ageAdd uint32
|
||||
nonce []byte
|
||||
label []byte
|
||||
maxEarlyData uint32
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
|
||||
var b cryptobyte.Builder
|
||||
b.AddUint8(typeNewSessionTicket)
|
||||
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddUint32(m.lifetime)
|
||||
b.AddUint32(m.ageAdd)
|
||||
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(m.nonce)
|
||||
})
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(m.label)
|
||||
})
|
||||
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
if m.maxEarlyData > 0 {
|
||||
b.AddUint16(extensionEarlyData)
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddUint32(m.maxEarlyData)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
|
||||
*m = newSessionTicketMsgTLS13{}
|
||||
s := cryptobyte.String(data)
|
||||
|
||||
var extensions cryptobyte.String
|
||||
if !s.Skip(4) || // message type and uint24 length field
|
||||
!s.ReadUint32(&m.lifetime) ||
|
||||
!s.ReadUint32(&m.ageAdd) ||
|
||||
!readUint8LengthPrefixed(&s, &m.nonce) ||
|
||||
!readUint16LengthPrefixed(&s, &m.label) ||
|
||||
!s.ReadUint16LengthPrefixed(&extensions) ||
|
||||
!s.Empty() {
|
||||
return false
|
||||
}
|
||||
|
||||
for !extensions.Empty() {
|
||||
var extension uint16
|
||||
var extData cryptobyte.String
|
||||
if !extensions.ReadUint16(&extension) ||
|
||||
!extensions.ReadUint16LengthPrefixed(&extData) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch extension {
|
||||
case extensionEarlyData:
|
||||
if !extData.ReadUint32(&m.maxEarlyData) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
// Ignore unknown extensions.
|
||||
continue
|
||||
}
|
||||
|
||||
if !extData.Empty() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
173
common/ktls/ktls_key_update.go
Normal file
173
common/ktls/ktls_key_update.go
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// handlePostHandshakeMessage processes a handshake message arrived after the
|
||||
// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
|
||||
func (c *Conn) handlePostHandshakeMessage() error {
|
||||
if *c.rawConn.Vers != tls.VersionTLS13 {
|
||||
return errors.New("ktls: kernel does not support TLS 1.2 renegotiation")
|
||||
}
|
||||
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//c.retryCount++
|
||||
//if c.retryCount > maxUselessRecords {
|
||||
// c.sendAlert(alertUnexpectedMessage)
|
||||
// return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
|
||||
//}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *newSessionTicketMsgTLS13:
|
||||
// return errors.New("ktls: received new session ticket")
|
||||
return nil
|
||||
case *keyUpdateMsg:
|
||||
return c.handleKeyUpdate(msg)
|
||||
}
|
||||
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
|
||||
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
|
||||
// unexpected_message alert here doesn't provide it with enough information to distinguish
|
||||
// this condition from other unexpected messages. This is probably fine.
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
|
||||
}
|
||||
|
||||
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
|
||||
//if c.quic != nil {
|
||||
// c.sendAlert(alertUnexpectedMessage)
|
||||
// return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
|
||||
//}
|
||||
|
||||
cipherSuite := cipherSuiteTLS13ByID(*c.rawConn.CipherSuite)
|
||||
if cipherSuite == nil {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertInternalError))
|
||||
}
|
||||
|
||||
newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.In.TrafficSecret)
|
||||
c.rawConn.In.SetTrafficSecret(cipherSuite, 0 /*tls.QUICEncryptionLevelInitial*/, newSecret)
|
||||
|
||||
err := c.resetupRX()
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return c.rawConn.In.SetErrorLocked(fmt.Errorf("ktls: resetupRX failed: %w", err))
|
||||
}
|
||||
|
||||
if keyUpdate.updateRequested {
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
|
||||
resetup, err := c.resetupTX()
|
||||
if err != nil {
|
||||
c.sendAlertLocked(alertInternalError)
|
||||
return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err))
|
||||
}
|
||||
|
||||
msg := &keyUpdateMsg{}
|
||||
msgBytes, err := msg.marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
|
||||
if err != nil {
|
||||
// Surface the error at the next write.
|
||||
c.rawConn.Out.SetErrorLocked(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.Out.TrafficSecret)
|
||||
c.rawConn.Out.SetTrafficSecret(cipherSuite, 0 /*QUICEncryptionLevelInitial*/, newSecret)
|
||||
|
||||
err = resetup()
|
||||
if err != nil {
|
||||
return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) readHandshakeBytes(n int) error {
|
||||
//if c.quic != nil {
|
||||
// return c.quicReadHandshakeBytes(n)
|
||||
//}
|
||||
for c.rawConn.Hand.Len() < n {
|
||||
if err := c.readRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) readHandshake(transcript io.Writer) (any, error) {
|
||||
if err := c.readHandshakeBytes(4); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := c.rawConn.Hand.Bytes()
|
||||
|
||||
maxHandshakeSize := maxHandshake
|
||||
// hasVers indicates we're past the first message, forcing someone trying to
|
||||
// make us just allocate a large buffer to at least do the initial part of
|
||||
// the handshake first.
|
||||
//if c.haveVers && data[0] == typeCertificate {
|
||||
// Since certificate messages are likely to be the only messages that
|
||||
// can be larger than maxHandshake, we use a special limit for just
|
||||
// those messages.
|
||||
//maxHandshakeSize = maxHandshakeCertificateMsg
|
||||
//}
|
||||
|
||||
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
if n > maxHandshakeSize {
|
||||
c.sendAlertLocked(alertInternalError)
|
||||
return nil, c.rawConn.In.SetErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
|
||||
}
|
||||
if err := c.readHandshakeBytes(4 + n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data = c.rawConn.Hand.Next(4 + n)
|
||||
return c.unmarshalHandshakeMessage(data, transcript)
|
||||
}
|
||||
|
||||
func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript io.Writer) (any, error) {
|
||||
var m handshakeMessage
|
||||
switch data[0] {
|
||||
case typeNewSessionTicket:
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 {
|
||||
m = new(newSessionTicketMsgTLS13)
|
||||
} else {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
case typeKeyUpdate:
|
||||
m = new(keyUpdateMsg)
|
||||
default:
|
||||
return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
||||
// The handshake message unmarshalers
|
||||
// expect to be able to keep references to data,
|
||||
// so pass in a fresh copy that won't be overwritten.
|
||||
data = append([]byte(nil), data...)
|
||||
|
||||
if !m.unmarshal(data) {
|
||||
return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError))
|
||||
}
|
||||
|
||||
if transcript != nil {
|
||||
transcript.Write(data)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
329
common/ktls/ktls_linux.go
Normal file
329
common/ktls/ktls_linux.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-box/common/badversion"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/shell"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// mod from https://gitlab.com/go-extension/tls
|
||||
|
||||
const (
|
||||
TLS_TX = 1
|
||||
TLS_RX = 2
|
||||
TLS_TX_ZEROCOPY_RO = 3 // TX zerocopy (only sendfile now)
|
||||
TLS_RX_EXPECT_NO_PAD = 4 // Attempt opportunistic zero-copy, TLS 1.3 only
|
||||
|
||||
TLS_SET_RECORD_TYPE = 1
|
||||
TLS_GET_RECORD_TYPE = 2
|
||||
)
|
||||
|
||||
type Support struct {
|
||||
TLS, TLS_RX bool
|
||||
TLS_Version13, TLS_Version13_RX bool
|
||||
|
||||
TLS_TX_ZEROCOPY bool
|
||||
TLS_RX_NOPADDING bool
|
||||
|
||||
TLS_AES_256_GCM bool
|
||||
TLS_AES_128_CCM bool
|
||||
TLS_CHACHA20_POLY1305 bool
|
||||
TLS_SM4 bool
|
||||
TLS_ARIA_GCM bool
|
||||
|
||||
TLS_Version13_KeyUpdate bool
|
||||
}
|
||||
|
||||
var KernelSupport = sync.OnceValues(func() (*Support, error) {
|
||||
var uname unix.Utsname
|
||||
err := unix.Uname(&uname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kernelVersion := badversion.Parse(strings.Trim(string(uname.Release[:]), "\x00"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var support Support
|
||||
switch {
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6, Minor: 14}):
|
||||
support.TLS_Version13_KeyUpdate = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6, Minor: 1}):
|
||||
support.TLS_ARIA_GCM = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6}):
|
||||
support.TLS_Version13_RX = true
|
||||
support.TLS_RX_NOPADDING = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 19}):
|
||||
support.TLS_TX_ZEROCOPY = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 16}):
|
||||
support.TLS_SM4 = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 11}):
|
||||
support.TLS_CHACHA20_POLY1305 = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 2}):
|
||||
support.TLS_AES_128_CCM = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 1}):
|
||||
support.TLS_AES_256_GCM = true
|
||||
support.TLS_Version13 = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 4, Minor: 17}):
|
||||
support.TLS_RX = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 4, Minor: 13}):
|
||||
support.TLS = true
|
||||
}
|
||||
|
||||
if support.TLS && support.TLS_Version13 {
|
||||
_, err := os.Stat("/sys/module/tls")
|
||||
if err != nil {
|
||||
if os.Getuid() == 0 {
|
||||
output, err := shell.Exec("modprobe", "tls").Read()
|
||||
if err != nil {
|
||||
return nil, E.Extend(E.Cause(err, "modprobe tls"), output)
|
||||
}
|
||||
} else {
|
||||
return nil, E.New("ktls: kernel TLS module not loaded")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &support, nil
|
||||
})
|
||||
|
||||
func Load() error {
|
||||
support, err := KernelSupport()
|
||||
if err != nil {
|
||||
return E.Cause(err, "ktls: check availability")
|
||||
}
|
||||
if !support.TLS || !support.TLS_Version13 {
|
||||
return E.New("ktls: kernel does not support TLS 1.3")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) setupKernel(txOffload, rxOffload bool) error {
|
||||
if !txOffload && !rxOffload {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
support, err := KernelSupport()
|
||||
if err != nil {
|
||||
return E.Cause(err, "check availability")
|
||||
}
|
||||
if !support.TLS || !support.TLS_Version13 {
|
||||
return E.New("kernel does not support TLS 1.3")
|
||||
}
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TCP, unix.TCP_ULP, "tls")
|
||||
})
|
||||
if err != nil {
|
||||
return os.NewSyscallError("setsockopt", err)
|
||||
}
|
||||
|
||||
if txOffload {
|
||||
txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false)
|
||||
if txCrypto == nil {
|
||||
return E.New("unsupported cipher suite")
|
||||
}
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String())
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if support.TLS_TX_ZEROCOPY {
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_TX_ZEROCOPY_RO, 1)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.kernelTx = true
|
||||
c.logger.DebugContext(c.ctx, "ktls: kernel TLS TX enabled")
|
||||
}
|
||||
|
||||
if rxOffload {
|
||||
rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true)
|
||||
if rxCrypto == nil {
|
||||
return E.New("unsupported cipher suite")
|
||||
}
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String())
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *c.rawConn.Vers >= tls.VersionTLS13 && support.TLS_RX_NOPADDING {
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.kernelRx = true
|
||||
c.logger.DebugContext(c.ctx, "ktls: kernel TLS RX enabled")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) resetupTX() (func() error, error) {
|
||||
if !c.kernelTx {
|
||||
return nil, nil
|
||||
}
|
||||
support, err := KernelSupport()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !support.TLS_Version13_KeyUpdate {
|
||||
return nil, errors.New("ktls: kernel does not support rekey")
|
||||
}
|
||||
txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false)
|
||||
if txCrypto == nil {
|
||||
return nil, errors.New("ktls: set kernelCipher on unsupported tls session")
|
||||
}
|
||||
return func() error {
|
||||
return control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String())
|
||||
})
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) resetupRX() error {
|
||||
if !c.kernelRx {
|
||||
return nil
|
||||
}
|
||||
support, err := KernelSupport()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !support.TLS_Version13_KeyUpdate {
|
||||
return errors.New("ktls: kernel does not support rekey")
|
||||
}
|
||||
rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true)
|
||||
if rxCrypto == nil {
|
||||
return errors.New("ktls: set kernelCipher on unsupported tls session")
|
||||
}
|
||||
return control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String())
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) readKernelRecord() (uint8, []byte, error) {
|
||||
if c.rawConn.RawInput.Len() < maxPlaintext {
|
||||
c.rawConn.RawInput.Grow(maxPlaintext - c.rawConn.RawInput.Len())
|
||||
}
|
||||
|
||||
data := c.rawConn.RawInput.Bytes()[:maxPlaintext]
|
||||
|
||||
// cmsg for record type
|
||||
buffer := make([]byte, unix.CmsgSpace(1))
|
||||
cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0]))
|
||||
cmsg.SetLen(unix.CmsgLen(1))
|
||||
|
||||
var iov unix.Iovec
|
||||
iov.Base = &data[0]
|
||||
iov.SetLen(len(data))
|
||||
|
||||
var msg unix.Msghdr
|
||||
msg.Control = &buffer[0]
|
||||
msg.Controllen = cmsg.Len
|
||||
msg.Iov = &iov
|
||||
msg.Iovlen = 1
|
||||
|
||||
var n int
|
||||
var err error
|
||||
er := c.rawSyscallConn.Read(func(fd uintptr) bool {
|
||||
n, err = recvmsg(int(fd), &msg, 0)
|
||||
return err != unix.EAGAIN
|
||||
})
|
||||
if er != nil {
|
||||
return 0, nil, er
|
||||
}
|
||||
switch err {
|
||||
case nil:
|
||||
case syscall.EINVAL:
|
||||
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertProtocolVersion))
|
||||
case syscall.EMSGSIZE:
|
||||
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow))
|
||||
case syscall.EBADMSG:
|
||||
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecryptError))
|
||||
default:
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if n <= 0 {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
if cmsg.Level == unix.SOL_TLS && cmsg.Type == TLS_GET_RECORD_TYPE {
|
||||
typ := buffer[unix.CmsgLen(0)]
|
||||
return typ, data[:n], nil
|
||||
}
|
||||
|
||||
return recordTypeApplicationData, data[:n], nil
|
||||
}
|
||||
|
||||
func (c *Conn) writeKernelRecord(typ uint16, data []byte) (int, error) {
|
||||
if typ == recordTypeApplicationData {
|
||||
return c.conn.Write(data)
|
||||
}
|
||||
|
||||
// cmsg for record type
|
||||
buffer := make([]byte, unix.CmsgSpace(1))
|
||||
cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0]))
|
||||
cmsg.SetLen(unix.CmsgLen(1))
|
||||
buffer[unix.CmsgLen(0)] = byte(typ)
|
||||
cmsg.Level = unix.SOL_TLS
|
||||
cmsg.Type = TLS_SET_RECORD_TYPE
|
||||
|
||||
var iov unix.Iovec
|
||||
iov.Base = &data[0]
|
||||
iov.SetLen(len(data))
|
||||
|
||||
var msg unix.Msghdr
|
||||
msg.Control = &buffer[0]
|
||||
msg.Controllen = cmsg.Len
|
||||
msg.Iov = &iov
|
||||
msg.Iovlen = 1
|
||||
|
||||
var n int
|
||||
var err error
|
||||
ew := c.rawSyscallConn.Write(func(fd uintptr) bool {
|
||||
n, err = sendmsg(int(fd), &msg, 0)
|
||||
return err != unix.EAGAIN
|
||||
})
|
||||
if ew != nil {
|
||||
return 0, ew
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
//go:linkname recvmsg golang.org/x/sys/unix.recvmsg
|
||||
func recvmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error)
|
||||
|
||||
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
|
||||
func sendmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error)
|
||||
24
common/ktls/ktls_prf.go
Normal file
24
common/ktls/ktls_prf.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import "unsafe"
|
||||
|
||||
//go:linkname cipherSuiteByID github.com/metacubex/utls.cipherSuiteByID
|
||||
func cipherSuiteByID(id uint16) unsafe.Pointer
|
||||
|
||||
//go:linkname keysFromMasterSecret github.com/metacubex/utls.keysFromMasterSecret
|
||||
func keysFromMasterSecret(version uint16, suite unsafe.Pointer, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte)
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/metacubex/utls.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) unsafe.Pointer
|
||||
|
||||
//go:linkname nextTrafficSecret github.com/metacubex/utls.(*cipherSuiteTLS13).nextTrafficSecret
|
||||
func nextTrafficSecret(cs unsafe.Pointer, trafficSecret []byte) []byte
|
||||
|
||||
//go:linkname trafficKey github.com/metacubex/utls.(*cipherSuiteTLS13).trafficKey
|
||||
func trafficKey(cs unsafe.Pointer, trafficSecret []byte) (key, iv []byte)
|
||||
292
common/ktls/ktls_read.go
Normal file
292
common/ktls/ktls_read.go
Normal file
@@ -0,0 +1,292 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
if !c.kernelRx {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
// Put this after Handshake, in case people were calling
|
||||
// Read(nil) for the side effect of the Handshake.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
c.rawConn.In.Lock()
|
||||
defer c.rawConn.In.Unlock()
|
||||
|
||||
for c.rawConn.Input.Len() == 0 {
|
||||
if err := c.readRecord(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for c.rawConn.Hand.Len() > 0 {
|
||||
if err := c.handlePostHandshakeMessage(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n, _ := c.rawConn.Input.Read(b)
|
||||
|
||||
// If a close-notify alert is waiting, read it so that we can return (n,
|
||||
// EOF) instead of (n, nil), to signal to the HTTP response reading
|
||||
// goroutine that the connection is now closed. This eliminates a race
|
||||
// where the HTTP response reading goroutine would otherwise not observe
|
||||
// the EOF until its next read, by which time a client goroutine might
|
||||
// have already tried to reuse the HTTP connection for a new request.
|
||||
// See https://golang.org/cl/76400046 and https://golang.org/issue/3514
|
||||
if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.RawInput.Len() > 0 &&
|
||||
c.rawConn.RawInput.Bytes()[0] == recordTypeAlert {
|
||||
if err := c.readRecord(); err != nil {
|
||||
return n, err // will be io.EOF on closeNotify
|
||||
}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) readRecord() error {
|
||||
if *c.rawConn.In.Err != nil {
|
||||
return *c.rawConn.In.Err
|
||||
}
|
||||
|
||||
typ, data, err := c.readRawRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(data) > maxPlaintext {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow))
|
||||
}
|
||||
|
||||
// Application Data messages are always protected.
|
||||
if c.rawConn.In.Cipher == nil && typ == recordTypeApplicationData {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
||||
//if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
|
||||
// This is a state-advancing message: reset the retry count.
|
||||
// c.retryCount = 0
|
||||
//}
|
||||
|
||||
// Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 && typ != recordTypeHandshake && c.rawConn.Hand.Len() > 0 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
||||
switch typ {
|
||||
default:
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
case recordTypeAlert:
|
||||
//if c.quic != nil {
|
||||
// return c.rawConn.In.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
//}
|
||||
if len(data) != 2 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
if data[1] == alertCloseNotify {
|
||||
return c.rawConn.In.SetErrorLocked(io.EOF)
|
||||
}
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 {
|
||||
// TLS 1.3 removed warning-level alerts except for alertUserCanceled
|
||||
// (RFC 8446, § 6.1). Since at least one major implementation
|
||||
// (https://bugs.openjdk.org/browse/JDK-8323517) misuses this alert,
|
||||
// many TLS stacks now ignore it outright when seen in a TLS 1.3
|
||||
// handshake (e.g. BoringSSL, NSS, Rustls).
|
||||
if data[1] == alertUserCanceled {
|
||||
// Like TLS 1.2 alertLevelWarning alerts, we drop the record and retry.
|
||||
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
|
||||
}
|
||||
return c.rawConn.In.SetErrorLocked(&net.OpError{Op: "remote error", Err: tls.AlertError(data[1])})
|
||||
}
|
||||
switch data[0] {
|
||||
case alertLevelWarning:
|
||||
// Drop the record on the floor and retry.
|
||||
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
|
||||
case alertLevelError:
|
||||
return c.rawConn.In.SetErrorLocked(&net.OpError{Op: "remote error", Err: tls.AlertError(data[1])})
|
||||
default:
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
||||
case recordTypeChangeCipherSpec:
|
||||
if len(data) != 1 || data[0] != 1 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError))
|
||||
}
|
||||
// Handshake messages are not allowed to fragment across the CCS.
|
||||
if c.rawConn.Hand.Len() > 0 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
// In TLS 1.3, change_cipher_spec records are ignored until the
|
||||
// Finished. See RFC 8446, Appendix D.4. Note that according to Section
|
||||
// 5, a server can send a ChangeCipherSpec before its ServerHello, when
|
||||
// c.vers is still unset. That's not useful though and suspicious if the
|
||||
// server then selects a lower protocol version, so don't allow that.
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 {
|
||||
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
|
||||
}
|
||||
// if !expectChangeCipherSpec {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
//}
|
||||
//if err := c.rawConn.In.changeCipherSpec(); err != nil {
|
||||
// return c.rawConn.In.setErrorLocked(c.sendAlert(err.(alert)))
|
||||
//}
|
||||
|
||||
case recordTypeApplicationData:
|
||||
// Some OpenSSL servers send empty records in order to randomize the
|
||||
// CBC RawIV. Ignore a limited number of empty records.
|
||||
if len(data) == 0 {
|
||||
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
|
||||
}
|
||||
// Note that data is owned by c.rawInput, following the Next call above,
|
||||
// to avoid copying the plaintext. This is safe because c.rawInput is
|
||||
// not read from or written to until c.input is drained.
|
||||
c.rawConn.Input.Reset(data)
|
||||
case recordTypeHandshake:
|
||||
if len(data) == 0 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
c.rawConn.Hand.Write(data)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
func (c *Conn) readRawRecord() (typ uint8, data []byte, err error) {
|
||||
// Read from kernel.
|
||||
if c.kernelRx {
|
||||
return c.readKernelRecord()
|
||||
}
|
||||
|
||||
// Read header, payload.
|
||||
if err = c.readFromUntil(c.conn, recordHeaderLen); err != nil {
|
||||
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
|
||||
// is an error, but popular web sites seem to do this, so we accept it
|
||||
// if and only if at the record boundary.
|
||||
if err == io.ErrUnexpectedEOF && c.rawConn.RawInput.Len() == 0 {
|
||||
err = io.EOF
|
||||
}
|
||||
if e, ok := err.(net.Error); !ok || !e.Temporary() {
|
||||
c.rawConn.In.SetErrorLocked(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
hdr := c.rawConn.RawInput.Bytes()[:recordHeaderLen]
|
||||
typ = hdr[0]
|
||||
|
||||
vers := uint16(hdr[1])<<8 | uint16(hdr[2])
|
||||
expectedVers := *c.rawConn.Vers
|
||||
if expectedVers == tls.VersionTLS13 {
|
||||
// All TLS 1.3 records are expected to have 0x0303 (1.2) after
|
||||
// the initial hello (RFC 8446 Section 5.1).
|
||||
expectedVers = tls.VersionTLS12
|
||||
}
|
||||
n := int(hdr[3])<<8 | int(hdr[4])
|
||||
if /*c.haveVers && */ vers != expectedVers {
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
|
||||
err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(nil, msg))
|
||||
return
|
||||
}
|
||||
//if !c.haveVers {
|
||||
// // First message, be extra suspicious: this might not be a TLS
|
||||
// // client. Bail out before reading a full 'body', if possible.
|
||||
// // The current max version is 3.3 so if the version is >= 16.0,
|
||||
// // it's probably not real.
|
||||
// if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
|
||||
// err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
|
||||
// return
|
||||
// }
|
||||
//}
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
|
||||
c.sendAlert(alertRecordOverflow)
|
||||
msg := fmt.Sprintf("oversized record received with length %d", n)
|
||||
err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(nil, msg))
|
||||
return
|
||||
}
|
||||
if err = c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
|
||||
if e, ok := err.(net.Error); !ok || !e.Temporary() {
|
||||
c.rawConn.In.SetErrorLocked(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Process message.
|
||||
record := c.rawConn.RawInput.Next(recordHeaderLen + n)
|
||||
data, typ, err = c.rawConn.In.Decrypt(record)
|
||||
if err != nil {
|
||||
err = c.rawConn.In.SetErrorLocked(c.sendAlert(uint8(err.(tls.AlertError))))
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like
|
||||
// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
|
||||
func (c *Conn) retryReadRecord( /*expectChangeCipherSpec bool*/ ) error {
|
||||
//c.retryCount++
|
||||
//if c.retryCount > maxUselessRecords {
|
||||
// c.sendAlert(alertUnexpectedMessage)
|
||||
// return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
|
||||
//}
|
||||
return c.readRecord( /*expectChangeCipherSpec*/ )
|
||||
}
|
||||
|
||||
// atLeastReader reads from R, stopping with EOF once at least N bytes have been
|
||||
// read. It is different from an io.LimitedReader in that it doesn't cut short
|
||||
// the last Read call, and in that it considers an early EOF an error.
|
||||
type atLeastReader struct {
|
||||
R io.Reader
|
||||
N int64
|
||||
}
|
||||
|
||||
func (r *atLeastReader) Read(p []byte) (int, error) {
|
||||
if r.N <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err := r.R.Read(p)
|
||||
r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
|
||||
if r.N > 0 && err == io.EOF {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
}
|
||||
if r.N <= 0 && err == nil {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// readFromUntil reads from r into c.rawConn.RawInput until c.rawConn.RawInput contains
|
||||
// at least n bytes or else returns an error.
|
||||
func (c *Conn) readFromUntil(r io.Reader, n int) error {
|
||||
if c.rawConn.RawInput.Len() >= n {
|
||||
return nil
|
||||
}
|
||||
needs := n - c.rawConn.RawInput.Len()
|
||||
// There might be extra input waiting on the wire. Make a best effort
|
||||
// attempt to fetch it so that it can be used in (*Conn).Read to
|
||||
// "predict" closeNotify alerts.
|
||||
c.rawConn.RawInput.Grow(needs + bytes.MinRead)
|
||||
_, err := c.rawConn.RawInput.ReadFrom(&atLeastReader{r, int64(needs)})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err tls.RecordHeaderError) {
|
||||
err.Msg = msg
|
||||
err.Conn = conn
|
||||
copy(err.RecordHeader[:], c.rawConn.RawInput.Bytes())
|
||||
return err
|
||||
}
|
||||
41
common/ktls/ktls_read_wait.go
Normal file
41
common/ktls/ktls_read_wait.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Conn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
c.rawConn.In.Lock()
|
||||
defer c.rawConn.In.Unlock()
|
||||
for c.rawConn.Input.Len() == 0 {
|
||||
err = c.readRecord()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
buffer = c.readWaitOptions.NewBuffer()
|
||||
n, err := c.rawConn.Input.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.Input.Len() > 0 &&
|
||||
c.rawConn.RawInput.Bytes()[0] == recordTypeAlert {
|
||||
_ = c.rawConn.ReadRecord()
|
||||
}
|
||||
c.readWaitOptions.PostReturn(buffer)
|
||||
return
|
||||
}
|
||||
15
common/ktls/ktls_stub.go
Normal file
15
common/ktls/ktls_stub.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !linux || !go1.25 || without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
154
common/ktls/ktls_write.go
Normal file
154
common/ktls/ktls_write.go
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && !without_badtls
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
if !c.kernelTx {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
// interlock with Close below
|
||||
for {
|
||||
x := c.rawConn.ActiveCall.Load()
|
||||
if x&1 != 0 {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if c.rawConn.ActiveCall.CompareAndSwap(x, x+2) {
|
||||
break
|
||||
}
|
||||
}
|
||||
defer c.rawConn.ActiveCall.Add(-2)
|
||||
|
||||
//if err := c.Conn.HandshakeContext(context.Background()); err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
|
||||
if err := *c.rawConn.Out.Err; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if !c.rawConn.IsHandshakeComplete.Load() {
|
||||
return 0, tls.AlertError(alertInternalError)
|
||||
}
|
||||
|
||||
if *c.rawConn.CloseNotifySent {
|
||||
// return 0, errShutdown
|
||||
return 0, errors.New("tls: protocol is shutdown")
|
||||
}
|
||||
|
||||
// TLS 1.0 is susceptible to a chosen-plaintext
|
||||
// attack when using block mode ciphers due to predictable IVs.
|
||||
// This can be prevented by splitting each Application Data
|
||||
// record into two records, effectively randomizing the RawIV.
|
||||
//
|
||||
// https://www.openssl.org/~bodo/tls-cbc.txt
|
||||
// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
|
||||
// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
|
||||
|
||||
var m int
|
||||
if len(b) > 1 && *c.rawConn.Vers == tls.VersionTLS10 {
|
||||
if _, ok := (*c.rawConn.Out.Cipher).(cipher.BlockMode); ok {
|
||||
n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
|
||||
if err != nil {
|
||||
return n, c.rawConn.Out.SetErrorLocked(err)
|
||||
}
|
||||
m, b = 1, b[1:]
|
||||
}
|
||||
}
|
||||
|
||||
n, err := c.writeRecordLocked(recordTypeApplicationData, b)
|
||||
return n + m, c.rawConn.Out.SetErrorLocked(err)
|
||||
}
|
||||
|
||||
func (c *Conn) writeRecordLocked(typ uint16, data []byte) (n int, err error) {
|
||||
if !c.kernelTx {
|
||||
return c.rawConn.WriteRecordLocked(typ, data)
|
||||
}
|
||||
/*for len(data) > 0 {
|
||||
m := len(data)
|
||||
if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
|
||||
m = maxPayload
|
||||
}
|
||||
_, err = c.writeKernelRecord(typ, data[:m])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += m
|
||||
data = data[m:]
|
||||
}*/
|
||||
return c.writeKernelRecord(typ, data)
|
||||
}
|
||||
|
||||
const (
|
||||
// tcpMSSEstimate is a conservative estimate of the TCP maximum segment
|
||||
// size (MSS). A constant is used, rather than querying the kernel for
|
||||
// the actual MSS, to avoid complexity. The value here is the IPv6
|
||||
// minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40
|
||||
// bytes) and a TCP header with timestamps (32 bytes).
|
||||
tcpMSSEstimate = 1208
|
||||
|
||||
// recordSizeBoostThreshold is the number of bytes of application data
|
||||
// sent after which the TLS record size will be increased to the
|
||||
// maximum.
|
||||
recordSizeBoostThreshold = 128 * 1024
|
||||
)
|
||||
|
||||
func (c *Conn) maxPayloadSizeForWrite(typ uint16) int {
|
||||
if /*c.config.DynamicRecordSizingDisabled ||*/ typ != recordTypeApplicationData {
|
||||
return maxPlaintext
|
||||
}
|
||||
|
||||
if *c.rawConn.PacketsSent >= recordSizeBoostThreshold {
|
||||
return maxPlaintext
|
||||
}
|
||||
|
||||
// Subtract TLS overheads to get the maximum payload size.
|
||||
payloadBytes := tcpMSSEstimate - recordHeaderLen - c.rawConn.Out.ExplicitNonceLen()
|
||||
if rawCipher := *c.rawConn.Out.Cipher; rawCipher != nil {
|
||||
switch ciph := rawCipher.(type) {
|
||||
case cipher.Stream:
|
||||
payloadBytes -= (*c.rawConn.Out.Mac).Size()
|
||||
case cipher.AEAD:
|
||||
payloadBytes -= ciph.Overhead()
|
||||
/*case cbcMode:
|
||||
blockSize := ciph.BlockSize()
|
||||
// The payload must fit in a multiple of blockSize, with
|
||||
// room for at least one padding byte.
|
||||
payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
|
||||
// The RawMac is appended before padding so affects the
|
||||
// payload size directly.
|
||||
payloadBytes -= c.out.mac.Size()*/
|
||||
default:
|
||||
panic("unknown cipher type")
|
||||
}
|
||||
}
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 {
|
||||
payloadBytes-- // encrypted ContentType
|
||||
}
|
||||
|
||||
// Allow packet growth in arithmetic progression up to max.
|
||||
pkt := *c.rawConn.PacketsSent
|
||||
*c.rawConn.PacketsSent++
|
||||
if pkt > 1000 {
|
||||
return maxPlaintext // avoid overflow in multiply below
|
||||
}
|
||||
|
||||
n := payloadBytes * int(pkt+1)
|
||||
if n > maxPlaintext {
|
||||
n = maxPlaintext
|
||||
}
|
||||
return n
|
||||
}
|
||||
Reference in New Issue
Block a user