First Commmit

This commit is contained in:
CN-JS-HuiBai
2026-04-14 22:41:14 +08:00
commit 9f867b19da
1086 changed files with 147554 additions and 0 deletions

177
service/ssmapi/api.go Normal file
View File

@@ -0,0 +1,177 @@
package ssmapi
import (
"net/http"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common/logger"
sHTTP "github.com/sagernet/sing/protocol/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
)
type APIServer struct {
logger logger.Logger
traffic *TrafficManager
user *UserManager
}
func NewAPIServer(logger logger.Logger, traffic *TrafficManager, user *UserManager) *APIServer {
return &APIServer{
logger: logger,
traffic: traffic,
user: user,
}
}
func (s *APIServer) Route(r chi.Router) {
r.Route("/server/v1", func(r chi.Router) {
r.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
s.logger.Debug(request.Method, " ", request.RequestURI, " ", sHTTP.SourceAddress(request))
handler.ServeHTTP(writer, request)
})
})
r.Get("/", s.getServerInfo)
r.Get("/users", s.listUser)
r.Post("/users", s.addUser)
r.Get("/users/{username}", s.getUser)
r.Put("/users/{username}", s.updateUser)
r.Delete("/users/{username}", s.deleteUser)
r.Get("/stats", s.getStats)
})
}
func (s *APIServer) getServerInfo(writer http.ResponseWriter, request *http.Request) {
render.JSON(writer, request, render.M{
"server": "sing-box " + C.Version,
"apiVersion": "v1",
})
}
type UserObject struct {
UserName string `json:"username"`
Password string `json:"uPSK,omitempty"`
DownlinkBytes int64 `json:"downlinkBytes"`
UplinkBytes int64 `json:"uplinkBytes"`
DownlinkPackets int64 `json:"downlinkPackets"`
UplinkPackets int64 `json:"uplinkPackets"`
TCPSessions int64 `json:"tcpSessions"`
UDPSessions int64 `json:"udpSessions"`
}
func (s *APIServer) listUser(writer http.ResponseWriter, request *http.Request) {
render.JSON(writer, request, render.M{
"users": s.user.List(),
})
}
func (s *APIServer) addUser(writer http.ResponseWriter, request *http.Request) {
var addRequest struct {
UserName string `json:"username"`
Password string `json:"uPSK"`
}
err := render.DecodeJSON(request.Body, &addRequest)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
err = s.user.Add(addRequest.UserName, addRequest.Password)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
writer.WriteHeader(http.StatusCreated)
}
func (s *APIServer) getUser(writer http.ResponseWriter, request *http.Request) {
userName := chi.URLParam(request, "username")
if userName == "" {
writer.WriteHeader(http.StatusBadRequest)
return
}
uPSK, loaded := s.user.Get(userName)
if !loaded {
writer.WriteHeader(http.StatusNotFound)
return
}
user := UserObject{
UserName: userName,
Password: uPSK,
}
s.traffic.ReadUser(&user)
render.JSON(writer, request, user)
}
func (s *APIServer) updateUser(writer http.ResponseWriter, request *http.Request) {
userName := chi.URLParam(request, "username")
if userName == "" {
writer.WriteHeader(http.StatusBadRequest)
return
}
var updateRequest struct {
Password string `json:"uPSK"`
}
err := render.DecodeJSON(request.Body, &updateRequest)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
_, loaded := s.user.Get(userName)
if !loaded {
writer.WriteHeader(http.StatusNotFound)
return
}
err = s.user.Update(userName, updateRequest.Password)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
writer.WriteHeader(http.StatusNoContent)
}
func (s *APIServer) deleteUser(writer http.ResponseWriter, request *http.Request) {
userName := chi.URLParam(request, "username")
if userName == "" {
writer.WriteHeader(http.StatusBadRequest)
return
}
_, loaded := s.user.Get(userName)
if !loaded {
writer.WriteHeader(http.StatusNotFound)
return
}
err := s.user.Delete(userName)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
writer.WriteHeader(http.StatusNoContent)
}
func (s *APIServer) getStats(writer http.ResponseWriter, request *http.Request) {
requireClear := request.URL.Query().Get("clear") == "true"
users := s.user.List()
s.traffic.ReadUsers(users, requireClear)
for i := range users {
users[i].Password = ""
}
uplinkBytes, downlinkBytes, uplinkPackets, downlinkPackets, tcpSessions, udpSessions := s.traffic.ReadGlobal(requireClear)
render.JSON(writer, request, render.M{
"uplinkBytes": uplinkBytes,
"downlinkBytes": downlinkBytes,
"uplinkPackets": uplinkPackets,
"downlinkPackets": downlinkPackets,
"tcpSessions": tcpSessions,
"udpSessions": udpSessions,
"users": users,
})
}

