Fix OCM websocket proxy lifecycle and headers
This commit is contained in:
@@ -139,7 +139,9 @@ type Service struct {
|
|||||||
userManager *UserManager
|
userManager *UserManager
|
||||||
accessMutex sync.RWMutex
|
accessMutex sync.RWMutex
|
||||||
usageTracker *AggregatedUsage
|
usageTracker *AggregatedUsage
|
||||||
trackingGroup sync.WaitGroup
|
webSocketMutex sync.Mutex
|
||||||
|
webSocketGroup sync.WaitGroup
|
||||||
|
webSocketConns map[*webSocketSession]struct{}
|
||||||
shuttingDown bool
|
shuttingDown bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -197,8 +199,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
|||||||
Network: []string{N.NetworkTCP},
|
Network: []string{N.NetworkTCP},
|
||||||
Listen: options.ListenOptions,
|
Listen: options.ListenOptions,
|
||||||
}),
|
}),
|
||||||
userManager: userManager,
|
userManager: userManager,
|
||||||
usageTracker: usageTracker,
|
usageTracker: usageTracker,
|
||||||
|
webSocketConns: make(map[*webSocketSession]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.TLS != nil {
|
if options.TLS != nil {
|
||||||
@@ -631,11 +634,17 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Close() error {
|
func (s *Service) Close() error {
|
||||||
|
webSocketSessions := s.startWebSocketShutdown()
|
||||||
|
|
||||||
err := common.Close(
|
err := common.Close(
|
||||||
common.PtrOrNil(s.httpServer),
|
common.PtrOrNil(s.httpServer),
|
||||||
common.PtrOrNil(s.listener),
|
common.PtrOrNil(s.listener),
|
||||||
s.tlsConfig,
|
s.tlsConfig,
|
||||||
)
|
)
|
||||||
|
for _, session := range webSocketSessions {
|
||||||
|
session.Close()
|
||||||
|
}
|
||||||
|
s.webSocketGroup.Wait()
|
||||||
|
|
||||||
if s.usageTracker != nil {
|
if s.usageTracker != nil {
|
||||||
s.usageTracker.cancelPendingSave()
|
s.usageTracker.cancelPendingSave()
|
||||||
@@ -647,3 +656,48 @@ func (s *Service) Close() error {
|
|||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
|
||||||
|
s.webSocketMutex.Lock()
|
||||||
|
defer s.webSocketMutex.Unlock()
|
||||||
|
|
||||||
|
if s.shuttingDown {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.webSocketConns[session] = struct{}{}
|
||||||
|
s.webSocketGroup.Add(1)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
|
||||||
|
s.webSocketMutex.Lock()
|
||||||
|
_, loaded := s.webSocketConns[session]
|
||||||
|
if loaded {
|
||||||
|
delete(s.webSocketConns, session)
|
||||||
|
}
|
||||||
|
s.webSocketMutex.Unlock()
|
||||||
|
|
||||||
|
if loaded {
|
||||||
|
s.webSocketGroup.Done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) isShuttingDown() bool {
|
||||||
|
s.webSocketMutex.Lock()
|
||||||
|
defer s.webSocketMutex.Unlock()
|
||||||
|
return s.shuttingDown
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) startWebSocketShutdown() []*webSocketSession {
|
||||||
|
s.webSocketMutex.Lock()
|
||||||
|
defer s.webSocketMutex.Unlock()
|
||||||
|
|
||||||
|
s.shuttingDown = true
|
||||||
|
|
||||||
|
webSocketSessions := make([]*webSocketSession, 0, len(s.webSocketConns))
|
||||||
|
for session := range s.webSocketConns {
|
||||||
|
webSocketSessions = append(webSocketSessions, session)
|
||||||
|
}
|
||||||
|
return webSocketSessions
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,19 @@ import (
|
|||||||
"github.com/openai/openai-go/v3/responses"
|
"github.com/openai/openai-go/v3/responses"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type webSocketSession struct {
|
||||||
|
clientConn net.Conn
|
||||||
|
upstreamConn net.Conn
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webSocketSession) Close() {
|
||||||
|
s.closeOnce.Do(func() {
|
||||||
|
s.clientConn.Close()
|
||||||
|
s.upstreamConn.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string {
|
func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string {
|
||||||
upstreamURL := baseURL
|
upstreamURL := baseURL
|
||||||
if strings.HasPrefix(upstreamURL, "https://") {
|
if strings.HasPrefix(upstreamURL, "https://") {
|
||||||
@@ -47,6 +60,22 @@ func isForwardableResponseHeader(key string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isForwardableWebSocketRequestHeader(key string) bool {
|
||||||
|
if isHopByHopHeader(key) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerKey := strings.ToLower(key)
|
||||||
|
switch {
|
||||||
|
case lowerKey == "authorization":
|
||||||
|
return false
|
||||||
|
case strings.HasPrefix(lowerKey, "sec-websocket-"):
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) {
|
func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) {
|
||||||
accessToken, err := s.getAccessToken()
|
accessToken, err := s.getAccessToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -61,18 +90,8 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
|||||||
}
|
}
|
||||||
|
|
||||||
upstreamHeaders := make(http.Header)
|
upstreamHeaders := make(http.Header)
|
||||||
forwardHeaders := []string{
|
|
||||||
"OpenAI-Beta",
|
|
||||||
"X-Conversation-ID",
|
|
||||||
}
|
|
||||||
for _, headerKey := range forwardHeaders {
|
|
||||||
if value := r.Header.Get(headerKey); value != "" {
|
|
||||||
upstreamHeaders.Set(headerKey, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for key, values := range r.Header {
|
for key, values := range r.Header {
|
||||||
lowerKey := strings.ToLower(key)
|
if isForwardableWebSocketRequestHeader(key) {
|
||||||
if strings.HasPrefix(lowerKey, "x-codex-") || strings.HasPrefix(lowerKey, "x-responsesapi-") {
|
|
||||||
upstreamHeaders[key] = values
|
upstreamHeaders[key] = values
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,8 +106,8 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
|||||||
|
|
||||||
upstreamResponseHeaders := make(http.Header)
|
upstreamResponseHeaders := make(http.Header)
|
||||||
upstreamDialer := ws.Dialer{
|
upstreamDialer := ws.Dialer{
|
||||||
NetDial: func(_ context.Context, network, addr string) (net.Conn, error) {
|
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
return s.dialer.DialContext(s.ctx, network, M.ParseSocksaddr(addr))
|
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||||
},
|
},
|
||||||
TLSConfig: &stdTLS.Config{
|
TLSConfig: &stdTLS.Config{
|
||||||
RootCAs: adapter.RootPoolFromContext(s.ctx),
|
RootCAs: adapter.RootPoolFromContext(s.ctx),
|
||||||
@@ -120,12 +139,26 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
|||||||
clientUpgrader := ws.HTTPUpgrader{
|
clientUpgrader := ws.HTTPUpgrader{
|
||||||
Header: clientResponseHeaders,
|
Header: clientResponseHeaders,
|
||||||
}
|
}
|
||||||
|
if s.isShuttingDown() {
|
||||||
|
upstreamConn.Close()
|
||||||
|
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
|
||||||
|
return
|
||||||
|
}
|
||||||
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("upgrade client websocket: ", err)
|
s.logger.Error("upgrade client websocket: ", err)
|
||||||
upstreamConn.Close()
|
upstreamConn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
session := &webSocketSession{
|
||||||
|
clientConn: clientConn,
|
||||||
|
upstreamConn: upstreamConn,
|
||||||
|
}
|
||||||
|
if !s.registerWebSocketSession(session) {
|
||||||
|
session.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer s.unregisterWebSocketSession(session)
|
||||||
|
|
||||||
var upstreamReadWriter io.ReadWriter
|
var upstreamReadWriter io.ReadWriter
|
||||||
if upstreamBufferedReader != nil {
|
if upstreamBufferedReader != nil {
|
||||||
@@ -139,21 +172,16 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
|||||||
|
|
||||||
modelChannel := make(chan string, 1)
|
modelChannel := make(chan string, 1)
|
||||||
var waitGroup sync.WaitGroup
|
var waitGroup sync.WaitGroup
|
||||||
var once sync.Once
|
|
||||||
closeAll := func() {
|
|
||||||
clientConn.Close()
|
|
||||||
upstreamConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
waitGroup.Add(2)
|
waitGroup.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
defer waitGroup.Done()
|
defer waitGroup.Done()
|
||||||
defer once.Do(closeAll)
|
defer session.Close()
|
||||||
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel)
|
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer waitGroup.Done()
|
defer waitGroup.Done()
|
||||||
defer once.Do(closeAll)
|
defer session.Close()
|
||||||
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint)
|
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint)
|
||||||
}()
|
}()
|
||||||
waitGroup.Wait()
|
waitGroup.Wait()
|
||||||
|
|||||||
Reference in New Issue
Block a user