First Commmit
This commit is contained in:
299
experimental/cachefile/dns_cache.go
Normal file
299
experimental/cachefile/dns_cache.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package cachefile
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/bbolt"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
var bucketDNSCache = []byte("dns_cache")
|
||||
|
||||
func (c *CacheFile) StoreDNS() bool {
|
||||
return c.storeDNS
|
||||
}
|
||||
|
||||
func (c *CacheFile) LoadDNSCache(transportName string, qName string, qType uint16) (rawMessage []byte, expireAt time.Time, loaded bool) {
|
||||
c.saveDNSCacheAccess.RLock()
|
||||
entry, cached := c.saveDNSCache[saveCacheKey{transportName, qName, qType}]
|
||||
c.saveDNSCacheAccess.RUnlock()
|
||||
if cached {
|
||||
return entry.rawMessage, entry.expireAt, true
|
||||
}
|
||||
key := buf.Get(2 + len(qName))
|
||||
binary.BigEndian.PutUint16(key, qType)
|
||||
copy(key[2:], qName)
|
||||
defer buf.Put(key)
|
||||
err := c.view(func(tx *bbolt.Tx) error {
|
||||
bucket := c.bucket(tx, bucketDNSCache)
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
bucket = bucket.Bucket([]byte(transportName))
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
content := bucket.Get(key)
|
||||
if len(content) < 8 {
|
||||
return nil
|
||||
}
|
||||
expireAt = time.Unix(int64(binary.BigEndian.Uint64(content[:8])), 0)
|
||||
rawMessage = make([]byte, len(content)-8)
|
||||
copy(rawMessage, content[8:])
|
||||
loaded = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, time.Time{}, false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *CacheFile) SaveDNSCache(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time) error {
|
||||
return c.batch(func(tx *bbolt.Tx) error {
|
||||
bucket, err := c.createBucket(tx, bucketDNSCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bucket, err = bucket.CreateBucketIfNotExists([]byte(transportName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := buf.Get(2 + len(qName))
|
||||
binary.BigEndian.PutUint16(key, qType)
|
||||
copy(key[2:], qName)
|
||||
defer buf.Put(key)
|
||||
value := buf.Get(8 + len(rawMessage))
|
||||
defer buf.Put(value)
|
||||
binary.BigEndian.PutUint64(value[:8], uint64(expireAt.Unix()))
|
||||
copy(value[8:], rawMessage)
|
||||
return bucket.Put(key, value)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *CacheFile) SaveDNSCacheAsync(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time, logger logger.Logger) {
|
||||
saveKey := saveCacheKey{transportName, qName, qType}
|
||||
if !c.queueDNSCacheSave(saveKey, rawMessage, expireAt) {
|
||||
return
|
||||
}
|
||||
go c.flushPendingDNSCache(saveKey, logger)
|
||||
}
|
||||
|
||||
func (c *CacheFile) queueDNSCacheSave(saveKey saveCacheKey, rawMessage []byte, expireAt time.Time) bool {
|
||||
c.saveDNSCacheAccess.Lock()
|
||||
defer c.saveDNSCacheAccess.Unlock()
|
||||
entry := c.saveDNSCache[saveKey]
|
||||
entry.rawMessage = append([]byte(nil), rawMessage...)
|
||||
entry.expireAt = expireAt
|
||||
entry.sequence++
|
||||
startFlush := !entry.saving
|
||||
entry.saving = true
|
||||
c.saveDNSCache[saveKey] = entry
|
||||
return startFlush
|
||||
}
|
||||
|
||||
func (c *CacheFile) flushPendingDNSCache(saveKey saveCacheKey, logger logger.Logger) {
|
||||
c.flushPendingDNSCacheWith(saveKey, logger, func(entry saveDNSCacheEntry) error {
|
||||
return c.SaveDNSCache(saveKey.TransportName, saveKey.QuestionName, saveKey.QType, entry.rawMessage, entry.expireAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *CacheFile) flushPendingDNSCacheWith(saveKey saveCacheKey, logger logger.Logger, save func(saveDNSCacheEntry) error) {
|
||||
for {
|
||||
c.saveDNSCacheAccess.RLock()
|
||||
entry, loaded := c.saveDNSCache[saveKey]
|
||||
c.saveDNSCacheAccess.RUnlock()
|
||||
if !loaded {
|
||||
return
|
||||
}
|
||||
err := save(entry)
|
||||
if err != nil {
|
||||
logger.Warn("save DNS cache: ", err)
|
||||
}
|
||||
c.saveDNSCacheAccess.Lock()
|
||||
currentEntry, loaded := c.saveDNSCache[saveKey]
|
||||
if !loaded {
|
||||
c.saveDNSCacheAccess.Unlock()
|
||||
return
|
||||
}
|
||||
if currentEntry.sequence != entry.sequence {
|
||||
c.saveDNSCacheAccess.Unlock()
|
||||
continue
|
||||
}
|
||||
delete(c.saveDNSCache, saveKey)
|
||||
c.saveDNSCacheAccess.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheFile) ClearDNSCache() error {
|
||||
c.saveDNSCacheAccess.Lock()
|
||||
clear(c.saveDNSCache)
|
||||
c.saveDNSCacheAccess.Unlock()
|
||||
return c.batch(func(tx *bbolt.Tx) error {
|
||||
if c.cacheID == nil {
|
||||
bucket := tx.Bucket(bucketDNSCache)
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
return tx.DeleteBucket(bucketDNSCache)
|
||||
}
|
||||
bucket := tx.Bucket(c.cacheID)
|
||||
if bucket == nil || bucket.Bucket(bucketDNSCache) == nil {
|
||||
return nil
|
||||
}
|
||||
return bucket.DeleteBucket(bucketDNSCache)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *CacheFile) loopCacheCleanup(interval time.Duration, cleanupFunc func()) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cleanupFunc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheFile) cleanupDNSCache() {
|
||||
now := time.Now()
|
||||
err := c.batch(func(tx *bbolt.Tx) error {
|
||||
bucket := c.bucket(tx, bucketDNSCache)
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
var emptyTransports [][]byte
|
||||
err := bucket.ForEachBucket(func(transportName []byte) error {
|
||||
transportBucket := bucket.Bucket(transportName)
|
||||
if transportBucket == nil {
|
||||
return nil
|
||||
}
|
||||
var expiredKeys [][]byte
|
||||
err := transportBucket.ForEach(func(key, value []byte) error {
|
||||
if len(value) < 8 {
|
||||
expiredKeys = append(expiredKeys, append([]byte(nil), key...))
|
||||
return nil
|
||||
}
|
||||
if c.disableExpire {
|
||||
return nil
|
||||
}
|
||||
expireAt := time.Unix(int64(binary.BigEndian.Uint64(value[:8])), 0)
|
||||
if now.After(expireAt.Add(c.optimisticTimeout)) {
|
||||
expiredKeys = append(expiredKeys, append([]byte(nil), key...))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range expiredKeys {
|
||||
err = transportBucket.Delete(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
first, _ := transportBucket.Cursor().First()
|
||||
if first == nil {
|
||||
emptyTransports = append(emptyTransports, append([]byte(nil), transportName...))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range emptyTransports {
|
||||
err = bucket.DeleteBucket(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Warn("cleanup DNS cache: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheFile) clearRDRC() {
|
||||
c.saveRDRCAccess.Lock()
|
||||
clear(c.saveRDRC)
|
||||
c.saveRDRCAccess.Unlock()
|
||||
err := c.batch(func(tx *bbolt.Tx) error {
|
||||
if c.cacheID == nil {
|
||||
if tx.Bucket(bucketRDRC) == nil {
|
||||
return nil
|
||||
}
|
||||
return tx.DeleteBucket(bucketRDRC)
|
||||
}
|
||||
bucket := tx.Bucket(c.cacheID)
|
||||
if bucket == nil || bucket.Bucket(bucketRDRC) == nil {
|
||||
return nil
|
||||
}
|
||||
return bucket.DeleteBucket(bucketRDRC)
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Warn("clear RDRC: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheFile) cleanupRDRC() {
|
||||
now := time.Now()
|
||||
err := c.batch(func(tx *bbolt.Tx) error {
|
||||
bucket := c.bucket(tx, bucketRDRC)
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
var emptyTransports [][]byte
|
||||
err := bucket.ForEachBucket(func(transportName []byte) error {
|
||||
transportBucket := bucket.Bucket(transportName)
|
||||
if transportBucket == nil {
|
||||
return nil
|
||||
}
|
||||
var expiredKeys [][]byte
|
||||
err := transportBucket.ForEach(func(key, value []byte) error {
|
||||
if len(value) < 8 {
|
||||
expiredKeys = append(expiredKeys, append([]byte(nil), key...))
|
||||
return nil
|
||||
}
|
||||
expiresAt := time.Unix(int64(binary.BigEndian.Uint64(value)), 0)
|
||||
if now.After(expiresAt) {
|
||||
expiredKeys = append(expiredKeys, append([]byte(nil), key...))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range expiredKeys {
|
||||
err = transportBucket.Delete(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
first, _ := transportBucket.Cursor().First()
|
||||
if first == nil {
|
||||
emptyTransports = append(emptyTransports, append([]byte(nil), transportName...))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range emptyTransports {
|
||||
err = bucket.DeleteBucket(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Warn("cleanup RDRC: ", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user