239
service/ssmapi/cache.go Normal file
View File

@@ -0,0 +1,239 @@
package ssmapi
import (
"bytes"
"os"
"path/filepath"
"sort"
"sync/atomic"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
"github.com/sagernet/sing/service/filemanager"
)
type Cache struct {
Endpoints *badjson.TypedMap[string, *EndpointCache] `json:"endpoints"`
}
type EndpointCache struct {
GlobalUplink int64 `json:"global_uplink"`
GlobalDownlink int64 `json:"global_downlink"`
GlobalUplinkPackets int64 `json:"global_uplink_packets"`
GlobalDownlinkPackets int64 `json:"global_downlink_packets"`
GlobalTCPSessions int64 `json:"global_tcp_sessions"`
GlobalUDPSessions int64 `json:"global_udp_sessions"`
UserUplink *badjson.TypedMap[string, int64] `json:"user_uplink"`
UserDownlink *badjson.TypedMap[string, int64] `json:"user_downlink"`
UserUplinkPackets *badjson.TypedMap[string, int64] `json:"user_uplink_packets"`
UserDownlinkPackets *badjson.TypedMap[string, int64] `json:"user_downlink_packets"`
UserTCPSessions *badjson.TypedMap[string, int64] `json:"user_tcp_sessions"`
UserUDPSessions *badjson.TypedMap[string, int64] `json:"user_udp_sessions"`
Users *badjson.TypedMap[string, string] `json:"users"`
}
func (s *Service) loadCache() error {
if s.cachePath == "" {
return nil
}
basePath := filemanager.BasePath(s.ctx, s.cachePath)
cacheBinary, err := os.ReadFile(basePath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
err = s.decodeCache(cacheBinary)
if err != nil {
os.RemoveAll(basePath)
return err
}
s.cacheMutex.Lock()
s.lastSavedCache = cacheBinary
s.cacheMutex.Unlock()
return nil
}
func (s *Service) saveCache() error {
if s.cachePath == "" {
return nil
}
cacheBinary, err := s.encodeCache()
if err != nil {
return err
}
s.cacheMutex.Lock()
defer s.cacheMutex.Unlock()
if bytes.Equal(s.lastSavedCache, cacheBinary) {
return nil
}
return s.writeCache(cacheBinary)
}
func (s *Service) writeCache(cacheBinary []byte) error {
basePath := filemanager.BasePath(s.ctx, s.cachePath)
err := os.MkdirAll(filepath.Dir(basePath), 0o777)
if err != nil {
return err
}
err = os.WriteFile(basePath, cacheBinary, 0o644)
if err != nil {
return err
}
s.lastSavedCache = cacheBinary
return nil
}
func (s *Service) decodeCache(cacheBinary []byte) error {
if len(cacheBinary) == 0 {
return nil
}
cache, err := json.UnmarshalExtended[*Cache](cacheBinary)
if err != nil {
return err
}
if cache.Endpoints == nil || cache.Endpoints.Size() == 0 {
return nil
}
for _, entry := range cache.Endpoints.Entries() {
trafficManager, loaded := s.traffics[entry.Key]
if !loaded {
continue
}
trafficManager.globalUplink.Store(entry.Value.GlobalUplink)
trafficManager.globalDownlink.Store(entry.Value.GlobalDownlink)
trafficManager.globalUplinkPackets.Store(entry.Value.GlobalUplinkPackets)
trafficManager.globalDownlinkPackets.Store(entry.Value.GlobalDownlinkPackets)
trafficManager.globalTCPSessions.Store(entry.Value.GlobalTCPSessions)
trafficManager.globalUDPSessions.Store(entry.Value.GlobalUDPSessions)
trafficManager.userUplink = typedAtomicInt64Map(entry.Value.UserUplink)
trafficManager.userDownlink = typedAtomicInt64Map(entry.Value.UserDownlink)
trafficManager.userUplinkPackets = typedAtomicInt64Map(entry.Value.UserUplinkPackets)
trafficManager.userDownlinkPackets = typedAtomicInt64Map(entry.Value.UserDownlinkPackets)
trafficManager.userTCPSessions = typedAtomicInt64Map(entry.Value.UserTCPSessions)
trafficManager.userUDPSessions = typedAtomicInt64Map(entry.Value.UserUDPSessions)
userManager, loaded := s.users[entry.Key]
if !loaded {
continue
}
userManager.usersMap = typedMap(entry.Value.Users)
_ = userManager.postUpdate(false)
}
return nil
}
func (s *Service) encodeCache() ([]byte, error) {
endpoints := new(badjson.TypedMap[string, *EndpointCache])
for tag, traffic := range s.traffics {
var (
userUplink = new(badjson.TypedMap[string, int64])
userDownlink = new(badjson.TypedMap[string, int64])
userUplinkPackets = new(badjson.TypedMap[string, int64])
userDownlinkPackets = new(badjson.TypedMap[string, int64])
userTCPSessions = new(badjson.TypedMap[string, int64])
userUDPSessions = new(badjson.TypedMap[string, int64])
userMap = new(badjson.TypedMap[string, string])
)
for user, uplink := range traffic.userUplink {
if uplink.Load() > 0 {
userUplink.Put(user, uplink.Load())
}
}
for user, downlink := range traffic.userDownlink {
if downlink.Load() > 0 {
userDownlink.Put(user, downlink.Load())
}
}
for user, uplinkPackets := range traffic.userUplinkPackets {
if uplinkPackets.Load() > 0 {
userUplinkPackets.Put(user, uplinkPackets.Load())
}
}
for user, downlinkPackets := range traffic.userDownlinkPackets {
if downlinkPackets.Load() > 0 {
userDownlinkPackets.Put(user, downlinkPackets.Load())
}
}
for user, tcpSessions := range traffic.userTCPSessions {
if tcpSessions.Load() > 0 {
userTCPSessions.Put(user, tcpSessions.Load())
}
}
for user, udpSessions := range traffic.userUDPSessions {
if udpSessions.Load() > 0 {
userUDPSessions.Put(user, udpSessions.Load())
}
}
userManager := s.users[tag]
if userManager != nil && len(userManager.usersMap) > 0 {
userMap = new(badjson.TypedMap[string, string])
for username, password := range userManager.usersMap {
if username != "" && password != "" {
userMap.Put(username, password)
}
}
}
endpoints.Put(tag, &EndpointCache{
GlobalUplink: traffic.globalUplink.Load(),
GlobalDownlink: traffic.globalDownlink.Load(),
GlobalUplinkPackets: traffic.globalUplinkPackets.Load(),
GlobalDownlinkPackets: traffic.globalDownlinkPackets.Load(),
GlobalTCPSessions: traffic.globalTCPSessions.Load(),
GlobalUDPSessions: traffic.globalUDPSessions.Load(),
UserUplink: sortTypedMap(userUplink),
UserDownlink: sortTypedMap(userDownlink),
UserUplinkPackets: sortTypedMap(userUplinkPackets),
UserDownlinkPackets: sortTypedMap(userDownlinkPackets),
UserTCPSessions: sortTypedMap(userTCPSessions),
UserUDPSessions: sortTypedMap(userUDPSessions),
Users: sortTypedMap(userMap),
})
}
var buffer bytes.Buffer
encoder := json.NewEncoder(&buffer)
encoder.SetIndent("", " ")
err := encoder.Encode(&Cache{
Endpoints: sortTypedMap(endpoints),
})
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
}
func sortTypedMap[T comparable](trafficMap *badjson.TypedMap[string, T]) *badjson.TypedMap[string, T] {
if trafficMap == nil {
return nil
}
keys := trafficMap.Keys()
sort.Strings(keys)
sortedMap := new(badjson.TypedMap[string, T])
for _, key := range keys {
value, _ := trafficMap.Get(key)
sortedMap.Put(key, value)
}
return sortedMap
}
func typedAtomicInt64Map(trafficMap *badjson.TypedMap[string, int64]) map[string]*atomic.Int64 {
result := make(map[string]*atomic.Int64)
if trafficMap != nil {
for _, entry := range trafficMap.Entries() {
counter := new(atomic.Int64)
counter.Store(entry.Value)
result[entry.Key] = counter
}
}
return result
}
func typedMap[T comparable](trafficMap *badjson.TypedMap[string, T]) map[string]T {
result := make(map[string]T)
if trafficMap != nil {
for _, entry := range trafficMap.Entries() {
result[entry.Key] = entry.Value
}
}
return result
}

163
service/ssmapi/server.go Normal file
View File

@@ -0,0 +1,163 @@
package ssmapi
import (
"context"
"errors"
"net/http"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/sagernet/sing/service"
"github.com/go-chi/chi/v5"
"golang.org/x/net/http2"
)
func RegisterService(registry *boxService.Registry) {
boxService.Register[option.SSMAPIServiceOptions](registry, C.TypeSSMAPI, NewService)
}
type Service struct {
boxService.Adapter
ctx context.Context
cancel context.CancelFunc
logger log.ContextLogger
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
traffics map[string]*TrafficManager
users map[string]*UserManager
cachePath string
saveTicker *time.Ticker
lastSavedCache []byte
cacheMutex sync.Mutex
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.SSMAPIServiceOptions) (adapter.Service, error) {
ctx, cancel := context.WithCancel(ctx)
chiRouter := chi.NewRouter()
s := &Service{
Adapter: boxService.NewAdapter(C.TypeSSMAPI, tag),
ctx: ctx,
cancel: cancel,
logger: logger,
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
httpServer: &http.Server{
Handler: chiRouter,
},
traffics: make(map[string]*TrafficManager),
users: make(map[string]*UserManager),
cachePath: options.CachePath,
}
inboundManager := service.FromContext[adapter.InboundManager](ctx)
if options.Servers.Size() == 0 {
return nil, E.New("missing servers")
}
for i, entry := range options.Servers.Entries() {
inbound, loaded := inboundManager.Get(entry.Value)
if !loaded {
return nil, E.New("parse SSM server[", i, "]: inbound ", entry.Value, " not found")
}
managedServer, isManaged := inbound.(adapter.ManagedSSMServer)
if !isManaged {
return nil, E.New("parse SSM server[", i, "]: inbound/", inbound.Type(), "[", inbound.Tag(), "] is not a SSM server")
}
traffic := NewTrafficManager()
managedServer.SetTracker(traffic)
user := NewUserManager(managedServer, traffic)
chiRouter.Route(entry.Key, NewAPIServer(logger, traffic, user).Route)
s.traffics[entry.Key] = traffic
s.users[entry.Key] = user
}
if options.TLS != nil {
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
s.tlsConfig = tlsConfig
}
return s, nil
}
func (s *Service) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := s.loadCache()
if err != nil {
s.logger.Error(E.Cause(err, "load cache"))
}
s.saveTicker = time.NewTicker(1 * time.Minute)
go s.loopSaveCache()
if s.tlsConfig != nil {
err = s.tlsConfig.Start()
if err != nil {
return E.Cause(err, "create TLS config")
}
}
tcpListener, err := s.listener.ListenTCP()
if err != nil {
return err
}
if s.tlsConfig != nil {
if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
}
tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
}
go func() {
err = s.httpServer.Serve(tcpListener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Error("serve error: ", err)
}
}()
return nil
}
func (s *Service) loopSaveCache() {
for {
select {
case <-s.ctx.Done():
return
case <-s.saveTicker.C:
err := s.saveCache()
if err != nil {
s.logger.Error(E.Cause(err, "save cache"))
}
}
}
}
func (s *Service) Close() error {
if s.cancel != nil {
s.cancel()
}
if s.saveTicker != nil {
s.saveTicker.Stop()
}
err := s.saveCache()
if err != nil {
s.logger.Error(E.Cause(err, "save cache"))
}
return common.Close(
common.PtrOrNil(s.httpServer),
common.PtrOrNil(s.listener),
s.tlsConfig,
)
}

