diff --git a/docs/configuration/service/ocm.md b/docs/configuration/service/ocm.md index 59dba7da..ce6d50fc 100644 --- a/docs/configuration/service/ocm.md +++ b/docs/configuration/service/ocm.md @@ -114,14 +114,18 @@ Add to `~/.codex/config.toml`: [model_providers.ocm] name = "OCM Proxy" base_url = "http://127.0.0.1:8080/v1" -wire_api = "responses" -requires_openai_auth = false +supports_websockets = true + +[profiles.ocm] +model_provider = "ocm" +# model = "gpt-5.4" # if the latest model is not yet publicly released +# model_reasoning_effort = "xhigh" ``` Then run: ```bash -codex --model-provider ocm +codex --profile ocm ``` ### Example with Authentication @@ -159,13 +163,17 @@ Add to `~/.codex/config.toml`: [model_providers.ocm] name = "OCM Proxy" base_url = "http://127.0.0.1:8080/v1" -wire_api = "responses" -requires_openai_auth = false +supports_websockets = true experimental_bearer_token = "sk-alice-secret-token" + +[profiles.ocm] +model_provider = "ocm" +# model = "gpt-5.4" # if the latest model is not yet publicly released +# model_reasoning_effort = "xhigh" ``` Then run: ```bash -codex --model-provider ocm +codex --profile ocm ``` diff --git a/docs/configuration/service/ocm.zh.md b/docs/configuration/service/ocm.zh.md index ee1d8510..016990ff 100644 --- a/docs/configuration/service/ocm.zh.md +++ b/docs/configuration/service/ocm.zh.md @@ -114,14 +114,18 @@ TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。 [model_providers.ocm] name = "OCM Proxy" base_url = "http://127.0.0.1:8080/v1" -wire_api = "responses" -requires_openai_auth = false +supports_websockets = true + +[profiles.ocm] +model_provider = "ocm" +# model = "gpt-5.4" # 如果最新模型尚未公开发布 +# model_reasoning_effort = "xhigh" ``` 然后运行: ```bash -codex --model-provider ocm +codex --profile ocm ``` ### 带身份验证的示例 @@ -159,13 +163,17 @@ codex --model-provider ocm [model_providers.ocm] name = "OCM Proxy" base_url = "http://127.0.0.1:8080/v1" -wire_api = "responses" -requires_openai_auth = false +supports_websockets = true experimental_bearer_token = "sk-alice-secret-token" + +[profiles.ocm] +model_provider = "ocm" +# model = "gpt-5.4" # 如果最新模型尚未公开发布 +# model_reasoning_effort = "xhigh" ``` 然后运行: ```bash -codex --model-provider ocm +codex --profile ocm ``` diff --git a/go.mod b/go.mod index 0f6661ab..43751f29 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/metacubex/utls v1.8.4 github.com/mholt/acmez/v3 v3.1.6 github.com/miekg/dns v1.1.72 - github.com/openai/openai-go/v3 v3.24.0 + github.com/openai/openai-go/v3 v3.26.0 github.com/oschwald/maxminddb-golang v1.13.1 github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1 github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a diff --git a/go.sum b/go.sum index 5d5c3760..7cd04d80 100644 --- a/go.sum +++ b/go.sum @@ -138,8 +138,8 @@ github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= -github.com/openai/openai-go/v3 v3.24.0 h1:08x6GnYiB+AAejTo6yzPY8RkZMJQ8NpreiOyM5QfyYU= -github.com/openai/openai-go/v3 v3.24.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/openai/openai-go/v3 v3.26.0 h1:bRt6H/ozMNt/dDkN4gobnLqaEGrRGBzmbVs0xxJEnQE= +github.com/openai/openai-go/v3 v3.26.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= diff --git a/service/ocm/service.go b/service/ocm/service.go index 2354d159..e05c9547 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -130,6 +130,7 @@ type Service struct { credentialPath string credentials *oauthCredentials users []option.OCMUser + dialer N.Dialer httpClient *http.Client httpHeaders http.Header listener *listener.Listener @@ -187,6 +188,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio logger: logger, credentialPath: options.CredentialPath, users: options.Users, + dialer: serviceDialer, httpClient: httpClient, httpHeaders: options.Headers.Build(), listener: listener.New(listener.Options{ @@ -356,6 +358,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { + s.handleWebSocket(w, r, proxyPath, username) + return + } + var requestModel string if s.usageTracker != nil && r.Body != nil { diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go new file mode 100644 index 00000000..5e8cb8bb --- /dev/null +++ b/service/ocm/service_websocket.go @@ -0,0 +1,255 @@ +package ocm + +import ( + "context" + stdTLS "crypto/tls" + "encoding/json" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/ws" + "github.com/sagernet/ws/wsutil" + + "github.com/openai/openai-go/v3/responses" +) + +func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string { + upstreamURL := baseURL + if strings.HasPrefix(upstreamURL, "https://") { + upstreamURL = "wss://" + upstreamURL[len("https://"):] + } else if strings.HasPrefix(upstreamURL, "http://") { + upstreamURL = "ws://" + upstreamURL[len("http://"):] + } + return upstreamURL + proxyPath +} + +func isForwardableResponseHeader(key string) bool { + lowerKey := strings.ToLower(key) + switch { + case strings.HasPrefix(lowerKey, "x-codex-"): + return true + case strings.HasPrefix(lowerKey, "x-reasoning"): + return true + case lowerKey == "openai-model": + return true + case strings.Contains(lowerKey, "-secondary-"): + return true + default: + return false + } +} + +func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) { + accessToken, err := s.getAccessToken() + if err != nil { + s.logger.Error("get access token for websocket: ", err) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") + return + } + + upstreamURL := buildUpstreamWebSocketURL(s.getBaseURL(), proxyPath) + if r.URL.RawQuery != "" { + upstreamURL += "?" + r.URL.RawQuery + } + + upstreamHeaders := make(http.Header) + forwardHeaders := []string{ + "OpenAI-Beta", + "X-Conversation-ID", + } + for _, headerKey := range forwardHeaders { + if value := r.Header.Get(headerKey); value != "" { + upstreamHeaders.Set(headerKey, value) + } + } + for key, values := range r.Header { + lowerKey := strings.ToLower(key) + if strings.HasPrefix(lowerKey, "x-codex-") || strings.HasPrefix(lowerKey, "x-responsesapi-") { + upstreamHeaders[key] = values + } + } + for key, values := range s.httpHeaders { + upstreamHeaders.Del(key) + upstreamHeaders[key] = values + } + upstreamHeaders.Set("Authorization", "Bearer "+accessToken) + if accountID := s.getAccountID(); accountID != "" { + upstreamHeaders.Set("ChatGPT-Account-Id", accountID) + } + + upstreamResponseHeaders := make(http.Header) + upstreamDialer := ws.Dialer{ + NetDial: func(_ context.Context, network, addr string) (net.Conn, error) { + return s.dialer.DialContext(s.ctx, network, M.ParseSocksaddr(addr)) + }, + TLSConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(s.ctx), + Time: ntp.TimeFuncFromContext(s.ctx), + }, + Header: ws.HandshakeHeaderHTTP(upstreamHeaders), + OnHeader: func(key, value []byte) error { + upstreamResponseHeaders.Add(string(key), string(value)) + return nil + }, + } + + upstreamConn, upstreamBufferedReader, _, err := upstreamDialer.Dial(r.Context(), upstreamURL) + if err != nil { + s.logger.Error("dial upstream websocket: ", err) + writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed") + return + } + + weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders) + + clientResponseHeaders := make(http.Header) + for key, values := range upstreamResponseHeaders { + if isForwardableResponseHeader(key) { + clientResponseHeaders[key] = values + } + } + + clientUpgrader := ws.HTTPUpgrader{ + Header: clientResponseHeaders, + } + clientConn, _, _, err := clientUpgrader.Upgrade(r, w) + if err != nil { + s.logger.Error("upgrade client websocket: ", err) + upstreamConn.Close() + return + } + + var upstreamReadWriter io.ReadWriter + if upstreamBufferedReader != nil { + upstreamReadWriter = struct { + io.Reader + io.Writer + }{upstreamBufferedReader, upstreamConn} + } else { + upstreamReadWriter = upstreamConn + } + + modelChannel := make(chan string, 1) + var waitGroup sync.WaitGroup + var once sync.Once + closeAll := func() { + clientConn.Close() + upstreamConn.Close() + } + + waitGroup.Add(2) + go func() { + defer waitGroup.Done() + defer once.Do(closeAll) + s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel) + }() + go func() { + defer waitGroup.Done() + defer once.Do(closeAll) + s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint) + }() + waitGroup.Wait() +} + +func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, modelChannel chan<- string) { + for { + data, opCode, err := wsutil.ReadClientData(clientConn) + if err != nil { + if !E.IsClosedOrCanceled(err) { + s.logger.Debug("read client websocket: ", err) + } + return + } + + if opCode == ws.OpText && s.usageTracker != nil { + var request struct { + Type string `json:"type"` + Model string `json:"model"` + } + if json.Unmarshal(data, &request) == nil && request.Type == "response.create" && request.Model != "" { + select { + case modelChannel <- request.Model: + default: + } + } + } + + err = wsutil.WriteClientMessage(upstreamConn, opCode, data) + if err != nil { + if !E.IsClosedOrCanceled(err) { + s.logger.Debug("write upstream websocket: ", err) + } + return + } + } +} + +func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { + var requestModel string + for { + data, opCode, err := wsutil.ReadServerData(upstreamReadWriter) + if err != nil { + if !E.IsClosedOrCanceled(err) { + s.logger.Debug("read upstream websocket: ", err) + } + return + } + + if opCode == ws.OpText && s.usageTracker != nil { + select { + case model := <-modelChannel: + requestModel = model + default: + } + + var event struct { + Type string `json:"type"` + } + if json.Unmarshal(data, &event) == nil && event.Type == "response.completed" { + var streamEvent responses.ResponseStreamEventUnion + if json.Unmarshal(data, &streamEvent) == nil { + completedEvent := streamEvent.AsResponseCompleted() + responseModel := string(completedEvent.Response.Model) + serviceTier := string(completedEvent.Response.ServiceTier) + inputTokens := completedEvent.Response.Usage.InputTokens + outputTokens := completedEvent.Response.Usage.OutputTokens + cachedTokens := completedEvent.Response.Usage.InputTokensDetails.CachedTokens + + if inputTokens > 0 || outputTokens > 0 { + if responseModel == "" { + responseModel = requestModel + } + if responseModel != "" { + s.usageTracker.AddUsageWithCycleHint( + responseModel, + inputTokens, + outputTokens, + cachedTokens, + serviceTier, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + } + } + } + + err = wsutil.WriteServerMessage(clientConn, opCode, data) + if err != nil { + if !E.IsClosedOrCanceled(err) { + s.logger.Debug("write client websocket: ", err) + } + return + } + } +}