diff --git a/service/ccm/service.go b/service/ccm/service.go index 944bedae..34c38824 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -281,11 +281,11 @@ func (s *Service) getAccessToken() (string, error) { return newCredentials.AccessToken, nil } -func detectContextWindow(betaHeader string, inputTokens int64) int { - if inputTokens > premiumContextThreshold { +func detectContextWindow(betaHeader string, totalInputTokens int64) int { + if totalInputTokens > premiumContextThreshold { features := strings.Split(betaHeader, ",") for _, feature := range features { - if strings.TrimSpace(feature) == "context-1m" { + if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { return contextWindowPremium } } @@ -454,7 +454,8 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if usage.InputTokens > 0 || usage.OutputTokens > 0 { if responseModel != "" { - contextWindow := detectContextWindow(anthropicBetaHeader, usage.InputTokens) + totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens + contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) s.usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, @@ -554,7 +555,8 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 { if responseModel != "" { - contextWindow := detectContextWindow(anthropicBetaHeader, accumulatedUsage.InputTokens) + totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens + contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) s.usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, diff --git a/service/ocm/service.go b/service/ocm/service.go index 0c2e3430..8b66964a 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -507,8 +507,10 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons responseModel = requestModel } if responseModel != "" { + contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) s.usageTracker.AddUsageWithCycleHint( responseModel, + contextWindow, inputTokens, outputTokens, cachedTokens, @@ -616,8 +618,10 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if inputTokens > 0 || outputTokens > 0 { if responseModel != "" { + contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) s.usageTracker.AddUsageWithCycleHint( responseModel, + contextWindow, inputTokens, outputTokens, cachedTokens, diff --git a/service/ocm/service_usage.go b/service/ocm/service_usage.go index 95d401a4..589fd093 100644 --- a/service/ocm/service_usage.go +++ b/service/ocm/service_usage.go @@ -46,6 +46,7 @@ func (u *UsageStats) UnmarshalJSON(data []byte) error { type CostCombination struct { Model string `json:"model"` ServiceTier string `json:"service_tier,omitempty"` + ContextWindow int `json:"context_window"` WeekStartUnix int64 `json:"week_start_unix,omitempty"` Total UsageStats `json:"total"` ByUser map[string]UsageStats `json:"by_user"` @@ -74,6 +75,7 @@ type UsageStatsJSON struct { type CostCombinationJSON struct { Model string `json:"model"` ServiceTier string `json:"service_tier,omitempty"` + ContextWindow int `json:"context_window"` WeekStartUnix int64 `json:"week_start_unix,omitempty"` Total UsageStatsJSON `json:"total"` ByUser map[string]UsageStatsJSON `json:"by_user"` @@ -104,8 +106,9 @@ type ModelPricing struct { } type modelFamily struct { - pattern *regexp.Regexp - pricing ModelPricing + pattern *regexp.Regexp + pricing ModelPricing + premiumPricing *ModelPricing } const ( @@ -116,6 +119,12 @@ const ( serviceTierScale = "scale" ) +const ( + contextWindowStandard = 272000 + contextWindowPremium = 1050000 + premiumContextThreshold = 272000 +) + var ( gpt52Pricing = ModelPricing{ InputPrice: 1.75, @@ -159,6 +168,30 @@ var ( CachedInputPrice: 0.025, } + gpt54StandardPricing = ModelPricing{ + InputPrice: 2.5, + OutputPrice: 15.0, + CachedInputPrice: 0.25, + } + + gpt54PremiumPricing = ModelPricing{ + InputPrice: 5.0, + OutputPrice: 22.5, + CachedInputPrice: 0.5, + } + + gpt54ProPricing = ModelPricing{ + InputPrice: 30.0, + OutputPrice: 180.0, + CachedInputPrice: 30.0, + } + + gpt54ProPremiumPricing = ModelPricing{ + InputPrice: 60.0, + OutputPrice: 270.0, + CachedInputPrice: 60.0, + } + gpt52ProPricing = ModelPricing{ InputPrice: 21.0, OutputPrice: 168.0, @@ -171,6 +204,30 @@ var ( CachedInputPrice: 15.0, } + gpt54FlexPricing = ModelPricing{ + InputPrice: 1.25, + OutputPrice: 7.5, + CachedInputPrice: 0.125, + } + + gpt54PremiumFlexPricing = ModelPricing{ + InputPrice: 2.5, + OutputPrice: 11.25, + CachedInputPrice: 0.25, + } + + gpt54ProFlexPricing = ModelPricing{ + InputPrice: 15.0, + OutputPrice: 90.0, + CachedInputPrice: 15.0, + } + + gpt54ProPremiumFlexPricing = ModelPricing{ + InputPrice: 30.0, + OutputPrice: 135.0, + CachedInputPrice: 30.0, + } + gpt52FlexPricing = ModelPricing{ InputPrice: 0.875, OutputPrice: 7.0, @@ -195,6 +252,18 @@ var ( CachedInputPrice: 0.0025, } + gpt54PriorityPricing = ModelPricing{ + InputPrice: 5.0, + OutputPrice: 30.0, + CachedInputPrice: 0.5, + } + + gpt54PremiumPriorityPricing = ModelPricing{ + InputPrice: 10.0, + OutputPrice: 45.0, + CachedInputPrice: 1.0, + } + gpt52PriorityPricing = ModelPricing{ InputPrice: 3.5, OutputPrice: 28.0, @@ -382,6 +451,16 @@ var ( } standardModelFamilies = []modelFamily{ + { + pattern: regexp.MustCompile(`^gpt-5\.4-pro(?:$|-)`), + pricing: gpt54ProPricing, + premiumPricing: &gpt54ProPremiumPricing, + }, + { + pattern: regexp.MustCompile(`^gpt-5\.4(?:$|-)`), + pricing: gpt54StandardPricing, + premiumPricing: &gpt54PremiumPricing, + }, { pattern: regexp.MustCompile(`^gpt-5\.3-codex(?:$|-)`), pricing: gpt52CodexPricing, @@ -525,6 +604,16 @@ var ( } flexModelFamilies = []modelFamily{ + { + pattern: regexp.MustCompile(`^gpt-5\.4-pro(?:$|-)`), + pricing: gpt54ProFlexPricing, + premiumPricing: &gpt54ProPremiumFlexPricing, + }, + { + pattern: regexp.MustCompile(`^gpt-5\.4(?:$|-)`), + pricing: gpt54FlexPricing, + premiumPricing: &gpt54PremiumFlexPricing, + }, { pattern: regexp.MustCompile(`^gpt-5-mini(?:$|-)`), pricing: gpt5MiniFlexPricing, @@ -556,6 +645,11 @@ var ( } priorityModelFamilies = []modelFamily{ + { + pattern: regexp.MustCompile(`^gpt-5\.4(?:$|-)`), + pricing: gpt54PriorityPricing, + premiumPricing: &gpt54PremiumPriorityPricing, + }, { pattern: regexp.MustCompile(`^gpt-5\.3-codex(?:$|-)`), pricing: gpt52CodexPriorityPricing, @@ -638,15 +732,28 @@ func modelFamiliesForTier(serviceTier string) []modelFamily { } } -func findPricingInFamilies(model string, modelFamilies []modelFamily) (ModelPricing, bool) { +func findPricingInFamilies(model string, contextWindow int, modelFamilies []modelFamily) (ModelPricing, bool) { + isPremium := contextWindow >= contextWindowPremium for _, family := range modelFamilies { if family.pattern.MatchString(model) { + if isPremium && family.premiumPricing != nil { + return *family.premiumPricing, true + } return family.pricing, true } } return ModelPricing{}, false } +func hasPremiumPricingInFamilies(model string, modelFamilies []modelFamily) bool { + for _, family := range modelFamilies { + if family.pattern.MatchString(model) { + return family.premiumPricing != nil + } + } + return false +} + func normalizeServiceTier(serviceTier string) string { switch strings.ToLower(strings.TrimSpace(serviceTier)) { case "", serviceTierAuto, serviceTierDefault: @@ -663,27 +770,27 @@ func normalizeServiceTier(serviceTier string) string { } } -func getPricing(model string, serviceTier string) ModelPricing { +func getPricing(model string, serviceTier string, contextWindow int) ModelPricing { normalizedServiceTier := normalizeServiceTier(serviceTier) - modelFamilies := modelFamiliesForTier(normalizedServiceTier) + families := modelFamiliesForTier(normalizedServiceTier) - if pricing, found := findPricingInFamilies(model, modelFamilies); found { + if pricing, found := findPricingInFamilies(model, contextWindow, families); found { return pricing } normalizedModel := normalizeGPT5Model(model) if normalizedModel != model { - if pricing, found := findPricingInFamilies(normalizedModel, modelFamilies); found { + if pricing, found := findPricingInFamilies(normalizedModel, contextWindow, families); found { return pricing } } if normalizedServiceTier != serviceTierDefault { - if pricing, found := findPricingInFamilies(model, standardModelFamilies); found { + if pricing, found := findPricingInFamilies(model, contextWindow, standardModelFamilies); found { return pricing } if normalizedModel != model { - if pricing, found := findPricingInFamilies(normalizedModel, standardModelFamilies); found { + if pricing, found := findPricingInFamilies(normalizedModel, contextWindow, standardModelFamilies); found { return pricing } } @@ -692,6 +799,30 @@ func getPricing(model string, serviceTier string) ModelPricing { return gpt4oPricing } +func detectContextWindow(model string, serviceTier string, inputTokens int64) int { + if inputTokens <= premiumContextThreshold { + return contextWindowStandard + } + normalizedServiceTier := normalizeServiceTier(serviceTier) + families := modelFamiliesForTier(normalizedServiceTier) + if hasPremiumPricingInFamilies(model, families) { + return contextWindowPremium + } + normalizedModel := normalizeGPT5Model(model) + if normalizedModel != model && hasPremiumPricingInFamilies(normalizedModel, families) { + return contextWindowPremium + } + if normalizedServiceTier != serviceTierDefault { + if hasPremiumPricingInFamilies(model, standardModelFamilies) { + return contextWindowPremium + } + if normalizedModel != model && hasPremiumPricingInFamilies(normalizedModel, standardModelFamilies) { + return contextWindowPremium + } + } + return contextWindowStandard +} + func normalizeGPT5Model(model string) string { if !strings.HasPrefix(model, "gpt-5.") { return model @@ -707,18 +838,18 @@ func normalizeGPT5Model(model string) string { case strings.Contains(model, "-chat-latest"): return "gpt-5.2-chat-latest" case strings.Contains(model, "-pro"): - return "gpt-5.2-pro" + return "gpt-5.4-pro" case strings.Contains(model, "-mini"): return "gpt-5-mini" case strings.Contains(model, "-nano"): return "gpt-5-nano" default: - return "gpt-5.2" + return "gpt-5.4" } } -func calculateCost(stats UsageStats, model string, serviceTier string) float64 { - pricing := getPricing(model, serviceTier) +func calculateCost(stats UsageStats, model string, serviceTier string, contextWindow int) float64 { + pricing := getPricing(model, serviceTier, contextWindow) regularInputTokens := stats.InputTokens - stats.CachedTokens if regularInputTokens < 0 { @@ -739,13 +870,16 @@ func roundCost(cost float64) float64 { func normalizeCombinations(combinations []CostCombination) { for index := range combinations { combinations[index].ServiceTier = normalizeServiceTier(combinations[index].ServiceTier) + if combinations[index].ContextWindow <= 0 { + combinations[index].ContextWindow = contextWindowStandard + } if combinations[index].ByUser == nil { combinations[index].ByUser = make(map[string]UsageStats) } } } -func addUsageToCombinations(combinations *[]CostCombination, model string, serviceTier string, weekStartUnix int64, user string, inputTokens, outputTokens, cachedTokens int64) { +func addUsageToCombinations(combinations *[]CostCombination, model string, serviceTier string, contextWindow int, weekStartUnix int64, user string, inputTokens, outputTokens, cachedTokens int64) { var matchedCombination *CostCombination for index := range *combinations { combination := &(*combinations)[index] @@ -753,7 +887,7 @@ func addUsageToCombinations(combinations *[]CostCombination, model string, servi if combination.ServiceTier != combinationServiceTier { combination.ServiceTier = combinationServiceTier } - if combination.Model == model && combinationServiceTier == serviceTier && combination.WeekStartUnix == weekStartUnix { + if combination.Model == model && combinationServiceTier == serviceTier && combination.ContextWindow == contextWindow && combination.WeekStartUnix == weekStartUnix { matchedCombination = combination break } @@ -763,6 +897,7 @@ func addUsageToCombinations(combinations *[]CostCombination, model string, servi newCombination := CostCombination{ Model: model, ServiceTier: serviceTier, + ContextWindow: contextWindow, WeekStartUnix: weekStartUnix, Total: UsageStats{}, ByUser: make(map[string]UsageStats), @@ -791,12 +926,13 @@ func buildCombinationJSON(combinations []CostCombination, aggregateUserCosts map var totalCost float64 for index, combination := range combinations { - combinationTotalCost := calculateCost(combination.Total, combination.Model, combination.ServiceTier) + combinationTotalCost := calculateCost(combination.Total, combination.Model, combination.ServiceTier, combination.ContextWindow) totalCost += combinationTotalCost combinationJSON := CostCombinationJSON{ Model: combination.Model, ServiceTier: combination.ServiceTier, + ContextWindow: combination.ContextWindow, WeekStartUnix: combination.WeekStartUnix, Total: UsageStatsJSON{ RequestCount: combination.Total.RequestCount, @@ -809,7 +945,7 @@ func buildCombinationJSON(combinations []CostCombination, aggregateUserCosts map } for user, userStats := range combination.ByUser { - userCost := calculateCost(userStats, combination.Model, combination.ServiceTier) + userCost := calculateCost(userStats, combination.Model, combination.ServiceTier, combination.ContextWindow) if aggregateUserCosts != nil { aggregateUserCosts[user] += userCost } @@ -857,7 +993,7 @@ func buildByWeekCost(combinations []CostCombination) map[string]float64 { } weekStartAt := time.Unix(combination.WeekStartUnix, 0).UTC() weekKey := formatWeekStartKey(weekStartAt) - byWeek[weekKey] += calculateCost(combination.Total, combination.Model, combination.ServiceTier) + byWeek[weekKey] += calculateCost(combination.Total, combination.Model, combination.ServiceTier, combination.ContextWindow) } for weekKey, weekCost := range byWeek { byWeek[weekKey] = roundCost(weekCost) @@ -879,7 +1015,7 @@ func buildByUserAndWeekCost(combinations []CostCombination) map[string]map[strin userWeeks = make(map[string]float64) byUserAndWeek[user] = userWeeks } - userWeeks[weekKey] += calculateCost(userStats, combination.Model, combination.ServiceTier) + userWeeks[weekKey] += calculateCost(userStats, combination.Model, combination.ServiceTier, combination.ContextWindow) } } for _, weekCosts := range byUserAndWeek { @@ -987,14 +1123,17 @@ func (u *AggregatedUsage) Save() error { return err } -func (u *AggregatedUsage) AddUsage(model string, inputTokens, outputTokens, cachedTokens int64, serviceTier string, user string) error { - return u.AddUsageWithCycleHint(model, inputTokens, outputTokens, cachedTokens, serviceTier, user, time.Now(), nil) +func (u *AggregatedUsage) AddUsage(model string, contextWindow int, inputTokens, outputTokens, cachedTokens int64, serviceTier string, user string) error { + return u.AddUsageWithCycleHint(model, contextWindow, inputTokens, outputTokens, cachedTokens, serviceTier, user, time.Now(), nil) } -func (u *AggregatedUsage) AddUsageWithCycleHint(model string, inputTokens, outputTokens, cachedTokens int64, serviceTier string, user string, observedAt time.Time, cycleHint *WeeklyCycleHint) error { +func (u *AggregatedUsage) AddUsageWithCycleHint(model string, contextWindow int, inputTokens, outputTokens, cachedTokens int64, serviceTier string, user string, observedAt time.Time, cycleHint *WeeklyCycleHint) error { if model == "" { return E.New("model cannot be empty") } + if contextWindow <= 0 { + return E.New("contextWindow must be positive") + } normalizedServiceTier := normalizeServiceTier(serviceTier) if observedAt.IsZero() { @@ -1007,7 +1146,7 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(model string, inputTokens, outpu u.LastUpdated = observedAt weekStartUnix := deriveWeekStartUnix(cycleHint) - addUsageToCombinations(&u.Combinations, model, normalizedServiceTier, weekStartUnix, user, inputTokens, outputTokens, cachedTokens) + addUsageToCombinations(&u.Combinations, model, normalizedServiceTier, contextWindow, weekStartUnix, user, inputTokens, outputTokens, cachedTokens) go u.scheduleSave() diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index c2e6148d..d19f2df8 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -256,8 +256,10 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite responseModel = requestModel } if responseModel != "" { + contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) s.usageTracker.AddUsageWithCycleHint( responseModel, + contextWindow, inputTokens, outputTokens, cachedTokens,