223
service/ssmapi/traffic.go Normal file
View File

@@ -0,0 +1,223 @@
package ssmapi
import (
"net"
"sync"
"sync/atomic"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
var _ adapter.SSMTracker = (*TrafficManager)(nil)
type TrafficManager struct {
globalUplink atomic.Int64
globalDownlink atomic.Int64
globalUplinkPackets atomic.Int64
globalDownlinkPackets atomic.Int64
globalTCPSessions atomic.Int64
globalUDPSessions atomic.Int64
userAccess sync.Mutex
userUplink map[string]*atomic.Int64
userDownlink map[string]*atomic.Int64
userUplinkPackets map[string]*atomic.Int64
userDownlinkPackets map[string]*atomic.Int64
userTCPSessions map[string]*atomic.Int64
userUDPSessions map[string]*atomic.Int64
}
func NewTrafficManager() *TrafficManager {
manager := &TrafficManager{
userUplink: make(map[string]*atomic.Int64),
userDownlink: make(map[string]*atomic.Int64),
userUplinkPackets: make(map[string]*atomic.Int64),
userDownlinkPackets: make(map[string]*atomic.Int64),
userTCPSessions: make(map[string]*atomic.Int64),
userUDPSessions: make(map[string]*atomic.Int64),
}
return manager
}
func (s *TrafficManager) UpdateUsers(users []string) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
newUserUplink := make(map[string]*atomic.Int64)
newUserDownlink := make(map[string]*atomic.Int64)
newUserUplinkPackets := make(map[string]*atomic.Int64)
newUserDownlinkPackets := make(map[string]*atomic.Int64)
newUserTCPSessions := make(map[string]*atomic.Int64)
newUserUDPSessions := make(map[string]*atomic.Int64)
for _, user := range users {
if counter, loaded := s.userUplink[user]; loaded {
newUserUplink[user] = counter
}
if counter, loaded := s.userDownlink[user]; loaded {
newUserDownlink[user] = counter
}
if counter, loaded := s.userUplinkPackets[user]; loaded {
newUserUplinkPackets[user] = counter
}
if counter, loaded := s.userDownlinkPackets[user]; loaded {
newUserDownlinkPackets[user] = counter
}
if counter, loaded := s.userTCPSessions[user]; loaded {
newUserTCPSessions[user] = counter
}
if counter, loaded := s.userUDPSessions[user]; loaded {
newUserUDPSessions[user] = counter
}
}
s.userUplink = newUserUplink
s.userDownlink = newUserDownlink
s.userUplinkPackets = newUserUplinkPackets
s.userDownlinkPackets = newUserDownlinkPackets
s.userTCPSessions = newUserTCPSessions
s.userUDPSessions = newUserUDPSessions
}
func (s *TrafficManager) userCounter(user string) (*atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
upCounter, loaded := s.userUplink[user]
if !loaded {
upCounter = new(atomic.Int64)
s.userUplink[user] = upCounter
}
downCounter, loaded := s.userDownlink[user]
if !loaded {
downCounter = new(atomic.Int64)
s.userDownlink[user] = downCounter
}
upPacketsCounter, loaded := s.userUplinkPackets[user]
if !loaded {
upPacketsCounter = new(atomic.Int64)
s.userUplinkPackets[user] = upPacketsCounter
}
downPacketsCounter, loaded := s.userDownlinkPackets[user]
if !loaded {
downPacketsCounter = new(atomic.Int64)
s.userDownlinkPackets[user] = downPacketsCounter
}
tcpSessionsCounter, loaded := s.userTCPSessions[user]
if !loaded {
tcpSessionsCounter = new(atomic.Int64)
s.userTCPSessions[user] = tcpSessionsCounter
}
udpSessionsCounter, loaded := s.userUDPSessions[user]
if !loaded {
udpSessionsCounter = new(atomic.Int64)
s.userUDPSessions[user] = udpSessionsCounter
}
return upCounter, downCounter, upPacketsCounter, downPacketsCounter, tcpSessionsCounter, udpSessionsCounter
}
func (s *TrafficManager) TrackConnection(conn net.Conn, metadata adapter.InboundContext) net.Conn {
s.globalTCPSessions.Add(1)
var readCounter []*atomic.Int64
var writeCounter []*atomic.Int64
readCounter = append(readCounter, &s.globalUplink)
writeCounter = append(writeCounter, &s.globalDownlink)
upCounter, downCounter, _, _, tcpSessionCounter, _ := s.userCounter(metadata.User)
readCounter = append(readCounter, upCounter)
writeCounter = append(writeCounter, downCounter)
tcpSessionCounter.Add(1)
return bufio.NewInt64CounterConn(conn, readCounter, writeCounter)
}
func (s *TrafficManager) TrackPacketConnection(conn N.PacketConn, metadata adapter.InboundContext) N.PacketConn {
s.globalUDPSessions.Add(1)
var readCounter []*atomic.Int64
var readPacketCounter []*atomic.Int64
var writeCounter []*atomic.Int64
var writePacketCounter []*atomic.Int64
readCounter = append(readCounter, &s.globalUplink)
writeCounter = append(writeCounter, &s.globalDownlink)
readPacketCounter = append(readPacketCounter, &s.globalUplinkPackets)
writePacketCounter = append(writePacketCounter, &s.globalDownlinkPackets)
upCounter, downCounter, upPacketsCounter, downPacketsCounter, _, udpSessionCounter := s.userCounter(metadata.User)
readCounter = append(readCounter, upCounter)
writeCounter = append(writeCounter, downCounter)
readPacketCounter = append(readPacketCounter, upPacketsCounter)
writePacketCounter = append(writePacketCounter, downPacketsCounter)
udpSessionCounter.Add(1)
return bufio.NewInt64CounterPacketConn(conn, readCounter, readPacketCounter, writeCounter, writePacketCounter)
}
func (s *TrafficManager) ReadUser(user *UserObject) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
s.readUser(user, false)
}
func (s *TrafficManager) readUser(user *UserObject, swap bool) {
if counter, loaded := s.userUplink[user.UserName]; loaded {
if swap {
user.UplinkBytes = counter.Swap(0)
} else {
user.UplinkBytes = counter.Load()
}
}
if counter, loaded := s.userDownlink[user.UserName]; loaded {
if swap {
user.DownlinkBytes = counter.Swap(0)
} else {
user.DownlinkBytes = counter.Load()
}
}
if counter, loaded := s.userUplinkPackets[user.UserName]; loaded {
if swap {
user.UplinkPackets = counter.Swap(0)
} else {
user.UplinkPackets = counter.Load()
}
}
if counter, loaded := s.userDownlinkPackets[user.UserName]; loaded {
if swap {
user.DownlinkPackets = counter.Swap(0)
} else {
user.DownlinkPackets = counter.Load()
}
}
if counter, loaded := s.userTCPSessions[user.UserName]; loaded {
if swap {
user.TCPSessions = counter.Swap(0)
} else {
user.TCPSessions = counter.Load()
}
}
if counter, loaded := s.userUDPSessions[user.UserName]; loaded {
if swap {
user.UDPSessions = counter.Swap(0)
} else {
user.UDPSessions = counter.Load()
}
}
}
func (s *TrafficManager) ReadUsers(users []*UserObject, swap bool) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
for _, user := range users {
s.readUser(user, swap)
}
}
func (s *TrafficManager) ReadGlobal(swap bool) (uplinkBytes int64, downlinkBytes int64, uplinkPackets int64, downlinkPackets int64, tcpSessions int64, udpSessions int64) {
if swap {
return s.globalUplink.Swap(0),
s.globalDownlink.Swap(0),
s.globalUplinkPackets.Swap(0),
s.globalDownlinkPackets.Swap(0),
s.globalTCPSessions.Swap(0),
s.globalUDPSessions.Swap(0)
} else {
return s.globalUplink.Load(),
s.globalDownlink.Load(),
s.globalUplinkPackets.Load(),
s.globalDownlinkPackets.Load(),
s.globalTCPSessions.Load(),
s.globalUDPSessions.Load()
}
}

