Fix OCM websocket proxy lifecycle and headers
This commit is contained in:
@@ -21,6 +21,19 @@ import (
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
)
|
||||
|
||||
type webSocketSession struct {
|
||||
clientConn net.Conn
|
||||
upstreamConn net.Conn
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (s *webSocketSession) Close() {
|
||||
s.closeOnce.Do(func() {
|
||||
s.clientConn.Close()
|
||||
s.upstreamConn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string {
|
||||
upstreamURL := baseURL
|
||||
if strings.HasPrefix(upstreamURL, "https://") {
|
||||
@@ -47,6 +60,22 @@ func isForwardableResponseHeader(key string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isForwardableWebSocketRequestHeader(key string) bool {
|
||||
if isHopByHopHeader(key) {
|
||||
return false
|
||||
}
|
||||
|
||||
lowerKey := strings.ToLower(key)
|
||||
switch {
|
||||
case lowerKey == "authorization":
|
||||
return false
|
||||
case strings.HasPrefix(lowerKey, "sec-websocket-"):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) {
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
@@ -61,18 +90,8 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
||||
}
|
||||
|
||||
upstreamHeaders := make(http.Header)
|
||||
forwardHeaders := []string{
|
||||
"OpenAI-Beta",
|
||||
"X-Conversation-ID",
|
||||
}
|
||||
for _, headerKey := range forwardHeaders {
|
||||
if value := r.Header.Get(headerKey); value != "" {
|
||||
upstreamHeaders.Set(headerKey, value)
|
||||
}
|
||||
}
|
||||
for key, values := range r.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if strings.HasPrefix(lowerKey, "x-codex-") || strings.HasPrefix(lowerKey, "x-responsesapi-") {
|
||||
if isForwardableWebSocketRequestHeader(key) {
|
||||
upstreamHeaders[key] = values
|
||||
}
|
||||
}
|
||||
@@ -87,8 +106,8 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
||||
|
||||
upstreamResponseHeaders := make(http.Header)
|
||||
upstreamDialer := ws.Dialer{
|
||||
NetDial: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return s.dialer.DialContext(s.ctx, network, M.ParseSocksaddr(addr))
|
||||
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
TLSConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(s.ctx),
|
||||
@@ -120,12 +139,26 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
||||
clientUpgrader := ws.HTTPUpgrader{
|
||||
Header: clientResponseHeaders,
|
||||
}
|
||||
if s.isShuttingDown() {
|
||||
upstreamConn.Close()
|
||||
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
|
||||
return
|
||||
}
|
||||
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
||||
if err != nil {
|
||||
s.logger.Error("upgrade client websocket: ", err)
|
||||
upstreamConn.Close()
|
||||
return
|
||||
}
|
||||
session := &webSocketSession{
|
||||
clientConn: clientConn,
|
||||
upstreamConn: upstreamConn,
|
||||
}
|
||||
if !s.registerWebSocketSession(session) {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
defer s.unregisterWebSocketSession(session)
|
||||
|
||||
var upstreamReadWriter io.ReadWriter
|
||||
if upstreamBufferedReader != nil {
|
||||
@@ -139,21 +172,16 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
||||
|
||||
modelChannel := make(chan string, 1)
|
||||
var waitGroup sync.WaitGroup
|
||||
var once sync.Once
|
||||
closeAll := func() {
|
||||
clientConn.Close()
|
||||
upstreamConn.Close()
|
||||
}
|
||||
|
||||
waitGroup.Add(2)
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer once.Do(closeAll)
|
||||
defer session.Close()
|
||||
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel)
|
||||
}()
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer once.Do(closeAll)
|
||||
defer session.Close()
|
||||
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint)
|
||||
}()
|
||||
waitGroup.Wait()
|
||||
|
||||
Reference in New Issue
Block a user