Fix websocket connection and goroutine leaks in Clash API
Co-authored-by: traitman <112139837+traitman@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@ package clashapi
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
@@ -27,7 +28,7 @@ func (s *Server) setupMetaAPI(r chi.Router) {
|
|||||||
})
|
})
|
||||||
r.Mount("/", middleware.Profiler())
|
r.Mount("/", middleware.Profiler())
|
||||||
}
|
}
|
||||||
r.Get("/memory", memory(s.trafficManager))
|
r.Get("/memory", memory(s.ctx, s.trafficManager))
|
||||||
r.Mount("/group", groupRouter(s))
|
r.Mount("/group", groupRouter(s))
|
||||||
r.Mount("/upgrade", upgradeRouter(s))
|
r.Mount("/upgrade", upgradeRouter(s))
|
||||||
}
|
}
|
||||||
@@ -37,7 +38,7 @@ type Memory struct {
|
|||||||
OSLimit uint64 `json:"oslimit"` // maybe we need it in the future
|
OSLimit uint64 `json:"oslimit"` // maybe we need it in the future
|
||||||
}
|
}
|
||||||
|
|
||||||
func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
|
func memory(ctx context.Context, trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
if r.Header.Get("Upgrade") == "websocket" {
|
if r.Header.Get("Upgrade") == "websocket" {
|
||||||
@@ -46,6 +47,7 @@ func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
@@ -58,7 +60,12 @@ func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r
|
|||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
var err error
|
var err error
|
||||||
first := true
|
first := true
|
||||||
for range tick.C {
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-tick.C:
|
||||||
|
}
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
|
|
||||||
inuse := trafficManager.Snapshot().Memory
|
inuse := trafficManager.Snapshot().Memory
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ func getConnections(ctx context.Context, trafficManager *trafficontrol.Manager)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
intervalStr := r.URL.Query().Get("interval")
|
intervalStr := r.URL.Query().Get("interval")
|
||||||
interval := 1000
|
interval := 1000
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op
|
|||||||
chiRouter.Group(func(r chi.Router) {
|
chiRouter.Group(func(r chi.Router) {
|
||||||
r.Use(authentication(options.Secret))
|
r.Use(authentication(options.Secret))
|
||||||
r.Get("/", hello(options.ExternalUI != ""))
|
r.Get("/", hello(options.ExternalUI != ""))
|
||||||
r.Get("/logs", getLogs(logFactory))
|
r.Get("/logs", getLogs(s.ctx, logFactory))
|
||||||
r.Get("/traffic", traffic(s.ctx, trafficManager))
|
r.Get("/traffic", traffic(s.ctx, trafficManager))
|
||||||
r.Get("/version", version)
|
r.Get("/version", version)
|
||||||
r.Mount("/configs", configRouter(s, logFactory))
|
r.Mount("/configs", configRouter(s, logFactory))
|
||||||
@@ -360,7 +360,7 @@ type Log struct {
|
|||||||
Payload string `json:"payload"`
|
Payload string `json:"payload"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *http.Request) {
|
func getLogs(ctx context.Context, logFactory log.ObservableFactory) func(w http.ResponseWriter, r *http.Request) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
levelText := r.URL.Query().Get("level")
|
levelText := r.URL.Query().Get("level")
|
||||||
if levelText == "" {
|
if levelText == "" {
|
||||||
@@ -399,6 +399,8 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht
|
|||||||
var logEntry log.Entry
|
var logEntry log.Entry
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
case <-done:
|
case <-done:
|
||||||
return
|
return
|
||||||
case logEntry = <-subscription:
|
case logEntry = <-subscription:
|
||||||
|
|||||||
Reference in New Issue
Block a user