87
service/ssmapi/user.go Normal file
View File

@@ -0,0 +1,87 @@
package ssmapi
import (
"sync"
"github.com/sagernet/sing-box/adapter"
E "github.com/sagernet/sing/common/exceptions"
)
type UserManager struct {
access sync.Mutex
usersMap map[string]string
server adapter.ManagedSSMServer
trafficManager *TrafficManager
}
func NewUserManager(inbound adapter.ManagedSSMServer, trafficManager *TrafficManager) *UserManager {
return &UserManager{
usersMap: make(map[string]string),
server: inbound,
trafficManager: trafficManager,
}
}
func (m *UserManager) postUpdate(updated bool) error {
users := make([]string, 0, len(m.usersMap))
uPSKs := make([]string, 0, len(m.usersMap))
for username, password := range m.usersMap {
users = append(users, username)
uPSKs = append(uPSKs, password)
}
err := m.server.UpdateUsers(users, uPSKs)
if err != nil {
return err
}
if updated {
m.trafficManager.UpdateUsers(users)
}
return nil
}
func (m *UserManager) List() []*UserObject {
m.access.Lock()
defer m.access.Unlock()
users := make([]*UserObject, 0, len(m.usersMap))
for username, password := range m.usersMap {
users = append(users, &UserObject{
UserName: username,
Password: password,
})
}
return users
}
func (m *UserManager) Add(username string, password string) error {
m.access.Lock()
defer m.access.Unlock()
if _, found := m.usersMap[username]; found {
return E.New("user ", username, " already exists")
}
m.usersMap[username] = password
return m.postUpdate(true)
}
func (m *UserManager) Get(username string) (string, bool) {
m.access.Lock()
defer m.access.Unlock()
if password, found := m.usersMap[username]; found {
return password, true
}
return "", false
}
func (m *UserManager) Update(username string, password string) error {
m.access.Lock()
defer m.access.Unlock()
m.usersMap[username] = password
return m.postUpdate(true)
}
func (m *UserManager) Delete(username string) error {
m.access.Lock()
defer m.access.Unlock()
delete(m.usersMap, username)
return m.postUpdate(true)
}