diff --git a/experimental/clashapi/connections.go b/experimental/clashapi/connections.go index 999d5898..5074adf7 100644 --- a/experimental/clashapi/connections.go +++ b/experimental/clashapi/connections.go @@ -2,6 +2,7 @@ package clashapi import ( "bytes" + "context" "net/http" "strconv" "time" @@ -17,15 +18,15 @@ import ( "github.com/gofrs/uuid/v5" ) -func connectionRouter(router adapter.Router, trafficManager *trafficontrol.Manager) http.Handler { +func connectionRouter(ctx context.Context, router adapter.Router, trafficManager *trafficontrol.Manager) http.Handler { r := chi.NewRouter() - r.Get("/", getConnections(trafficManager)) + r.Get("/", getConnections(ctx, trafficManager)) r.Delete("/", closeAllConnections(router, trafficManager)) r.Delete("/{id}", closeConnection(trafficManager)) return r } -func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) { +func getConnections(ctx context.Context, trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") != "websocket" { snapshot := trafficManager.Snapshot() @@ -67,7 +68,12 @@ func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseW tick := time.NewTicker(time.Millisecond * time.Duration(interval)) defer tick.Stop() - for range tick.C { + for { + select { + case <-ctx.Done(): + return + case <-tick.C: + } if err = sendSnapshot(); err != nil { break } diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index e71031dc..c3661182 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -116,12 +116,12 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op r.Use(authentication(options.Secret)) r.Get("/", hello(options.ExternalUI != "")) r.Get("/logs", getLogs(logFactory)) - r.Get("/traffic", traffic(trafficManager)) + r.Get("/traffic", traffic(s.ctx, trafficManager)) r.Get("/version", version) r.Mount("/configs", configRouter(s, logFactory)) r.Mount("/proxies", proxyRouter(s, s.router)) r.Mount("/rules", ruleRouter(s.router)) - r.Mount("/connections", connectionRouter(s.router, trafficManager)) + r.Mount("/connections", connectionRouter(s.ctx, s.router, trafficManager)) r.Mount("/providers/proxies", proxyProviderRouter()) r.Mount("/providers/rules", ruleProviderRouter()) r.Mount("/script", scriptRouter()) @@ -303,7 +303,7 @@ type Traffic struct { Down int64 `json:"down"` } -func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) { +func traffic(ctx context.Context, trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var conn net.Conn if r.Header.Get("Upgrade") == "websocket" { @@ -324,7 +324,12 @@ func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, defer tick.Stop() buf := &bytes.Buffer{} uploadTotal, downloadTotal := trafficManager.Total() - for range tick.C { + for { + select { + case <-ctx.Done(): + return + case <-tick.C: + } buf.Reset() uploadTotalNew, downloadTotalNew := trafficManager.Total() err := json.NewEncoder(buf).Encode(Traffic{