From 4d217b7481319f3d0612329cbc2f26e0102ce137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 3 Mar 2026 20:59:13 +0800 Subject: [PATCH 01/96] Add MAC and hostname rule items --- adapter/inbound.go | 3 + adapter/neighbor.go | 13 + adapter/router.go | 2 + docs/configuration/dns/rule.md | 31 ++ docs/configuration/dns/rule.zh.md | 31 ++ docs/configuration/inbound/tun.md | 35 ++ docs/configuration/inbound/tun.zh.md | 35 ++ docs/configuration/route/index.md | 31 ++ docs/configuration/route/index.zh.md | 31 ++ docs/configuration/route/rule.md | 31 ++ docs/configuration/route/rule.zh.md | 31 ++ go.mod | 6 +- go.sum | 4 +- option/route.go | 2 + option/rule.go | 2 + option/rule_dns.go | 2 + option/tun.go | 2 + protocol/tun/inbound.go | 18 + route/neighbor_resolver_linux.go | 596 +++++++++++++++++++++ route/neighbor_resolver_stub.go | 14 + route/route.go | 17 + route/router.go | 39 ++ route/rule/rule_default.go | 10 + route/rule/rule_dns.go | 10 + route/rule/rule_item_source_hostname.go | 42 ++ route/rule/rule_item_source_mac_address.go | 48 ++ route/rule_conds.go | 8 + 27 files changed, 1089 insertions(+), 5 deletions(-) create mode 100644 adapter/neighbor.go create mode 100644 route/neighbor_resolver_linux.go create mode 100644 route/neighbor_resolver_stub.go create mode 100644 route/rule/rule_item_source_hostname.go create mode 100644 route/rule/rule_item_source_mac_address.go diff --git a/adapter/inbound.go b/adapter/inbound.go index b32e9f8278..acd6f4912c 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -2,6 +2,7 @@ package adapter import ( "context" + "net" "net/netip" "time" @@ -82,6 +83,8 @@ type InboundContext struct { SourceGeoIPCode string GeoIPCode string ProcessInfo *ConnectionOwner + SourceMACAddress net.HardwareAddr + SourceHostname string QueryType uint16 FakeIP bool diff --git a/adapter/neighbor.go b/adapter/neighbor.go new file mode 100644 index 0000000000..920398f674 --- /dev/null +++ b/adapter/neighbor.go @@ -0,0 +1,13 @@ +package adapter + +import ( + "net" + "net/netip" +) + +type NeighborResolver interface { + LookupMAC(address netip.Addr) (net.HardwareAddr, bool) + LookupHostname(address netip.Addr) (string, bool) + Start() error + Close() error +} diff --git a/adapter/router.go b/adapter/router.go index 3d5310c4ee..82e6881a60 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -26,6 +26,8 @@ type Router interface { RuleSet(tag string) (RuleSet, bool) Rules() []Rule NeedFindProcess() bool + NeedFindNeighbor() bool + NeighborResolver() NeighborResolver AppendTracker(tracker ConnectionTracker) ResetNetwork() } diff --git a/docs/configuration/dns/rule.md b/docs/configuration/dns/rule.md index 6407e1bf60..262a23e629 100644 --- a/docs/configuration/dns/rule.md +++ b/docs/configuration/dns/rule.md @@ -2,6 +2,11 @@ icon: material/alert-decagram --- +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [source_mac_address](#source_mac_address) + :material-plus: [source_hostname](#source_hostname) + !!! quote "Changes in sing-box 1.13.0" :material-plus: [interface_address](#interface_address) @@ -149,6 +154,12 @@ icon: material/alert-decagram "default_interface_address": [ "2000::/3" ], + "source_mac_address": [ + "00:11:22:33:44:55" + ], + "source_hostname": [ + "my-device" + ], "wifi_ssid": [ "My WIFI" ], @@ -408,6 +419,26 @@ Matches network interface (same values as `network_type`) address. Match default interface address. +#### source_mac_address + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `route.find_neighbor` enabled. + +Match source device MAC address. + +#### source_hostname + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `route.find_neighbor` enabled. + +Match source device hostname from DHCP leases. + #### wifi_ssid !!! quote "" diff --git a/docs/configuration/dns/rule.zh.md b/docs/configuration/dns/rule.zh.md index 588e0736a4..4bf60b9862 100644 --- a/docs/configuration/dns/rule.zh.md +++ b/docs/configuration/dns/rule.zh.md @@ -2,6 +2,11 @@ icon: material/alert-decagram --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [source_mac_address](#source_mac_address) + :material-plus: [source_hostname](#source_hostname) + !!! quote "sing-box 1.13.0 中的更改" :material-plus: [interface_address](#interface_address) @@ -149,6 +154,12 @@ icon: material/alert-decagram "default_interface_address": [ "2000::/3" ], + "source_mac_address": [ + "00:11:22:33:44:55" + ], + "source_hostname": [ + "my-device" + ], "wifi_ssid": [ "My WIFI" ], @@ -407,6 +418,26 @@ Available values: `wifi`, `cellular`, `ethernet` and `other`. 匹配默认接口地址。 +#### source_mac_address + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + +匹配源设备 MAC 地址。 + +#### source_hostname + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + +匹配源设备从 DHCP 租约获取的主机名。 + #### wifi_ssid !!! quote "" diff --git a/docs/configuration/inbound/tun.md b/docs/configuration/inbound/tun.md index ed368a13a2..5a2f58d3db 100644 --- a/docs/configuration/inbound/tun.md +++ b/docs/configuration/inbound/tun.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [include_mac_address](#include_mac_address) + :material-plus: [exclude_mac_address](#exclude_mac_address) + !!! quote "Changes in sing-box 1.13.3" :material-alert: [strict_route](#strict_route) @@ -129,6 +134,12 @@ icon: material/new-box "exclude_package": [ "com.android.captiveportallogin" ], + "include_mac_address": [ + "00:11:22:33:44:55" + ], + "exclude_mac_address": [ + "66:77:88:99:aa:bb" + ], "platform": { "http_proxy": { "enabled": false, @@ -555,6 +566,30 @@ Limit android packages in route. Exclude android packages in route. +#### include_mac_address + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `auto_route` and `auto_redirect` enabled. + +Limit MAC addresses in route. Not limited by default. + +Conflict with `exclude_mac_address`. + +#### exclude_mac_address + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `auto_route` and `auto_redirect` enabled. + +Exclude MAC addresses in route. + +Conflict with `include_mac_address`. + #### platform Platform-specific settings, provided by client applications. diff --git a/docs/configuration/inbound/tun.zh.md b/docs/configuration/inbound/tun.zh.md index eaf5ff49c3..a41e5ae9ff 100644 --- a/docs/configuration/inbound/tun.zh.md +++ b/docs/configuration/inbound/tun.zh.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [include_mac_address](#include_mac_address) + :material-plus: [exclude_mac_address](#exclude_mac_address) + !!! quote "sing-box 1.13.3 中的更改" :material-alert: [strict_route](#strict_route) @@ -130,6 +135,12 @@ icon: material/new-box "exclude_package": [ "com.android.captiveportallogin" ], + "include_mac_address": [ + "00:11:22:33:44:55" + ], + "exclude_mac_address": [ + "66:77:88:99:aa:bb" + ], "platform": { "http_proxy": { "enabled": false, @@ -543,6 +554,30 @@ TCP/IP 栈。 排除路由的 Android 应用包名。 +#### include_mac_address + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `auto_route` 和 `auto_redirect` 已启用。 + +限制被路由的 MAC 地址。默认不限制。 + +与 `exclude_mac_address` 冲突。 + +#### exclude_mac_address + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `auto_route` 和 `auto_redirect` 已启用。 + +排除路由的 MAC 地址。 + +与 `include_mac_address` 冲突。 + #### platform 平台特定的设置,由客户端应用提供。 diff --git a/docs/configuration/route/index.md b/docs/configuration/route/index.md index 1fc9bfd231..01e405614e 100644 --- a/docs/configuration/route/index.md +++ b/docs/configuration/route/index.md @@ -4,6 +4,11 @@ icon: material/alert-decagram # Route +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [find_neighbor](#find_neighbor) + :material-plus: [dhcp_lease_files](#dhcp_lease_files) + !!! quote "Changes in sing-box 1.12.0" :material-plus: [default_domain_resolver](#default_domain_resolver) @@ -35,6 +40,8 @@ icon: material/alert-decagram "override_android_vpn": false, "default_interface": "", "default_mark": 0, + "find_neighbor": false, + "dhcp_lease_files": [], "default_domain_resolver": "", // or {} "default_network_strategy": "", "default_network_type": [], @@ -107,6 +114,30 @@ Set routing mark by default. Takes no effect if `outbound.routing_mark` is set. +#### find_neighbor + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux. + +Enable neighbor resolution for source MAC address and hostname lookup. + +Required for `source_mac_address` and `source_hostname` rule items. + +#### dhcp_lease_files + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux. + +Custom DHCP lease file paths for hostname and MAC address resolution. + +Automatically detected from common DHCP servers (dnsmasq, odhcpd, ISC dhcpd, Kea) if empty. + #### default_domain_resolver !!! question "Since sing-box 1.12.0" diff --git a/docs/configuration/route/index.zh.md b/docs/configuration/route/index.zh.md index fa50bfe7d9..84ce76723c 100644 --- a/docs/configuration/route/index.zh.md +++ b/docs/configuration/route/index.zh.md @@ -4,6 +4,11 @@ icon: material/alert-decagram # 路由 +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [find_neighbor](#find_neighbor) + :material-plus: [dhcp_lease_files](#dhcp_lease_files) + !!! quote "sing-box 1.12.0 中的更改" :material-plus: [default_domain_resolver](#default_domain_resolver) @@ -37,6 +42,8 @@ icon: material/alert-decagram "override_android_vpn": false, "default_interface": "", "default_mark": 0, + "find_neighbor": false, + "dhcp_lease_files": [], "default_network_strategy": "", "default_fallback_delay": "" } @@ -106,6 +113,30 @@ icon: material/alert-decagram 如果设置了 `outbound.routing_mark` 设置,则不生效。 +#### find_neighbor + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux。 + +启用邻居解析以查找源 MAC 地址和主机名。 + +`source_mac_address` 和 `source_hostname` 规则项需要此选项。 + +#### dhcp_lease_files + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux。 + +用于主机名和 MAC 地址解析的自定义 DHCP 租约文件路径。 + +为空时自动从常见 DHCP 服务器(dnsmasq、odhcpd、ISC dhcpd、Kea)检测。 + #### default_domain_resolver !!! question "自 sing-box 1.12.0 起" diff --git a/docs/configuration/route/rule.md b/docs/configuration/route/rule.md index 31f768fe23..d226571096 100644 --- a/docs/configuration/route/rule.md +++ b/docs/configuration/route/rule.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [source_mac_address](#source_mac_address) + :material-plus: [source_hostname](#source_hostname) + !!! quote "Changes in sing-box 1.13.0" :material-plus: [interface_address](#interface_address) @@ -159,6 +164,12 @@ icon: material/new-box "tailscale", "wireguard" ], + "source_mac_address": [ + "00:11:22:33:44:55" + ], + "source_hostname": [ + "my-device" + ], "rule_set": [ "geoip-cn", "geosite-cn" @@ -449,6 +460,26 @@ Match specified outbounds' preferred routes. | `tailscale` | Match MagicDNS domains and peers' allowed IPs | | `wireguard` | Match peers's allowed IPs | +#### source_mac_address + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `route.find_neighbor` enabled. + +Match source device MAC address. + +#### source_hostname + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `route.find_neighbor` enabled. + +Match source device hostname from DHCP leases. + #### rule_set !!! question "Since sing-box 1.8.0" diff --git a/docs/configuration/route/rule.zh.md b/docs/configuration/route/rule.zh.md index 1ffe57d688..597e655f6e 100644 --- a/docs/configuration/route/rule.zh.md +++ b/docs/configuration/route/rule.zh.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [source_mac_address](#source_mac_address) + :material-plus: [source_hostname](#source_hostname) + !!! quote "sing-box 1.13.0 中的更改" :material-plus: [interface_address](#interface_address) @@ -156,6 +161,12 @@ icon: material/new-box "tailscale", "wireguard" ], + "source_mac_address": [ + "00:11:22:33:44:55" + ], + "source_hostname": [ + "my-device" + ], "rule_set": [ "geoip-cn", "geosite-cn" @@ -446,6 +457,26 @@ icon: material/new-box | `tailscale` | 匹配 MagicDNS 域名和对端的 allowed IPs | | `wireguard` | 匹配对端的 allowed IPs | +#### source_mac_address + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + +匹配源设备 MAC 地址。 + +#### source_hostname + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + +匹配源设备从 DHCP 租约获取的主机名。 + #### rule_set !!! question "自 sing-box 1.8.0 起" diff --git a/go.mod b/go.mod index a394028de0..32aa7061d7 100644 --- a/go.mod +++ b/go.mod @@ -14,11 +14,13 @@ require ( github.com/godbus/dbus/v5 v5.2.2 github.com/gofrs/uuid/v5 v5.4.0 github.com/insomniacslk/dhcp v0.0.0-20260220084031-5adc3eb26f91 + github.com/jsimonetti/rtnetlink v1.4.0 github.com/keybase/go-keychain v0.0.1 github.com/libdns/acmedns v0.5.0 github.com/libdns/alidns v1.0.6 github.com/libdns/cloudflare v0.2.2 github.com/logrusorgru/aurora v2.0.3+incompatible + github.com/mdlayher/netlink v1.9.0 github.com/metacubex/utls v1.8.4 github.com/mholt/acmez/v3 v3.1.6 github.com/miekg/dns v1.1.72 @@ -39,7 +41,7 @@ require ( github.com/sagernet/sing-shadowsocks v0.2.8 github.com/sagernet/sing-shadowsocks2 v0.2.1 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 - github.com/sagernet/sing-tun v0.8.3 + github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226 github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 github.com/sagernet/smux v1.5.50-sing-box-mod.1 github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260311131347-f88b27eeb76e @@ -92,11 +94,9 @@ require ( github.com/hashicorp/yamux v0.1.2 // indirect github.com/hdevalence/ed25519consensus v0.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jsimonetti/rtnetlink v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/libdns/libdns v1.1.1 // indirect - github.com/mdlayher/netlink v1.9.0 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/mitchellh/go-ps v1.0.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect diff --git a/go.sum b/go.sum index 76a680f872..dea19790c3 100644 --- a/go.sum +++ b/go.sum @@ -248,8 +248,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA= -github.com/sagernet/sing-tun v0.8.3 h1:mozxmuIoRhFdVHnheenLpBaammVj7bZPcnkApaYKDPY= -github.com/sagernet/sing-tun v0.8.3/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs= +github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226 h1:Shy/fsm+pqVq6OkBAWPaOmOiPT/AwoRxQLiV1357Y0Y= +github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs= github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 h1:aSwUNYUkVyVvdmBSufR8/nRFonwJeKSIROxHcm5br9o= github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1/go.mod h1:P11scgTxMxVVQ8dlM27yNm3Cro40mD0+gHbnqrNGDuY= github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478= diff --git a/option/route.go b/option/route.go index f4b6539156..0c3e576d13 100644 --- a/option/route.go +++ b/option/route.go @@ -9,6 +9,8 @@ type RouteOptions struct { RuleSet []RuleSet `json:"rule_set,omitempty"` Final string `json:"final,omitempty"` FindProcess bool `json:"find_process,omitempty"` + FindNeighbor bool `json:"find_neighbor,omitempty"` + DHCPLeaseFiles badoption.Listable[string] `json:"dhcp_lease_files,omitempty"` AutoDetectInterface bool `json:"auto_detect_interface,omitempty"` OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"` DefaultInterface string `json:"default_interface,omitempty"` diff --git a/option/rule.go b/option/rule.go index 3e7fd8771b..b792ccf4b2 100644 --- a/option/rule.go +++ b/option/rule.go @@ -103,6 +103,8 @@ type RawDefaultRule struct { InterfaceAddress *badjson.TypedMap[string, badoption.Listable[*badoption.Prefixable]] `json:"interface_address,omitempty"` NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[*badoption.Prefixable]] `json:"network_interface_address,omitempty"` DefaultInterfaceAddress badoption.Listable[*badoption.Prefixable] `json:"default_interface_address,omitempty"` + SourceMACAddress badoption.Listable[string] `json:"source_mac_address,omitempty"` + SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"` PreferredBy badoption.Listable[string] `json:"preferred_by,omitempty"` RuleSet badoption.Listable[string] `json:"rule_set,omitempty"` RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"` diff --git a/option/rule_dns.go b/option/rule_dns.go index dbc1657898..880b96ac54 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -106,6 +106,8 @@ type RawDefaultDNSRule struct { InterfaceAddress *badjson.TypedMap[string, badoption.Listable[*badoption.Prefixable]] `json:"interface_address,omitempty"` NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[*badoption.Prefixable]] `json:"network_interface_address,omitempty"` DefaultInterfaceAddress badoption.Listable[*badoption.Prefixable] `json:"default_interface_address,omitempty"` + SourceMACAddress badoption.Listable[string] `json:"source_mac_address,omitempty"` + SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"` RuleSet badoption.Listable[string] `json:"rule_set,omitempty"` RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"` RuleSetIPCIDRAcceptEmpty bool `json:"rule_set_ip_cidr_accept_empty,omitempty"` diff --git a/option/tun.go b/option/tun.go index 72b6e456ba..fda028b69e 100644 --- a/option/tun.go +++ b/option/tun.go @@ -39,6 +39,8 @@ type TunInboundOptions struct { IncludeAndroidUser badoption.Listable[int] `json:"include_android_user,omitempty"` IncludePackage badoption.Listable[string] `json:"include_package,omitempty"` ExcludePackage badoption.Listable[string] `json:"exclude_package,omitempty"` + IncludeMACAddress badoption.Listable[string] `json:"include_mac_address,omitempty"` + ExcludeMACAddress badoption.Listable[string] `json:"exclude_mac_address,omitempty"` UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"` Stack string `json:"stack,omitempty"` Platform *TunPlatformOptions `json:"platform,omitempty"` diff --git a/protocol/tun/inbound.go b/protocol/tun/inbound.go index df9344b817..6f10849321 100644 --- a/protocol/tun/inbound.go +++ b/protocol/tun/inbound.go @@ -156,6 +156,22 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo if nfQueue == 0 { nfQueue = tun.DefaultAutoRedirectNFQueue } + var includeMACAddress []net.HardwareAddr + for i, macString := range options.IncludeMACAddress { + mac, macErr := net.ParseMAC(macString) + if macErr != nil { + return nil, E.Cause(macErr, "parse include_mac_address[", i, "]") + } + includeMACAddress = append(includeMACAddress, mac) + } + var excludeMACAddress []net.HardwareAddr + for i, macString := range options.ExcludeMACAddress { + mac, macErr := net.ParseMAC(macString) + if macErr != nil { + return nil, E.Cause(macErr, "parse exclude_mac_address[", i, "]") + } + excludeMACAddress = append(excludeMACAddress, mac) + } networkManager := service.FromContext[adapter.NetworkManager](ctx) multiPendingPackets := C.IsDarwin && ((options.Stack == "gvisor" && tunMTU < 32768) || (options.Stack != "gvisor" && options.MTU <= 9000)) inbound := &Inbound{ @@ -193,6 +209,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo IncludeAndroidUser: options.IncludeAndroidUser, IncludePackage: options.IncludePackage, ExcludePackage: options.ExcludePackage, + IncludeMACAddress: includeMACAddress, + ExcludeMACAddress: excludeMACAddress, InterfaceMonitor: networkManager.InterfaceMonitor(), EXP_MultiPendingPackets: multiPendingPackets, }, diff --git a/route/neighbor_resolver_linux.go b/route/neighbor_resolver_linux.go new file mode 100644 index 0000000000..40db5766ad --- /dev/null +++ b/route/neighbor_resolver_linux.go @@ -0,0 +1,596 @@ +//go:build linux + +package route + +import ( + "bufio" + "encoding/binary" + "encoding/hex" + "net" + "net/netip" + "os" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + + "github.com/jsimonetti/rtnetlink" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +var defaultLeaseFiles = []string{ + "/tmp/dhcp.leases", + "/var/lib/dhcp/dhcpd.leases", + "/var/lib/dhcpd/dhcpd.leases", + "/var/lib/kea/kea-leases4.csv", + "/var/lib/kea/kea-leases6.csv", +} + +type neighborResolver struct { + logger logger.ContextLogger + leaseFiles []string + access sync.RWMutex + neighborIPToMAC map[netip.Addr]net.HardwareAddr + leaseIPToMAC map[netip.Addr]net.HardwareAddr + ipToHostname map[netip.Addr]string + macToHostname map[string]string + watcher *fswatch.Watcher + done chan struct{} +} + +func newNeighborResolver(resolverLogger logger.ContextLogger, leaseFiles []string) (adapter.NeighborResolver, error) { + if len(leaseFiles) == 0 { + for _, path := range defaultLeaseFiles { + info, err := os.Stat(path) + if err == nil && info.Size() > 0 { + leaseFiles = append(leaseFiles, path) + } + } + } + return &neighborResolver{ + logger: resolverLogger, + leaseFiles: leaseFiles, + neighborIPToMAC: make(map[netip.Addr]net.HardwareAddr), + leaseIPToMAC: make(map[netip.Addr]net.HardwareAddr), + ipToHostname: make(map[netip.Addr]string), + macToHostname: make(map[string]string), + done: make(chan struct{}), + }, nil +} + +func (r *neighborResolver) Start() error { + err := r.loadNeighborTable() + if err != nil { + r.logger.Warn(E.Cause(err, "load neighbor table")) + } + r.reloadLeaseFiles() + go r.subscribeNeighborUpdates() + if len(r.leaseFiles) > 0 { + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: r.leaseFiles, + Logger: r.logger, + Callback: func(_ string) { + r.reloadLeaseFiles() + }, + }) + if err != nil { + r.logger.Warn(E.Cause(err, "create lease file watcher")) + } else { + r.watcher = watcher + err = watcher.Start() + if err != nil { + r.logger.Warn(E.Cause(err, "start lease file watcher")) + } + } + } + return nil +} + +func (r *neighborResolver) Close() error { + close(r.done) + if r.watcher != nil { + return r.watcher.Close() + } + return nil +} + +func (r *neighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) { + r.access.RLock() + defer r.access.RUnlock() + mac, found := r.neighborIPToMAC[address] + if found { + return mac, true + } + mac, found = r.leaseIPToMAC[address] + if found { + return mac, true + } + mac, found = extractMACFromEUI64(address) + if found { + return mac, true + } + return nil, false +} + +func (r *neighborResolver) LookupHostname(address netip.Addr) (string, bool) { + r.access.RLock() + defer r.access.RUnlock() + hostname, found := r.ipToHostname[address] + if found { + return hostname, true + } + mac, macFound := r.neighborIPToMAC[address] + if !macFound { + mac, macFound = r.leaseIPToMAC[address] + } + if !macFound { + mac, macFound = extractMACFromEUI64(address) + } + if macFound { + hostname, found = r.macToHostname[mac.String()] + if found { + return hostname, true + } + } + return "", false +} + +func (r *neighborResolver) loadNeighborTable() error { + connection, err := rtnetlink.Dial(nil) + if err != nil { + return E.Cause(err, "dial rtnetlink") + } + defer connection.Close() + neighbors, err := connection.Neigh.List() + if err != nil { + return E.Cause(err, "list neighbors") + } + r.access.Lock() + defer r.access.Unlock() + for _, neigh := range neighbors { + if neigh.Attributes == nil { + continue + } + if neigh.Attributes.LLAddress == nil || len(neigh.Attributes.Address) == 0 { + continue + } + address, ok := netip.AddrFromSlice(neigh.Attributes.Address) + if !ok { + continue + } + r.neighborIPToMAC[address] = slices.Clone(neigh.Attributes.LLAddress) + } + return nil +} + +func (r *neighborResolver) subscribeNeighborUpdates() { + connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + Groups: 1 << (unix.RTNLGRP_NEIGH - 1), + }) + if err != nil { + r.logger.Warn(E.Cause(err, "subscribe neighbor updates")) + return + } + defer connection.Close() + for { + select { + case <-r.done: + return + default: + } + err = connection.SetReadDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + r.logger.Warn(E.Cause(err, "set netlink read deadline")) + return + } + messages, err := connection.Receive() + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-r.done: + return + default: + } + r.logger.Warn(E.Cause(err, "receive neighbor update")) + continue + } + for _, message := range messages { + switch message.Header.Type { + case unix.RTM_NEWNEIGH: + var neighMessage rtnetlink.NeighMessage + unmarshalErr := neighMessage.UnmarshalBinary(message.Data) + if unmarshalErr != nil { + continue + } + if neighMessage.Attributes == nil { + continue + } + if neighMessage.Attributes.LLAddress == nil || len(neighMessage.Attributes.Address) == 0 { + continue + } + address, ok := netip.AddrFromSlice(neighMessage.Attributes.Address) + if !ok { + continue + } + r.access.Lock() + r.neighborIPToMAC[address] = slices.Clone(neighMessage.Attributes.LLAddress) + r.access.Unlock() + case unix.RTM_DELNEIGH: + var neighMessage rtnetlink.NeighMessage + unmarshalErr := neighMessage.UnmarshalBinary(message.Data) + if unmarshalErr != nil { + continue + } + if neighMessage.Attributes == nil || len(neighMessage.Attributes.Address) == 0 { + continue + } + address, ok := netip.AddrFromSlice(neighMessage.Attributes.Address) + if !ok { + continue + } + r.access.Lock() + delete(r.neighborIPToMAC, address) + r.access.Unlock() + } + } + } +} + +func (r *neighborResolver) reloadLeaseFiles() { + leaseIPToMAC := make(map[netip.Addr]net.HardwareAddr) + ipToHostname := make(map[netip.Addr]string) + macToHostname := make(map[string]string) + for _, path := range r.leaseFiles { + r.parseLeaseFile(path, leaseIPToMAC, ipToHostname, macToHostname) + } + r.access.Lock() + r.leaseIPToMAC = leaseIPToMAC + r.ipToHostname = ipToHostname + r.macToHostname = macToHostname + r.access.Unlock() +} + +func (r *neighborResolver) parseLeaseFile(path string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + file, err := os.Open(path) + if err != nil { + return + } + defer file.Close() + if strings.HasSuffix(path, "kea-leases4.csv") { + r.parseKeaCSV4(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "kea-leases6.csv") { + r.parseKeaCSV6(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "dhcpd.leases") { + r.parseISCDhcpd(file, ipToMAC, ipToHostname, macToHostname) + return + } + r.parseDnsmasqOdhcpd(file, ipToMAC, ipToHostname, macToHostname) +} + +func (r *neighborResolver) parseDnsmasqOdhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + now := time.Now().Unix() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "duid ") { + continue + } + if strings.HasPrefix(line, "# ") { + r.parseOdhcpdLine(line[2:], ipToMAC, ipToHostname, macToHostname) + continue + } + fields := strings.Fields(line) + if len(fields) < 4 { + continue + } + expiry, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + continue + } + if expiry != 0 && expiry < now { + continue + } + if strings.Contains(fields[1], ":") { + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) + if !addrOK { + continue + } + address = address.Unmap() + ipToMAC[address] = mac + hostname := fields[3] + if hostname != "*" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + } else { + var mac net.HardwareAddr + if len(fields) >= 5 { + duid, duidErr := parseDUID(fields[4]) + if duidErr == nil { + mac, _ = extractMACFromDUID(duid) + } + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) + if !addrOK { + continue + } + address = address.Unmap() + if mac != nil { + ipToMAC[address] = mac + } + hostname := fields[3] + if hostname != "*" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } + } +} + +func (r *neighborResolver) parseOdhcpdLine(line string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + fields := strings.Fields(line) + if len(fields) < 5 { + return + } + validTime, err := strconv.ParseInt(fields[4], 10, 64) + if err != nil { + return + } + if validTime == 0 { + return + } + if validTime > 0 && validTime < time.Now().Unix() { + return + } + hostname := fields[3] + if hostname == "-" || strings.HasPrefix(hostname, `broken\x20`) { + hostname = "" + } + if len(fields) >= 8 && fields[2] == "ipv4" { + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + return + } + addressField := fields[7] + slashIndex := strings.IndexByte(addressField, '/') + if slashIndex >= 0 { + addressField = addressField[:slashIndex] + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) + if !addrOK { + return + } + address = address.Unmap() + ipToMAC[address] = mac + if hostname != "" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + return + } + var mac net.HardwareAddr + duidHex := fields[1] + duidBytes, hexErr := hex.DecodeString(duidHex) + if hexErr == nil { + mac, _ = extractMACFromDUID(duidBytes) + } + for i := 7; i < len(fields); i++ { + addressField := fields[i] + slashIndex := strings.IndexByte(addressField, '/') + if slashIndex >= 0 { + addressField = addressField[:slashIndex] + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) + if !addrOK { + continue + } + address = address.Unmap() + if mac != nil { + ipToMAC[address] = mac + } + if hostname != "" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } +} + +func (r *neighborResolver) parseISCDhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + var currentIP netip.Addr + var currentMAC net.HardwareAddr + var currentHostname string + var currentActive bool + var inLease bool + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "lease ") && strings.HasSuffix(line, "{") { + ipString := strings.TrimSuffix(strings.TrimPrefix(line, "lease "), " {") + parsed, addrOK := netip.AddrFromSlice(net.ParseIP(ipString)) + if addrOK { + currentIP = parsed.Unmap() + inLease = true + currentMAC = nil + currentHostname = "" + currentActive = false + } + continue + } + if line == "}" && inLease { + if currentActive && currentMAC != nil { + ipToMAC[currentIP] = currentMAC + if currentHostname != "" { + ipToHostname[currentIP] = currentHostname + macToHostname[currentMAC.String()] = currentHostname + } + } else { + delete(ipToMAC, currentIP) + delete(ipToHostname, currentIP) + } + inLease = false + continue + } + if !inLease { + continue + } + if strings.HasPrefix(line, "hardware ethernet ") { + macString := strings.TrimSuffix(strings.TrimPrefix(line, "hardware ethernet "), ";") + parsed, macErr := net.ParseMAC(macString) + if macErr == nil { + currentMAC = parsed + } + } else if strings.HasPrefix(line, "client-hostname ") { + hostname := strings.TrimSuffix(strings.TrimPrefix(line, "client-hostname "), ";") + hostname = strings.Trim(hostname, "\"") + if hostname != "" { + currentHostname = hostname + } + } else if strings.HasPrefix(line, "binding state ") { + state := strings.TrimSuffix(strings.TrimPrefix(line, "binding state "), ";") + currentActive = state == "active" + } + } +} + +func (r *neighborResolver) parseKeaCSV4(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + firstLine := true + for scanner.Scan() { + if firstLine { + firstLine = false + continue + } + fields := strings.Split(scanner.Text(), ",") + if len(fields) < 10 { + continue + } + if fields[9] != "0" { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) + if !addrOK { + continue + } + address = address.Unmap() + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + continue + } + ipToMAC[address] = mac + hostname := "" + if len(fields) > 8 { + hostname = fields[8] + } + if hostname != "" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + } +} + +func (r *neighborResolver) parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + firstLine := true + for scanner.Scan() { + if firstLine { + firstLine = false + continue + } + fields := strings.Split(scanner.Text(), ",") + if len(fields) < 14 { + continue + } + if fields[13] != "0" { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) + if !addrOK { + continue + } + address = address.Unmap() + var mac net.HardwareAddr + if fields[12] != "" { + mac, _ = net.ParseMAC(fields[12]) + } + if mac == nil { + duid, duidErr := hex.DecodeString(strings.ReplaceAll(fields[1], ":", "")) + if duidErr == nil { + mac, _ = extractMACFromDUID(duid) + } + } + hostname := "" + if len(fields) > 11 { + hostname = fields[11] + } + if mac != nil { + ipToMAC[address] = mac + } + if hostname != "" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } +} + +func extractMACFromDUID(duid []byte) (net.HardwareAddr, bool) { + if len(duid) < 4 { + return nil, false + } + duidType := binary.BigEndian.Uint16(duid[0:2]) + hwType := binary.BigEndian.Uint16(duid[2:4]) + if hwType != 1 { + return nil, false + } + switch duidType { + case 1: + if len(duid) < 14 { + return nil, false + } + return net.HardwareAddr(slices.Clone(duid[8:14])), true + case 3: + if len(duid) < 10 { + return nil, false + } + return net.HardwareAddr(slices.Clone(duid[4:10])), true + } + return nil, false +} + +func extractMACFromEUI64(address netip.Addr) (net.HardwareAddr, bool) { + if !address.Is6() { + return nil, false + } + b := address.As16() + if b[11] != 0xff || b[12] != 0xfe { + return nil, false + } + return net.HardwareAddr{b[8] ^ 0x02, b[9], b[10], b[13], b[14], b[15]}, true +} + +func parseDUID(s string) ([]byte, error) { + cleaned := strings.ReplaceAll(s, ":", "") + return hex.DecodeString(cleaned) +} diff --git a/route/neighbor_resolver_stub.go b/route/neighbor_resolver_stub.go new file mode 100644 index 0000000000..9288892a8d --- /dev/null +++ b/route/neighbor_resolver_stub.go @@ -0,0 +1,14 @@ +//go:build !linux + +package route + +import ( + "os" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/logger" +) + +func newNeighborResolver(_ logger.ContextLogger, _ []string) (adapter.NeighborResolver, error) { + return nil, os.ErrInvalid +} diff --git a/route/route.go b/route/route.go index cdd7ba2509..324b76829a 100644 --- a/route/route.go +++ b/route/route.go @@ -439,6 +439,23 @@ func (r *Router) matchRule( metadata.ProcessInfo = processInfo } } + if r.neighborResolver != nil && metadata.SourceMACAddress == nil && metadata.Source.Addr.IsValid() { + mac, macFound := r.neighborResolver.LookupMAC(metadata.Source.Addr) + if macFound { + metadata.SourceMACAddress = mac + } + hostname, hostnameFound := r.neighborResolver.LookupHostname(metadata.Source.Addr) + if hostnameFound { + metadata.SourceHostname = hostname + if macFound { + r.logger.InfoContext(ctx, "found neighbor: ", mac, ", hostname: ", hostname) + } else { + r.logger.InfoContext(ctx, "found neighbor hostname: ", hostname) + } + } else if macFound { + r.logger.InfoContext(ctx, "found neighbor: ", mac) + } + } if metadata.Destination.Addr.IsValid() && r.dnsTransport.FakeIP() != nil && r.dnsTransport.FakeIP().Store().Contains(metadata.Destination.Addr) { domain, loaded := r.dnsTransport.FakeIP().Store().Lookup(metadata.Destination.Addr) if !loaded { diff --git a/route/router.go b/route/router.go index 5c73cb1c9f..abc7ffa313 100644 --- a/route/router.go +++ b/route/router.go @@ -31,9 +31,12 @@ type Router struct { network adapter.NetworkManager rules []adapter.Rule needFindProcess bool + needFindNeighbor bool + leaseFiles []string ruleSets []adapter.RuleSet ruleSetMap map[string]adapter.RuleSet processSearcher process.Searcher + neighborResolver adapter.NeighborResolver pauseManager pause.Manager trackers []adapter.ConnectionTracker platformInterface adapter.PlatformInterface @@ -53,6 +56,8 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route rules: make([]adapter.Rule, 0, len(options.Rules)), ruleSetMap: make(map[string]adapter.RuleSet), needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess, + needFindNeighbor: hasRule(options.Rules, isNeighborRule) || hasDNSRule(dnsOptions.Rules, isNeighborDNSRule) || options.FindNeighbor, + leaseFiles: options.DHCPLeaseFiles, pauseManager: service.FromContext[pause.Manager](ctx), platformInterface: service.FromContext[adapter.PlatformInterface](ctx), } @@ -112,6 +117,7 @@ func (r *Router) Start(stage adapter.StartStage) error { } r.network.Initialize(r.ruleSets) needFindProcess := r.needFindProcess + needFindNeighbor := r.needFindNeighbor for _, ruleSet := range r.ruleSets { metadata := ruleSet.Metadata() if metadata.ContainsProcessRule { @@ -141,6 +147,24 @@ func (r *Router) Start(stage adapter.StartStage) error { } } } + r.needFindNeighbor = needFindNeighbor + if needFindNeighbor { + monitor.Start("initialize neighbor resolver") + resolver, err := newNeighborResolver(r.logger, r.leaseFiles) + monitor.Finish() + if err != nil { + if err != os.ErrInvalid { + r.logger.Warn(E.Cause(err, "create neighbor resolver")) + } + } else { + err = resolver.Start() + if err != nil { + r.logger.Warn(E.Cause(err, "start neighbor resolver")) + } else { + r.neighborResolver = resolver + } + } + } case adapter.StartStatePostStart: for i, rule := range r.rules { monitor.Start("initialize rule[", i, "]") @@ -172,6 +196,13 @@ func (r *Router) Start(stage adapter.StartStage) error { func (r *Router) Close() error { monitor := taskmonitor.New(r.logger, C.StopTimeout) var err error + if r.neighborResolver != nil { + monitor.Start("close neighbor resolver") + err = E.Append(err, r.neighborResolver.Close(), func(closeErr error) error { + return E.Cause(closeErr, "close neighbor resolver") + }) + monitor.Finish() + } for i, rule := range r.rules { monitor.Start("close rule[", i, "]") err = E.Append(err, rule.Close(), func(err error) error { @@ -206,6 +237,14 @@ func (r *Router) NeedFindProcess() bool { return r.needFindProcess } +func (r *Router) NeedFindNeighbor() bool { + return r.needFindNeighbor +} + +func (r *Router) NeighborResolver() adapter.NeighborResolver { + return r.neighborResolver +} + func (r *Router) ResetNetwork() { r.network.ResetNetwork() r.dns.ResetNetwork() diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index 202fb3b36d..7ffdd521cb 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -260,6 +260,16 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.SourceMACAddress) > 0 { + item := NewSourceMACAddressItem(options.SourceMACAddress) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceHostname) > 0 { + item := NewSourceHostnameItem(options.SourceHostname) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.PreferredBy) > 0 { item := NewPreferredByItem(ctx, options.PreferredBy) rule.items = append(rule.items, item) diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index 9235dd6fd9..957df8747d 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -261,6 +261,16 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.SourceMACAddress) > 0 { + item := NewSourceMACAddressItem(options.SourceMACAddress) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceHostname) > 0 { + item := NewSourceHostnameItem(options.SourceHostname) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.RuleSet) > 0 { //nolint:staticcheck if options.Deprecated_RulesetIPCIDRMatchSource { diff --git a/route/rule/rule_item_source_hostname.go b/route/rule/rule_item_source_hostname.go new file mode 100644 index 0000000000..0df11c8c8a --- /dev/null +++ b/route/rule/rule_item_source_hostname.go @@ -0,0 +1,42 @@ +package rule + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" +) + +var _ RuleItem = (*SourceHostnameItem)(nil) + +type SourceHostnameItem struct { + hostnames []string + hostnameMap map[string]bool +} + +func NewSourceHostnameItem(hostnameList []string) *SourceHostnameItem { + rule := &SourceHostnameItem{ + hostnames: hostnameList, + hostnameMap: make(map[string]bool), + } + for _, hostname := range hostnameList { + rule.hostnameMap[hostname] = true + } + return rule +} + +func (r *SourceHostnameItem) Match(metadata *adapter.InboundContext) bool { + if metadata.SourceHostname == "" { + return false + } + return r.hostnameMap[metadata.SourceHostname] +} + +func (r *SourceHostnameItem) String() string { + var description string + if len(r.hostnames) == 1 { + description = "source_hostname=" + r.hostnames[0] + } else { + description = "source_hostname=[" + strings.Join(r.hostnames, " ") + "]" + } + return description +} diff --git a/route/rule/rule_item_source_mac_address.go b/route/rule/rule_item_source_mac_address.go new file mode 100644 index 0000000000..feeadb1dbf --- /dev/null +++ b/route/rule/rule_item_source_mac_address.go @@ -0,0 +1,48 @@ +package rule + +import ( + "net" + "strings" + + "github.com/sagernet/sing-box/adapter" +) + +var _ RuleItem = (*SourceMACAddressItem)(nil) + +type SourceMACAddressItem struct { + addresses []string + addressMap map[string]bool +} + +func NewSourceMACAddressItem(addressList []string) *SourceMACAddressItem { + rule := &SourceMACAddressItem{ + addresses: addressList, + addressMap: make(map[string]bool), + } + for _, address := range addressList { + parsed, err := net.ParseMAC(address) + if err == nil { + rule.addressMap[parsed.String()] = true + } else { + rule.addressMap[address] = true + } + } + return rule +} + +func (r *SourceMACAddressItem) Match(metadata *adapter.InboundContext) bool { + if metadata.SourceMACAddress == nil { + return false + } + return r.addressMap[metadata.SourceMACAddress.String()] +} + +func (r *SourceMACAddressItem) String() string { + var description string + if len(r.addresses) == 1 { + description = "source_mac_address=" + r.addresses[0] + } else { + description = "source_mac_address=[" + strings.Join(r.addresses, " ") + "]" + } + return description +} diff --git a/route/rule_conds.go b/route/rule_conds.go index 55c4a058e2..22ce94fffd 100644 --- a/route/rule_conds.go +++ b/route/rule_conds.go @@ -45,6 +45,14 @@ func isProcessDNSRule(rule option.DefaultDNSRule) bool { return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.ProcessPathRegex) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 } +func isNeighborRule(rule option.DefaultRule) bool { + return len(rule.SourceMACAddress) > 0 || len(rule.SourceHostname) > 0 +} + +func isNeighborDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.SourceMACAddress) > 0 || len(rule.SourceHostname) > 0 +} + func isWIFIRule(rule option.DefaultRule) bool { return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0 } From f802668915df8503246d2b48cc1a87fd7b3b056d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 5 Mar 2026 00:15:37 +0800 Subject: [PATCH 02/96] Add Android support for MAC and hostname rule items --- adapter/neighbor.go | 10 ++ adapter/platform.go | 4 + experimental/libbox/config.go | 12 +++ experimental/libbox/neighbor.go | 135 +++++++++++++++++++++++++++ experimental/libbox/neighbor_stub.go | 24 +++++ experimental/libbox/platform.go | 6 ++ experimental/libbox/service.go | 37 ++++++++ route/neighbor_resolver_linux.go | 85 ++--------------- route/neighbor_resolver_parse.go | 50 ++++++++++ route/neighbor_resolver_platform.go | 84 +++++++++++++++++ route/neighbor_table_linux.go | 68 ++++++++++++++ route/router.go | 33 +++++-- 12 files changed, 462 insertions(+), 86 deletions(-) create mode 100644 experimental/libbox/neighbor.go create mode 100644 experimental/libbox/neighbor_stub.go create mode 100644 route/neighbor_resolver_parse.go create mode 100644 route/neighbor_resolver_platform.go create mode 100644 route/neighbor_table_linux.go diff --git a/adapter/neighbor.go b/adapter/neighbor.go index 920398f674..d917db5b7a 100644 --- a/adapter/neighbor.go +++ b/adapter/neighbor.go @@ -5,9 +5,19 @@ import ( "net/netip" ) +type NeighborEntry struct { + Address netip.Addr + MACAddress net.HardwareAddr + Hostname string +} + type NeighborResolver interface { LookupMAC(address netip.Addr) (net.HardwareAddr, bool) LookupHostname(address netip.Addr) (string, bool) Start() error Close() error } + +type NeighborUpdateListener interface { + UpdateNeighborTable(entries []NeighborEntry) +} diff --git a/adapter/platform.go b/adapter/platform.go index 95db93c646..12ab82a219 100644 --- a/adapter/platform.go +++ b/adapter/platform.go @@ -36,6 +36,10 @@ type PlatformInterface interface { UsePlatformNotification() bool SendNotification(notification *Notification) error + + UsePlatformNeighborResolver() bool + StartNeighborMonitor(listener NeighborUpdateListener) error + CloseNeighborMonitor(listener NeighborUpdateListener) error } type FindConnectionOwnerRequest struct { diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index 122425d293..54369bf770 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -144,6 +144,18 @@ func (s *platformInterfaceStub) SendNotification(notification *adapter.Notificat return nil } +func (s *platformInterfaceStub) UsePlatformNeighborResolver() bool { + return false +} + +func (s *platformInterfaceStub) StartNeighborMonitor(listener adapter.NeighborUpdateListener) error { + return os.ErrInvalid +} + +func (s *platformInterfaceStub) CloseNeighborMonitor(listener adapter.NeighborUpdateListener) error { + return nil +} + func (s *platformInterfaceStub) UsePlatformLocalDNSTransport() bool { return false } diff --git a/experimental/libbox/neighbor.go b/experimental/libbox/neighbor.go new file mode 100644 index 0000000000..b2ded5f7a1 --- /dev/null +++ b/experimental/libbox/neighbor.go @@ -0,0 +1,135 @@ +//go:build linux + +package libbox + +import ( + "net" + "net/netip" + "slices" + "time" + + "github.com/sagernet/sing-box/route" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type NeighborEntry struct { + Address string + MACAddress string + Hostname string +} + +type NeighborEntryIterator interface { + Next() *NeighborEntry + HasNext() bool +} + +type NeighborSubscription struct { + done chan struct{} +} + +func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { + entries, err := route.ReadNeighborEntries() + if err != nil { + return nil, E.Cause(err, "initial neighbor dump") + } + table := make(map[netip.Addr]net.HardwareAddr) + for _, entry := range entries { + table[entry.Address] = entry.MACAddress + } + listener.UpdateNeighborTable(tableToIterator(table)) + connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + Groups: 1 << (unix.RTNLGRP_NEIGH - 1), + }) + if err != nil { + return nil, E.Cause(err, "subscribe neighbor updates") + } + subscription := &NeighborSubscription{ + done: make(chan struct{}), + } + go subscription.loop(listener, connection, table) + return subscription, nil +} + +func (s *NeighborSubscription) Close() { + close(s.done) +} + +func (s *NeighborSubscription) loop(listener NeighborUpdateListener, connection *netlink.Conn, table map[netip.Addr]net.HardwareAddr) { + defer connection.Close() + for { + select { + case <-s.done: + return + default: + } + err := connection.SetReadDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + return + } + messages, err := connection.Receive() + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-s.done: + return + default: + } + continue + } + changed := false + for _, message := range messages { + address, mac, isDelete, ok := route.ParseNeighborMessage(message) + if !ok { + continue + } + if isDelete { + if _, exists := table[address]; exists { + delete(table, address) + changed = true + } + } else { + existing, exists := table[address] + if !exists || !slices.Equal(existing, mac) { + table[address] = mac + changed = true + } + } + } + if changed { + listener.UpdateNeighborTable(tableToIterator(table)) + } + } +} + +func tableToIterator(table map[netip.Addr]net.HardwareAddr) NeighborEntryIterator { + entries := make([]*NeighborEntry, 0, len(table)) + for address, mac := range table { + entries = append(entries, &NeighborEntry{ + Address: address.String(), + MACAddress: mac.String(), + }) + } + return &neighborEntryIterator{entries} +} + +type neighborEntryIterator struct { + entries []*NeighborEntry +} + +func (i *neighborEntryIterator) HasNext() bool { + return len(i.entries) > 0 +} + +func (i *neighborEntryIterator) Next() *NeighborEntry { + if len(i.entries) == 0 { + return nil + } + entry := i.entries[0] + i.entries = i.entries[1:] + return entry +} diff --git a/experimental/libbox/neighbor_stub.go b/experimental/libbox/neighbor_stub.go new file mode 100644 index 0000000000..95f6dc7d6f --- /dev/null +++ b/experimental/libbox/neighbor_stub.go @@ -0,0 +1,24 @@ +//go:build !linux + +package libbox + +import "os" + +type NeighborEntry struct { + Address string + MACAddress string + Hostname string +} + +type NeighborEntryIterator interface { + Next() *NeighborEntry + HasNext() bool +} + +type NeighborSubscription struct{} + +func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { + return nil, os.ErrInvalid +} + +func (s *NeighborSubscription) Close() {} diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index 63c54ccf2c..3b1b0f3204 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -21,6 +21,12 @@ type PlatformInterface interface { SystemCertificates() StringIterator ClearDNSCache() SendNotification(notification *Notification) error + StartNeighborMonitor(listener NeighborUpdateListener) error + CloseNeighborMonitor(listener NeighborUpdateListener) error +} + +type NeighborUpdateListener interface { + UpdateNeighborTable(entries NeighborEntryIterator) } type ConnectionOwner struct { diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 3a13f6d169..458d0c66c5 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -220,6 +220,43 @@ func (w *platformInterfaceWrapper) SendNotification(notification *adapter.Notifi return w.iif.SendNotification((*Notification)(notification)) } +func (w *platformInterfaceWrapper) UsePlatformNeighborResolver() bool { + return true +} + +func (w *platformInterfaceWrapper) StartNeighborMonitor(listener adapter.NeighborUpdateListener) error { + return w.iif.StartNeighborMonitor(&neighborUpdateListenerWrapper{listener: listener}) +} + +func (w *platformInterfaceWrapper) CloseNeighborMonitor(listener adapter.NeighborUpdateListener) error { + return w.iif.CloseNeighborMonitor(nil) +} + +type neighborUpdateListenerWrapper struct { + listener adapter.NeighborUpdateListener +} + +func (w *neighborUpdateListenerWrapper) UpdateNeighborTable(entries NeighborEntryIterator) { + var result []adapter.NeighborEntry + for entries.HasNext() { + entry := entries.Next() + address, err := netip.ParseAddr(entry.Address) + if err != nil { + continue + } + macAddress, err := net.ParseMAC(entry.MACAddress) + if err != nil { + continue + } + result = append(result, adapter.NeighborEntry{ + Address: address, + MACAddress: macAddress, + Hostname: entry.Hostname, + }) + } + w.listener.UpdateNeighborTable(result) +} + func AvailablePort(startPort int32) (int32, error) { for port := int(startPort); ; port++ { if port > 65535 { diff --git a/route/neighbor_resolver_linux.go b/route/neighbor_resolver_linux.go index 40db5766ad..111cc6f040 100644 --- a/route/neighbor_resolver_linux.go +++ b/route/neighbor_resolver_linux.go @@ -4,7 +4,6 @@ package route import ( "bufio" - "encoding/binary" "encoding/hex" "net" "net/netip" @@ -204,43 +203,17 @@ func (r *neighborResolver) subscribeNeighborUpdates() { continue } for _, message := range messages { - switch message.Header.Type { - case unix.RTM_NEWNEIGH: - var neighMessage rtnetlink.NeighMessage - unmarshalErr := neighMessage.UnmarshalBinary(message.Data) - if unmarshalErr != nil { - continue - } - if neighMessage.Attributes == nil { - continue - } - if neighMessage.Attributes.LLAddress == nil || len(neighMessage.Attributes.Address) == 0 { - continue - } - address, ok := netip.AddrFromSlice(neighMessage.Attributes.Address) - if !ok { - continue - } - r.access.Lock() - r.neighborIPToMAC[address] = slices.Clone(neighMessage.Attributes.LLAddress) - r.access.Unlock() - case unix.RTM_DELNEIGH: - var neighMessage rtnetlink.NeighMessage - unmarshalErr := neighMessage.UnmarshalBinary(message.Data) - if unmarshalErr != nil { - continue - } - if neighMessage.Attributes == nil || len(neighMessage.Attributes.Address) == 0 { - continue - } - address, ok := netip.AddrFromSlice(neighMessage.Attributes.Address) - if !ok { - continue - } - r.access.Lock() + address, mac, isDelete, ok := ParseNeighborMessage(message) + if !ok { + continue + } + r.access.Lock() + if isDelete { delete(r.neighborIPToMAC, address) - r.access.Unlock() + } else { + r.neighborIPToMAC[address] = mac } + r.access.Unlock() } } } @@ -554,43 +527,3 @@ func (r *neighborResolver) parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]ne } } } - -func extractMACFromDUID(duid []byte) (net.HardwareAddr, bool) { - if len(duid) < 4 { - return nil, false - } - duidType := binary.BigEndian.Uint16(duid[0:2]) - hwType := binary.BigEndian.Uint16(duid[2:4]) - if hwType != 1 { - return nil, false - } - switch duidType { - case 1: - if len(duid) < 14 { - return nil, false - } - return net.HardwareAddr(slices.Clone(duid[8:14])), true - case 3: - if len(duid) < 10 { - return nil, false - } - return net.HardwareAddr(slices.Clone(duid[4:10])), true - } - return nil, false -} - -func extractMACFromEUI64(address netip.Addr) (net.HardwareAddr, bool) { - if !address.Is6() { - return nil, false - } - b := address.As16() - if b[11] != 0xff || b[12] != 0xfe { - return nil, false - } - return net.HardwareAddr{b[8] ^ 0x02, b[9], b[10], b[13], b[14], b[15]}, true -} - -func parseDUID(s string) ([]byte, error) { - cleaned := strings.ReplaceAll(s, ":", "") - return hex.DecodeString(cleaned) -} diff --git a/route/neighbor_resolver_parse.go b/route/neighbor_resolver_parse.go new file mode 100644 index 0000000000..1979b7eabc --- /dev/null +++ b/route/neighbor_resolver_parse.go @@ -0,0 +1,50 @@ +package route + +import ( + "encoding/binary" + "encoding/hex" + "net" + "net/netip" + "slices" + "strings" +) + +func extractMACFromDUID(duid []byte) (net.HardwareAddr, bool) { + if len(duid) < 4 { + return nil, false + } + duidType := binary.BigEndian.Uint16(duid[0:2]) + hwType := binary.BigEndian.Uint16(duid[2:4]) + if hwType != 1 { + return nil, false + } + switch duidType { + case 1: + if len(duid) < 14 { + return nil, false + } + return net.HardwareAddr(slices.Clone(duid[8:14])), true + case 3: + if len(duid) < 10 { + return nil, false + } + return net.HardwareAddr(slices.Clone(duid[4:10])), true + } + return nil, false +} + +func extractMACFromEUI64(address netip.Addr) (net.HardwareAddr, bool) { + if !address.Is6() { + return nil, false + } + b := address.As16() + if b[11] != 0xff || b[12] != 0xfe { + return nil, false + } + return net.HardwareAddr{b[8] ^ 0x02, b[9], b[10], b[13], b[14], b[15]}, true +} + +func parseDUID(s string) ([]byte, error) { + cleaned := strings.ReplaceAll(s, ":", "") + return hex.DecodeString(cleaned) +} diff --git a/route/neighbor_resolver_platform.go b/route/neighbor_resolver_platform.go new file mode 100644 index 0000000000..ddb9a99592 --- /dev/null +++ b/route/neighbor_resolver_platform.go @@ -0,0 +1,84 @@ +package route + +import ( + "net" + "net/netip" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/logger" +) + +type platformNeighborResolver struct { + logger logger.ContextLogger + platform adapter.PlatformInterface + access sync.RWMutex + ipToMAC map[netip.Addr]net.HardwareAddr + ipToHostname map[netip.Addr]string + macToHostname map[string]string +} + +func newPlatformNeighborResolver(resolverLogger logger.ContextLogger, platform adapter.PlatformInterface) adapter.NeighborResolver { + return &platformNeighborResolver{ + logger: resolverLogger, + platform: platform, + ipToMAC: make(map[netip.Addr]net.HardwareAddr), + ipToHostname: make(map[netip.Addr]string), + macToHostname: make(map[string]string), + } +} + +func (r *platformNeighborResolver) Start() error { + return r.platform.StartNeighborMonitor(r) +} + +func (r *platformNeighborResolver) Close() error { + return r.platform.CloseNeighborMonitor(r) +} + +func (r *platformNeighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) { + r.access.RLock() + defer r.access.RUnlock() + mac, found := r.ipToMAC[address] + if found { + return mac, true + } + return extractMACFromEUI64(address) +} + +func (r *platformNeighborResolver) LookupHostname(address netip.Addr) (string, bool) { + r.access.RLock() + defer r.access.RUnlock() + hostname, found := r.ipToHostname[address] + if found { + return hostname, true + } + mac, found := r.ipToMAC[address] + if !found { + mac, found = extractMACFromEUI64(address) + } + if !found { + return "", false + } + hostname, found = r.macToHostname[mac.String()] + return hostname, found +} + +func (r *platformNeighborResolver) UpdateNeighborTable(entries []adapter.NeighborEntry) { + ipToMAC := make(map[netip.Addr]net.HardwareAddr) + ipToHostname := make(map[netip.Addr]string) + macToHostname := make(map[string]string) + for _, entry := range entries { + ipToMAC[entry.Address] = entry.MACAddress + if entry.Hostname != "" { + ipToHostname[entry.Address] = entry.Hostname + macToHostname[entry.MACAddress.String()] = entry.Hostname + } + } + r.access.Lock() + r.ipToMAC = ipToMAC + r.ipToHostname = ipToHostname + r.macToHostname = macToHostname + r.access.Unlock() + r.logger.Info("updated neighbor table: ", len(entries), " entries") +} diff --git a/route/neighbor_table_linux.go b/route/neighbor_table_linux.go new file mode 100644 index 0000000000..61a214fd3a --- /dev/null +++ b/route/neighbor_table_linux.go @@ -0,0 +1,68 @@ +//go:build linux + +package route + +import ( + "net" + "net/netip" + "slices" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/jsimonetti/rtnetlink" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +func ReadNeighborEntries() ([]adapter.NeighborEntry, error) { + connection, err := rtnetlink.Dial(nil) + if err != nil { + return nil, E.Cause(err, "dial rtnetlink") + } + defer connection.Close() + neighbors, err := connection.Neigh.List() + if err != nil { + return nil, E.Cause(err, "list neighbors") + } + var entries []adapter.NeighborEntry + for _, neighbor := range neighbors { + if neighbor.Attributes == nil { + continue + } + if neighbor.Attributes.LLAddress == nil || len(neighbor.Attributes.Address) == 0 { + continue + } + address, ok := netip.AddrFromSlice(neighbor.Attributes.Address) + if !ok { + continue + } + entries = append(entries, adapter.NeighborEntry{ + Address: address, + MACAddress: slices.Clone(neighbor.Attributes.LLAddress), + }) + } + return entries, nil +} + +func ParseNeighborMessage(message netlink.Message) (address netip.Addr, macAddress net.HardwareAddr, isDelete bool, ok bool) { + var neighMessage rtnetlink.NeighMessage + err := neighMessage.UnmarshalBinary(message.Data) + if err != nil { + return + } + if neighMessage.Attributes == nil || len(neighMessage.Attributes.Address) == 0 { + return + } + address, ok = netip.AddrFromSlice(neighMessage.Attributes.Address) + if !ok { + return + } + isDelete = message.Header.Type == unix.RTM_DELNEIGH + if !isDelete && neighMessage.Attributes.LLAddress == nil { + ok = false + return + } + macAddress = slices.Clone(neighMessage.Attributes.LLAddress) + return +} diff --git a/route/router.go b/route/router.go index abc7ffa313..59eded3157 100644 --- a/route/router.go +++ b/route/router.go @@ -149,21 +149,34 @@ func (r *Router) Start(stage adapter.StartStage) error { } r.needFindNeighbor = needFindNeighbor if needFindNeighbor { - monitor.Start("initialize neighbor resolver") - resolver, err := newNeighborResolver(r.logger, r.leaseFiles) - monitor.Finish() - if err != nil { - if err != os.ErrInvalid { - r.logger.Warn(E.Cause(err, "create neighbor resolver")) - } - } else { - err = resolver.Start() + if r.platformInterface != nil && r.platformInterface.UsePlatformNeighborResolver() { + monitor.Start("initialize neighbor resolver") + resolver := newPlatformNeighborResolver(r.logger, r.platformInterface) + err := resolver.Start() + monitor.Finish() if err != nil { - r.logger.Warn(E.Cause(err, "start neighbor resolver")) + r.logger.Error(E.Cause(err, "start neighbor resolver")) } else { r.neighborResolver = resolver } } + if r.neighborResolver == nil { + monitor.Start("initialize neighbor resolver") + resolver, err := newNeighborResolver(r.logger, r.leaseFiles) + monitor.Finish() + if err != nil { + if err != os.ErrInvalid { + r.logger.Error(E.Cause(err, "create neighbor resolver")) + } + } else { + err = resolver.Start() + if err != nil { + r.logger.Error(E.Cause(err, "start neighbor resolver")) + } else { + r.neighborResolver = resolver + } + } + } } case adapter.StartStatePostStart: for i, rule := range r.rules { From 44d1c86b1bdf1812d73236eebaa3b9ddd043843f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 6 Mar 2026 08:47:37 +0800 Subject: [PATCH 03/96] Add macOS support for MAC and hostname rule items --- experimental/libbox/neighbor.go | 86 +----- experimental/libbox/neighbor_darwin.go | 123 ++++++++ experimental/libbox/neighbor_linux.go | 88 ++++++ experimental/libbox/neighbor_stub.go | 19 +- experimental/libbox/platform.go | 1 + experimental/libbox/service.go | 6 +- route/neighbor_resolver_darwin.go | 239 +++++++++++++++ route/neighbor_resolver_lease.go | 386 +++++++++++++++++++++++++ route/neighbor_resolver_linux.go | 313 +------------------- route/neighbor_resolver_stub.go | 2 +- route/neighbor_table_darwin.go | 104 +++++++ route/router.go | 3 +- 12 files changed, 956 insertions(+), 414 deletions(-) create mode 100644 experimental/libbox/neighbor_darwin.go create mode 100644 experimental/libbox/neighbor_linux.go create mode 100644 route/neighbor_resolver_darwin.go create mode 100644 route/neighbor_resolver_lease.go create mode 100644 route/neighbor_table_darwin.go diff --git a/experimental/libbox/neighbor.go b/experimental/libbox/neighbor.go index b2ded5f7a1..e38aa8023f 100644 --- a/experimental/libbox/neighbor.go +++ b/experimental/libbox/neighbor.go @@ -1,23 +1,13 @@ -//go:build linux - package libbox import ( "net" "net/netip" - "slices" - "time" - - "github.com/sagernet/sing-box/route" - E "github.com/sagernet/sing/common/exceptions" - - "github.com/mdlayher/netlink" - "golang.org/x/sys/unix" ) type NeighborEntry struct { Address string - MACAddress string + MacAddress string Hostname string } @@ -30,88 +20,16 @@ type NeighborSubscription struct { done chan struct{} } -func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { - entries, err := route.ReadNeighborEntries() - if err != nil { - return nil, E.Cause(err, "initial neighbor dump") - } - table := make(map[netip.Addr]net.HardwareAddr) - for _, entry := range entries { - table[entry.Address] = entry.MACAddress - } - listener.UpdateNeighborTable(tableToIterator(table)) - connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ - Groups: 1 << (unix.RTNLGRP_NEIGH - 1), - }) - if err != nil { - return nil, E.Cause(err, "subscribe neighbor updates") - } - subscription := &NeighborSubscription{ - done: make(chan struct{}), - } - go subscription.loop(listener, connection, table) - return subscription, nil -} - func (s *NeighborSubscription) Close() { close(s.done) } -func (s *NeighborSubscription) loop(listener NeighborUpdateListener, connection *netlink.Conn, table map[netip.Addr]net.HardwareAddr) { - defer connection.Close() - for { - select { - case <-s.done: - return - default: - } - err := connection.SetReadDeadline(time.Now().Add(3 * time.Second)) - if err != nil { - return - } - messages, err := connection.Receive() - if err != nil { - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - continue - } - select { - case <-s.done: - return - default: - } - continue - } - changed := false - for _, message := range messages { - address, mac, isDelete, ok := route.ParseNeighborMessage(message) - if !ok { - continue - } - if isDelete { - if _, exists := table[address]; exists { - delete(table, address) - changed = true - } - } else { - existing, exists := table[address] - if !exists || !slices.Equal(existing, mac) { - table[address] = mac - changed = true - } - } - } - if changed { - listener.UpdateNeighborTable(tableToIterator(table)) - } - } -} - func tableToIterator(table map[netip.Addr]net.HardwareAddr) NeighborEntryIterator { entries := make([]*NeighborEntry, 0, len(table)) for address, mac := range table { entries = append(entries, &NeighborEntry{ Address: address.String(), - MACAddress: mac.String(), + MacAddress: mac.String(), }) } return &neighborEntryIterator{entries} diff --git a/experimental/libbox/neighbor_darwin.go b/experimental/libbox/neighbor_darwin.go new file mode 100644 index 0000000000..d7484a69b4 --- /dev/null +++ b/experimental/libbox/neighbor_darwin.go @@ -0,0 +1,123 @@ +//go:build darwin + +package libbox + +import ( + "net" + "net/netip" + "os" + "slices" + "time" + + "github.com/sagernet/sing-box/route" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + + xroute "golang.org/x/net/route" + "golang.org/x/sys/unix" +) + +func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { + entries, err := route.ReadNeighborEntries() + if err != nil { + return nil, E.Cause(err, "initial neighbor dump") + } + table := make(map[netip.Addr]net.HardwareAddr) + for _, entry := range entries { + table[entry.Address] = entry.MACAddress + } + listener.UpdateNeighborTable(tableToIterator(table)) + routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0) + if err != nil { + return nil, E.Cause(err, "open route socket") + } + err = unix.SetNonblock(routeSocket, true) + if err != nil { + unix.Close(routeSocket) + return nil, E.Cause(err, "set route socket nonblock") + } + subscription := &NeighborSubscription{ + done: make(chan struct{}), + } + go subscription.loop(listener, routeSocket, table) + return subscription, nil +} + +func (s *NeighborSubscription) loop(listener NeighborUpdateListener, routeSocket int, table map[netip.Addr]net.HardwareAddr) { + routeSocketFile := os.NewFile(uintptr(routeSocket), "route") + defer routeSocketFile.Close() + buffer := buf.NewPacket() + defer buffer.Release() + for { + select { + case <-s.done: + return + default: + } + tv := unix.NsecToTimeval(int64(3 * time.Second)) + _ = unix.SetsockoptTimeval(routeSocket, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv) + n, err := routeSocketFile.Read(buffer.FreeBytes()) + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-s.done: + return + default: + } + continue + } + messages, err := xroute.ParseRIB(xroute.RIBTypeRoute, buffer.FreeBytes()[:n]) + if err != nil { + continue + } + changed := false + for _, message := range messages { + routeMessage, isRouteMessage := message.(*xroute.RouteMessage) + if !isRouteMessage { + continue + } + if routeMessage.Flags&unix.RTF_LLINFO == 0 { + continue + } + address, mac, isDelete, ok := route.ParseRouteNeighborMessage(routeMessage) + if !ok { + continue + } + if isDelete { + if _, exists := table[address]; exists { + delete(table, address) + changed = true + } + } else { + existing, exists := table[address] + if !exists || !slices.Equal(existing, mac) { + table[address] = mac + changed = true + } + } + } + if changed { + listener.UpdateNeighborTable(tableToIterator(table)) + } + } +} + +func ReadBootpdLeases() NeighborEntryIterator { + leaseIPToMAC, ipToHostname, macToHostname := route.ReloadLeaseFiles([]string{"/var/db/dhcpd_leases"}) + entries := make([]*NeighborEntry, 0, len(leaseIPToMAC)) + for address, mac := range leaseIPToMAC { + entry := &NeighborEntry{ + Address: address.String(), + MacAddress: mac.String(), + } + hostname, found := ipToHostname[address] + if !found { + hostname = macToHostname[mac.String()] + } + entry.Hostname = hostname + entries = append(entries, entry) + } + return &neighborEntryIterator{entries} +} diff --git a/experimental/libbox/neighbor_linux.go b/experimental/libbox/neighbor_linux.go new file mode 100644 index 0000000000..ae10bdd2ee --- /dev/null +++ b/experimental/libbox/neighbor_linux.go @@ -0,0 +1,88 @@ +//go:build linux + +package libbox + +import ( + "net" + "net/netip" + "slices" + "time" + + "github.com/sagernet/sing-box/route" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { + entries, err := route.ReadNeighborEntries() + if err != nil { + return nil, E.Cause(err, "initial neighbor dump") + } + table := make(map[netip.Addr]net.HardwareAddr) + for _, entry := range entries { + table[entry.Address] = entry.MACAddress + } + listener.UpdateNeighborTable(tableToIterator(table)) + connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + Groups: 1 << (unix.RTNLGRP_NEIGH - 1), + }) + if err != nil { + return nil, E.Cause(err, "subscribe neighbor updates") + } + subscription := &NeighborSubscription{ + done: make(chan struct{}), + } + go subscription.loop(listener, connection, table) + return subscription, nil +} + +func (s *NeighborSubscription) loop(listener NeighborUpdateListener, connection *netlink.Conn, table map[netip.Addr]net.HardwareAddr) { + defer connection.Close() + for { + select { + case <-s.done: + return + default: + } + err := connection.SetReadDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + return + } + messages, err := connection.Receive() + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-s.done: + return + default: + } + continue + } + changed := false + for _, message := range messages { + address, mac, isDelete, ok := route.ParseNeighborMessage(message) + if !ok { + continue + } + if isDelete { + if _, exists := table[address]; exists { + delete(table, address) + changed = true + } + } else { + existing, exists := table[address] + if !exists || !slices.Equal(existing, mac) { + table[address] = mac + changed = true + } + } + } + if changed { + listener.UpdateNeighborTable(tableToIterator(table)) + } + } +} diff --git a/experimental/libbox/neighbor_stub.go b/experimental/libbox/neighbor_stub.go index 95f6dc7d6f..d465bc7bb0 100644 --- a/experimental/libbox/neighbor_stub.go +++ b/experimental/libbox/neighbor_stub.go @@ -1,24 +1,9 @@ -//go:build !linux +//go:build !linux && !darwin package libbox import "os" -type NeighborEntry struct { - Address string - MACAddress string - Hostname string -} - -type NeighborEntryIterator interface { - Next() *NeighborEntry - HasNext() bool -} - -type NeighborSubscription struct{} - -func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { +func SubscribeNeighborTable(_ NeighborUpdateListener) (*NeighborSubscription, error) { return nil, os.ErrInvalid } - -func (s *NeighborSubscription) Close() {} diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index 3b1b0f3204..d2cac4cf68 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -23,6 +23,7 @@ type PlatformInterface interface { SendNotification(notification *Notification) error StartNeighborMonitor(listener NeighborUpdateListener) error CloseNeighborMonitor(listener NeighborUpdateListener) error + RegisterMyInterface(name string) } type NeighborUpdateListener interface { diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 458d0c66c5..b521f0f8e9 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -78,6 +78,7 @@ func (w *platformInterfaceWrapper) OpenInterface(options *tun.Options, platformO } options.FileDescriptor = dupFd w.myTunName = options.Name + w.iif.RegisterMyInterface(options.Name) return tun.New(*options) } @@ -240,11 +241,14 @@ func (w *neighborUpdateListenerWrapper) UpdateNeighborTable(entries NeighborEntr var result []adapter.NeighborEntry for entries.HasNext() { entry := entries.Next() + if entry == nil { + continue + } address, err := netip.ParseAddr(entry.Address) if err != nil { continue } - macAddress, err := net.ParseMAC(entry.MACAddress) + macAddress, err := net.ParseMAC(entry.MacAddress) if err != nil { continue } diff --git a/route/neighbor_resolver_darwin.go b/route/neighbor_resolver_darwin.go new file mode 100644 index 0000000000..a8884ae628 --- /dev/null +++ b/route/neighbor_resolver_darwin.go @@ -0,0 +1,239 @@ +//go:build darwin + +package route + +import ( + "net" + "net/netip" + "os" + "sync" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + + "golang.org/x/net/route" + "golang.org/x/sys/unix" +) + +var defaultLeaseFiles = []string{ + "/var/db/dhcpd_leases", + "/tmp/dhcp.leases", +} + +type neighborResolver struct { + logger logger.ContextLogger + leaseFiles []string + access sync.RWMutex + neighborIPToMAC map[netip.Addr]net.HardwareAddr + leaseIPToMAC map[netip.Addr]net.HardwareAddr + ipToHostname map[netip.Addr]string + macToHostname map[string]string + watcher *fswatch.Watcher + done chan struct{} +} + +func newNeighborResolver(resolverLogger logger.ContextLogger, leaseFiles []string) (adapter.NeighborResolver, error) { + if len(leaseFiles) == 0 { + for _, path := range defaultLeaseFiles { + info, err := os.Stat(path) + if err == nil && info.Size() > 0 { + leaseFiles = append(leaseFiles, path) + } + } + } + return &neighborResolver{ + logger: resolverLogger, + leaseFiles: leaseFiles, + neighborIPToMAC: make(map[netip.Addr]net.HardwareAddr), + leaseIPToMAC: make(map[netip.Addr]net.HardwareAddr), + ipToHostname: make(map[netip.Addr]string), + macToHostname: make(map[string]string), + done: make(chan struct{}), + }, nil +} + +func (r *neighborResolver) Start() error { + err := r.loadNeighborTable() + if err != nil { + r.logger.Warn(E.Cause(err, "load neighbor table")) + } + r.doReloadLeaseFiles() + go r.subscribeNeighborUpdates() + if len(r.leaseFiles) > 0 { + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: r.leaseFiles, + Logger: r.logger, + Callback: func(_ string) { + r.doReloadLeaseFiles() + }, + }) + if err != nil { + r.logger.Warn(E.Cause(err, "create lease file watcher")) + } else { + r.watcher = watcher + err = watcher.Start() + if err != nil { + r.logger.Warn(E.Cause(err, "start lease file watcher")) + } + } + } + return nil +} + +func (r *neighborResolver) Close() error { + close(r.done) + if r.watcher != nil { + return r.watcher.Close() + } + return nil +} + +func (r *neighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) { + r.access.RLock() + defer r.access.RUnlock() + mac, found := r.neighborIPToMAC[address] + if found { + return mac, true + } + mac, found = r.leaseIPToMAC[address] + if found { + return mac, true + } + mac, found = extractMACFromEUI64(address) + if found { + return mac, true + } + return nil, false +} + +func (r *neighborResolver) LookupHostname(address netip.Addr) (string, bool) { + r.access.RLock() + defer r.access.RUnlock() + hostname, found := r.ipToHostname[address] + if found { + return hostname, true + } + mac, macFound := r.neighborIPToMAC[address] + if !macFound { + mac, macFound = r.leaseIPToMAC[address] + } + if !macFound { + mac, macFound = extractMACFromEUI64(address) + } + if macFound { + hostname, found = r.macToHostname[mac.String()] + if found { + return hostname, true + } + } + return "", false +} + +func (r *neighborResolver) loadNeighborTable() error { + entries, err := ReadNeighborEntries() + if err != nil { + return err + } + r.access.Lock() + defer r.access.Unlock() + for _, entry := range entries { + r.neighborIPToMAC[entry.Address] = entry.MACAddress + } + return nil +} + +func (r *neighborResolver) subscribeNeighborUpdates() { + routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0) + if err != nil { + r.logger.Warn(E.Cause(err, "subscribe neighbor updates")) + return + } + err = unix.SetNonblock(routeSocket, true) + if err != nil { + unix.Close(routeSocket) + r.logger.Warn(E.Cause(err, "set route socket nonblock")) + return + } + routeSocketFile := os.NewFile(uintptr(routeSocket), "route") + defer routeSocketFile.Close() + buffer := buf.NewPacket() + defer buffer.Release() + for { + select { + case <-r.done: + return + default: + } + err = setReadDeadline(routeSocketFile, 3*time.Second) + if err != nil { + r.logger.Warn(E.Cause(err, "set route socket read deadline")) + return + } + n, err := routeSocketFile.Read(buffer.FreeBytes()) + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-r.done: + return + default: + } + r.logger.Warn(E.Cause(err, "receive neighbor update")) + continue + } + messages, err := route.ParseRIB(route.RIBTypeRoute, buffer.FreeBytes()[:n]) + if err != nil { + continue + } + for _, message := range messages { + routeMessage, isRouteMessage := message.(*route.RouteMessage) + if !isRouteMessage { + continue + } + if routeMessage.Flags&unix.RTF_LLINFO == 0 { + continue + } + address, mac, isDelete, ok := ParseRouteNeighborMessage(routeMessage) + if !ok { + continue + } + r.access.Lock() + if isDelete { + delete(r.neighborIPToMAC, address) + } else { + r.neighborIPToMAC[address] = mac + } + r.access.Unlock() + } + } +} + +func (r *neighborResolver) doReloadLeaseFiles() { + leaseIPToMAC, ipToHostname, macToHostname := ReloadLeaseFiles(r.leaseFiles) + r.access.Lock() + r.leaseIPToMAC = leaseIPToMAC + r.ipToHostname = ipToHostname + r.macToHostname = macToHostname + r.access.Unlock() +} + +func setReadDeadline(file *os.File, timeout time.Duration) error { + rawConn, err := file.SyscallConn() + if err != nil { + return err + } + var controlErr error + err = rawConn.Control(func(fd uintptr) { + tv := unix.NsecToTimeval(int64(timeout)) + controlErr = unix.SetsockoptTimeval(int(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv) + }) + if err != nil { + return err + } + return controlErr +} diff --git a/route/neighbor_resolver_lease.go b/route/neighbor_resolver_lease.go new file mode 100644 index 0000000000..e3f9c0b464 --- /dev/null +++ b/route/neighbor_resolver_lease.go @@ -0,0 +1,386 @@ +package route + +import ( + "bufio" + "encoding/hex" + "net" + "net/netip" + "os" + "strconv" + "strings" + "time" +) + +func parseLeaseFile(path string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + file, err := os.Open(path) + if err != nil { + return + } + defer file.Close() + if strings.HasSuffix(path, "dhcpd_leases") { + parseBootpdLeases(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "kea-leases4.csv") { + parseKeaCSV4(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "kea-leases6.csv") { + parseKeaCSV6(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "dhcpd.leases") { + parseISCDhcpd(file, ipToMAC, ipToHostname, macToHostname) + return + } + parseDnsmasqOdhcpd(file, ipToMAC, ipToHostname, macToHostname) +} + +func ReloadLeaseFiles(leaseFiles []string) (leaseIPToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + leaseIPToMAC = make(map[netip.Addr]net.HardwareAddr) + ipToHostname = make(map[netip.Addr]string) + macToHostname = make(map[string]string) + for _, path := range leaseFiles { + parseLeaseFile(path, leaseIPToMAC, ipToHostname, macToHostname) + } + return +} + +func parseDnsmasqOdhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + now := time.Now().Unix() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "duid ") { + continue + } + if strings.HasPrefix(line, "# ") { + parseOdhcpdLine(line[2:], ipToMAC, ipToHostname, macToHostname) + continue + } + fields := strings.Fields(line) + if len(fields) < 4 { + continue + } + expiry, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + continue + } + if expiry != 0 && expiry < now { + continue + } + if strings.Contains(fields[1], ":") { + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) + if !addrOK { + continue + } + address = address.Unmap() + ipToMAC[address] = mac + hostname := fields[3] + if hostname != "*" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + } else { + var mac net.HardwareAddr + if len(fields) >= 5 { + duid, duidErr := parseDUID(fields[4]) + if duidErr == nil { + mac, _ = extractMACFromDUID(duid) + } + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) + if !addrOK { + continue + } + address = address.Unmap() + if mac != nil { + ipToMAC[address] = mac + } + hostname := fields[3] + if hostname != "*" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } + } +} + +func parseOdhcpdLine(line string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + fields := strings.Fields(line) + if len(fields) < 5 { + return + } + validTime, err := strconv.ParseInt(fields[4], 10, 64) + if err != nil { + return + } + if validTime == 0 { + return + } + if validTime > 0 && validTime < time.Now().Unix() { + return + } + hostname := fields[3] + if hostname == "-" || strings.HasPrefix(hostname, `broken\x20`) { + hostname = "" + } + if len(fields) >= 8 && fields[2] == "ipv4" { + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + return + } + addressField := fields[7] + slashIndex := strings.IndexByte(addressField, '/') + if slashIndex >= 0 { + addressField = addressField[:slashIndex] + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) + if !addrOK { + return + } + address = address.Unmap() + ipToMAC[address] = mac + if hostname != "" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + return + } + var mac net.HardwareAddr + duidHex := fields[1] + duidBytes, hexErr := hex.DecodeString(duidHex) + if hexErr == nil { + mac, _ = extractMACFromDUID(duidBytes) + } + for i := 7; i < len(fields); i++ { + addressField := fields[i] + slashIndex := strings.IndexByte(addressField, '/') + if slashIndex >= 0 { + addressField = addressField[:slashIndex] + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) + if !addrOK { + continue + } + address = address.Unmap() + if mac != nil { + ipToMAC[address] = mac + } + if hostname != "" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } +} + +func parseISCDhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + var currentIP netip.Addr + var currentMAC net.HardwareAddr + var currentHostname string + var currentActive bool + var inLease bool + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "lease ") && strings.HasSuffix(line, "{") { + ipString := strings.TrimSuffix(strings.TrimPrefix(line, "lease "), " {") + parsed, addrOK := netip.AddrFromSlice(net.ParseIP(ipString)) + if addrOK { + currentIP = parsed.Unmap() + inLease = true + currentMAC = nil + currentHostname = "" + currentActive = false + } + continue + } + if line == "}" && inLease { + if currentActive && currentMAC != nil { + ipToMAC[currentIP] = currentMAC + if currentHostname != "" { + ipToHostname[currentIP] = currentHostname + macToHostname[currentMAC.String()] = currentHostname + } + } else { + delete(ipToMAC, currentIP) + delete(ipToHostname, currentIP) + } + inLease = false + continue + } + if !inLease { + continue + } + if strings.HasPrefix(line, "hardware ethernet ") { + macString := strings.TrimSuffix(strings.TrimPrefix(line, "hardware ethernet "), ";") + parsed, macErr := net.ParseMAC(macString) + if macErr == nil { + currentMAC = parsed + } + } else if strings.HasPrefix(line, "client-hostname ") { + hostname := strings.TrimSuffix(strings.TrimPrefix(line, "client-hostname "), ";") + hostname = strings.Trim(hostname, "\"") + if hostname != "" { + currentHostname = hostname + } + } else if strings.HasPrefix(line, "binding state ") { + state := strings.TrimSuffix(strings.TrimPrefix(line, "binding state "), ";") + currentActive = state == "active" + } + } +} + +func parseKeaCSV4(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + firstLine := true + for scanner.Scan() { + if firstLine { + firstLine = false + continue + } + fields := strings.Split(scanner.Text(), ",") + if len(fields) < 10 { + continue + } + if fields[9] != "0" { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) + if !addrOK { + continue + } + address = address.Unmap() + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + continue + } + ipToMAC[address] = mac + hostname := "" + if len(fields) > 8 { + hostname = fields[8] + } + if hostname != "" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + } +} + +func parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + firstLine := true + for scanner.Scan() { + if firstLine { + firstLine = false + continue + } + fields := strings.Split(scanner.Text(), ",") + if len(fields) < 14 { + continue + } + if fields[13] != "0" { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) + if !addrOK { + continue + } + address = address.Unmap() + var mac net.HardwareAddr + if fields[12] != "" { + mac, _ = net.ParseMAC(fields[12]) + } + if mac == nil { + duid, duidErr := hex.DecodeString(strings.ReplaceAll(fields[1], ":", "")) + if duidErr == nil { + mac, _ = extractMACFromDUID(duid) + } + } + hostname := "" + if len(fields) > 11 { + hostname = fields[11] + } + if mac != nil { + ipToMAC[address] = mac + } + if hostname != "" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } +} + +func parseBootpdLeases(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + now := time.Now().Unix() + scanner := bufio.NewScanner(file) + var currentName string + var currentIP netip.Addr + var currentMAC net.HardwareAddr + var currentLease int64 + var inBlock bool + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "{" { + inBlock = true + currentName = "" + currentIP = netip.Addr{} + currentMAC = nil + currentLease = 0 + continue + } + if line == "}" && inBlock { + if currentMAC != nil && currentIP.IsValid() { + if currentLease == 0 || currentLease >= now { + ipToMAC[currentIP] = currentMAC + if currentName != "" { + ipToHostname[currentIP] = currentName + macToHostname[currentMAC.String()] = currentName + } + } + } + inBlock = false + continue + } + if !inBlock { + continue + } + key, value, found := strings.Cut(line, "=") + if !found { + continue + } + switch key { + case "name": + currentName = value + case "ip_address": + parsed, addrOK := netip.AddrFromSlice(net.ParseIP(value)) + if addrOK { + currentIP = parsed.Unmap() + } + case "hw_address": + typeAndMAC, hasSep := strings.CutPrefix(value, "1,") + if hasSep { + mac, macErr := net.ParseMAC(typeAndMAC) + if macErr == nil { + currentMAC = mac + } + } + case "lease": + leaseHex := strings.TrimPrefix(value, "0x") + parsed, parseErr := strconv.ParseInt(leaseHex, 16, 64) + if parseErr == nil { + currentLease = parsed + } + } + } +} diff --git a/route/neighbor_resolver_linux.go b/route/neighbor_resolver_linux.go index 111cc6f040..b7991b4c89 100644 --- a/route/neighbor_resolver_linux.go +++ b/route/neighbor_resolver_linux.go @@ -3,14 +3,10 @@ package route import ( - "bufio" - "encoding/hex" "net" "net/netip" "os" "slices" - "strconv" - "strings" "sync" "time" @@ -69,14 +65,14 @@ func (r *neighborResolver) Start() error { if err != nil { r.logger.Warn(E.Cause(err, "load neighbor table")) } - r.reloadLeaseFiles() + r.doReloadLeaseFiles() go r.subscribeNeighborUpdates() if len(r.leaseFiles) > 0 { watcher, err := fswatch.NewWatcher(fswatch.Options{ Path: r.leaseFiles, Logger: r.logger, Callback: func(_ string) { - r.reloadLeaseFiles() + r.doReloadLeaseFiles() }, }) if err != nil { @@ -218,312 +214,11 @@ func (r *neighborResolver) subscribeNeighborUpdates() { } } -func (r *neighborResolver) reloadLeaseFiles() { - leaseIPToMAC := make(map[netip.Addr]net.HardwareAddr) - ipToHostname := make(map[netip.Addr]string) - macToHostname := make(map[string]string) - for _, path := range r.leaseFiles { - r.parseLeaseFile(path, leaseIPToMAC, ipToHostname, macToHostname) - } +func (r *neighborResolver) doReloadLeaseFiles() { + leaseIPToMAC, ipToHostname, macToHostname := ReloadLeaseFiles(r.leaseFiles) r.access.Lock() r.leaseIPToMAC = leaseIPToMAC r.ipToHostname = ipToHostname r.macToHostname = macToHostname r.access.Unlock() } - -func (r *neighborResolver) parseLeaseFile(path string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - file, err := os.Open(path) - if err != nil { - return - } - defer file.Close() - if strings.HasSuffix(path, "kea-leases4.csv") { - r.parseKeaCSV4(file, ipToMAC, ipToHostname, macToHostname) - return - } - if strings.HasSuffix(path, "kea-leases6.csv") { - r.parseKeaCSV6(file, ipToMAC, ipToHostname, macToHostname) - return - } - if strings.HasSuffix(path, "dhcpd.leases") { - r.parseISCDhcpd(file, ipToMAC, ipToHostname, macToHostname) - return - } - r.parseDnsmasqOdhcpd(file, ipToMAC, ipToHostname, macToHostname) -} - -func (r *neighborResolver) parseDnsmasqOdhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - now := time.Now().Unix() - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "duid ") { - continue - } - if strings.HasPrefix(line, "# ") { - r.parseOdhcpdLine(line[2:], ipToMAC, ipToHostname, macToHostname) - continue - } - fields := strings.Fields(line) - if len(fields) < 4 { - continue - } - expiry, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - continue - } - if expiry != 0 && expiry < now { - continue - } - if strings.Contains(fields[1], ":") { - mac, macErr := net.ParseMAC(fields[1]) - if macErr != nil { - continue - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) - if !addrOK { - continue - } - address = address.Unmap() - ipToMAC[address] = mac - hostname := fields[3] - if hostname != "*" { - ipToHostname[address] = hostname - macToHostname[mac.String()] = hostname - } - } else { - var mac net.HardwareAddr - if len(fields) >= 5 { - duid, duidErr := parseDUID(fields[4]) - if duidErr == nil { - mac, _ = extractMACFromDUID(duid) - } - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) - if !addrOK { - continue - } - address = address.Unmap() - if mac != nil { - ipToMAC[address] = mac - } - hostname := fields[3] - if hostname != "*" { - ipToHostname[address] = hostname - if mac != nil { - macToHostname[mac.String()] = hostname - } - } - } - } -} - -func (r *neighborResolver) parseOdhcpdLine(line string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - fields := strings.Fields(line) - if len(fields) < 5 { - return - } - validTime, err := strconv.ParseInt(fields[4], 10, 64) - if err != nil { - return - } - if validTime == 0 { - return - } - if validTime > 0 && validTime < time.Now().Unix() { - return - } - hostname := fields[3] - if hostname == "-" || strings.HasPrefix(hostname, `broken\x20`) { - hostname = "" - } - if len(fields) >= 8 && fields[2] == "ipv4" { - mac, macErr := net.ParseMAC(fields[1]) - if macErr != nil { - return - } - addressField := fields[7] - slashIndex := strings.IndexByte(addressField, '/') - if slashIndex >= 0 { - addressField = addressField[:slashIndex] - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) - if !addrOK { - return - } - address = address.Unmap() - ipToMAC[address] = mac - if hostname != "" { - ipToHostname[address] = hostname - macToHostname[mac.String()] = hostname - } - return - } - var mac net.HardwareAddr - duidHex := fields[1] - duidBytes, hexErr := hex.DecodeString(duidHex) - if hexErr == nil { - mac, _ = extractMACFromDUID(duidBytes) - } - for i := 7; i < len(fields); i++ { - addressField := fields[i] - slashIndex := strings.IndexByte(addressField, '/') - if slashIndex >= 0 { - addressField = addressField[:slashIndex] - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) - if !addrOK { - continue - } - address = address.Unmap() - if mac != nil { - ipToMAC[address] = mac - } - if hostname != "" { - ipToHostname[address] = hostname - if mac != nil { - macToHostname[mac.String()] = hostname - } - } - } -} - -func (r *neighborResolver) parseISCDhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - scanner := bufio.NewScanner(file) - var currentIP netip.Addr - var currentMAC net.HardwareAddr - var currentHostname string - var currentActive bool - var inLease bool - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if strings.HasPrefix(line, "lease ") && strings.HasSuffix(line, "{") { - ipString := strings.TrimSuffix(strings.TrimPrefix(line, "lease "), " {") - parsed, addrOK := netip.AddrFromSlice(net.ParseIP(ipString)) - if addrOK { - currentIP = parsed.Unmap() - inLease = true - currentMAC = nil - currentHostname = "" - currentActive = false - } - continue - } - if line == "}" && inLease { - if currentActive && currentMAC != nil { - ipToMAC[currentIP] = currentMAC - if currentHostname != "" { - ipToHostname[currentIP] = currentHostname - macToHostname[currentMAC.String()] = currentHostname - } - } else { - delete(ipToMAC, currentIP) - delete(ipToHostname, currentIP) - } - inLease = false - continue - } - if !inLease { - continue - } - if strings.HasPrefix(line, "hardware ethernet ") { - macString := strings.TrimSuffix(strings.TrimPrefix(line, "hardware ethernet "), ";") - parsed, macErr := net.ParseMAC(macString) - if macErr == nil { - currentMAC = parsed - } - } else if strings.HasPrefix(line, "client-hostname ") { - hostname := strings.TrimSuffix(strings.TrimPrefix(line, "client-hostname "), ";") - hostname = strings.Trim(hostname, "\"") - if hostname != "" { - currentHostname = hostname - } - } else if strings.HasPrefix(line, "binding state ") { - state := strings.TrimSuffix(strings.TrimPrefix(line, "binding state "), ";") - currentActive = state == "active" - } - } -} - -func (r *neighborResolver) parseKeaCSV4(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - scanner := bufio.NewScanner(file) - firstLine := true - for scanner.Scan() { - if firstLine { - firstLine = false - continue - } - fields := strings.Split(scanner.Text(), ",") - if len(fields) < 10 { - continue - } - if fields[9] != "0" { - continue - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) - if !addrOK { - continue - } - address = address.Unmap() - mac, macErr := net.ParseMAC(fields[1]) - if macErr != nil { - continue - } - ipToMAC[address] = mac - hostname := "" - if len(fields) > 8 { - hostname = fields[8] - } - if hostname != "" { - ipToHostname[address] = hostname - macToHostname[mac.String()] = hostname - } - } -} - -func (r *neighborResolver) parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - scanner := bufio.NewScanner(file) - firstLine := true - for scanner.Scan() { - if firstLine { - firstLine = false - continue - } - fields := strings.Split(scanner.Text(), ",") - if len(fields) < 14 { - continue - } - if fields[13] != "0" { - continue - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) - if !addrOK { - continue - } - address = address.Unmap() - var mac net.HardwareAddr - if fields[12] != "" { - mac, _ = net.ParseMAC(fields[12]) - } - if mac == nil { - duid, duidErr := hex.DecodeString(strings.ReplaceAll(fields[1], ":", "")) - if duidErr == nil { - mac, _ = extractMACFromDUID(duid) - } - } - hostname := "" - if len(fields) > 11 { - hostname = fields[11] - } - if mac != nil { - ipToMAC[address] = mac - } - if hostname != "" { - ipToHostname[address] = hostname - if mac != nil { - macToHostname[mac.String()] = hostname - } - } - } -} diff --git a/route/neighbor_resolver_stub.go b/route/neighbor_resolver_stub.go index 9288892a8d..177a1fccbc 100644 --- a/route/neighbor_resolver_stub.go +++ b/route/neighbor_resolver_stub.go @@ -1,4 +1,4 @@ -//go:build !linux +//go:build !linux && !darwin package route diff --git a/route/neighbor_table_darwin.go b/route/neighbor_table_darwin.go new file mode 100644 index 0000000000..8ca2d0f0b7 --- /dev/null +++ b/route/neighbor_table_darwin.go @@ -0,0 +1,104 @@ +//go:build darwin + +package route + +import ( + "net" + "net/netip" + "syscall" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/net/route" + "golang.org/x/sys/unix" +) + +func ReadNeighborEntries() ([]adapter.NeighborEntry, error) { + var entries []adapter.NeighborEntry + ipv4Entries, err := readNeighborEntriesAF(syscall.AF_INET) + if err != nil { + return nil, E.Cause(err, "read IPv4 neighbors") + } + entries = append(entries, ipv4Entries...) + ipv6Entries, err := readNeighborEntriesAF(syscall.AF_INET6) + if err != nil { + return nil, E.Cause(err, "read IPv6 neighbors") + } + entries = append(entries, ipv6Entries...) + return entries, nil +} + +func readNeighborEntriesAF(addressFamily int) ([]adapter.NeighborEntry, error) { + rib, err := route.FetchRIB(addressFamily, route.RIBType(syscall.NET_RT_FLAGS), syscall.RTF_LLINFO) + if err != nil { + return nil, err + } + messages, err := route.ParseRIB(route.RIBType(syscall.NET_RT_FLAGS), rib) + if err != nil { + return nil, err + } + var entries []adapter.NeighborEntry + for _, message := range messages { + routeMessage, isRouteMessage := message.(*route.RouteMessage) + if !isRouteMessage { + continue + } + address, macAddress, ok := parseRouteNeighborEntry(routeMessage) + if !ok { + continue + } + entries = append(entries, adapter.NeighborEntry{ + Address: address, + MACAddress: macAddress, + }) + } + return entries, nil +} + +func parseRouteNeighborEntry(message *route.RouteMessage) (address netip.Addr, macAddress net.HardwareAddr, ok bool) { + if len(message.Addrs) <= unix.RTAX_GATEWAY { + return + } + gateway, isLinkAddr := message.Addrs[unix.RTAX_GATEWAY].(*route.LinkAddr) + if !isLinkAddr || len(gateway.Addr) < 6 { + return + } + switch destination := message.Addrs[unix.RTAX_DST].(type) { + case *route.Inet4Addr: + address = netip.AddrFrom4(destination.IP) + case *route.Inet6Addr: + address = netip.AddrFrom16(destination.IP) + default: + return + } + macAddress = net.HardwareAddr(make([]byte, len(gateway.Addr))) + copy(macAddress, gateway.Addr) + ok = true + return +} + +func ParseRouteNeighborMessage(message *route.RouteMessage) (address netip.Addr, macAddress net.HardwareAddr, isDelete bool, ok bool) { + isDelete = message.Type == unix.RTM_DELETE + if len(message.Addrs) <= unix.RTAX_GATEWAY { + return + } + switch destination := message.Addrs[unix.RTAX_DST].(type) { + case *route.Inet4Addr: + address = netip.AddrFrom4(destination.IP) + case *route.Inet6Addr: + address = netip.AddrFrom16(destination.IP) + default: + return + } + if !isDelete { + gateway, isLinkAddr := message.Addrs[unix.RTAX_GATEWAY].(*route.LinkAddr) + if !isLinkAddr || len(gateway.Addr) < 6 { + return + } + macAddress = net.HardwareAddr(make([]byte, len(gateway.Addr))) + copy(macAddress, gateway.Addr) + } + ok = true + return +} diff --git a/route/router.go b/route/router.go index 59eded3157..c141581d01 100644 --- a/route/router.go +++ b/route/router.go @@ -159,8 +159,7 @@ func (r *Router) Start(stage adapter.StartStage) error { } else { r.neighborResolver = resolver } - } - if r.neighborResolver == nil { + } else { monitor.Start("initialize neighbor resolver") resolver, err := newNeighborResolver(r.logger, r.leaseFiles) monitor.Finish() From 9462b1deeb4cd4ed2a1b3ab4a5d0502eed19a9f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 6 Mar 2026 21:43:21 +0800 Subject: [PATCH 04/96] documentation: Update descriptions for neighbor rules --- docs/configuration/dns/rule.md | 4 +- docs/configuration/dns/rule.zh.md | 4 +- docs/configuration/route/index.md | 17 ++++++-- docs/configuration/route/index.zh.md | 17 ++++++-- docs/configuration/route/rule.md | 4 +- docs/configuration/route/rule.zh.md | 4 +- docs/configuration/shared/neighbor.md | 49 ++++++++++++++++++++++++ docs/configuration/shared/neighbor.zh.md | 49 ++++++++++++++++++++++++ mkdocs.yml | 1 + 9 files changed, 133 insertions(+), 16 deletions(-) create mode 100644 docs/configuration/shared/neighbor.md create mode 100644 docs/configuration/shared/neighbor.zh.md diff --git a/docs/configuration/dns/rule.md b/docs/configuration/dns/rule.md index 262a23e629..97a4a7b3d5 100644 --- a/docs/configuration/dns/rule.md +++ b/docs/configuration/dns/rule.md @@ -425,7 +425,7 @@ Match default interface address. !!! quote "" - Only supported on Linux with `route.find_neighbor` enabled. + Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. Match source device MAC address. @@ -435,7 +435,7 @@ Match source device MAC address. !!! quote "" - Only supported on Linux with `route.find_neighbor` enabled. + Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. Match source device hostname from DHCP leases. diff --git a/docs/configuration/dns/rule.zh.md b/docs/configuration/dns/rule.zh.md index 4bf60b9862..e1288bb69e 100644 --- a/docs/configuration/dns/rule.zh.md +++ b/docs/configuration/dns/rule.zh.md @@ -424,7 +424,7 @@ Available values: `wifi`, `cellular`, `ethernet` and `other`. !!! quote "" - 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + 仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 匹配源设备 MAC 地址。 @@ -434,7 +434,7 @@ Available values: `wifi`, `cellular`, `ethernet` and `other`. !!! quote "" - 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + 仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 匹配源设备从 DHCP 租约获取的主机名。 diff --git a/docs/configuration/route/index.md b/docs/configuration/route/index.md index 01e405614e..40104b619e 100644 --- a/docs/configuration/route/index.md +++ b/docs/configuration/route/index.md @@ -40,6 +40,7 @@ icon: material/alert-decagram "override_android_vpn": false, "default_interface": "", "default_mark": 0, + "find_process": false, "find_neighbor": false, "dhcp_lease_files": [], "default_domain_resolver": "", // or {} @@ -114,17 +115,25 @@ Set routing mark by default. Takes no effect if `outbound.routing_mark` is set. +#### find_process + +!!! quote "" + + Only supported on Linux, Windows, and macOS. + +Enable process search for logging when no `process_name`, `process_path`, `package_name`, `user` or `user_id` rules exist. + #### find_neighbor !!! question "Since sing-box 1.14.0" !!! quote "" - Only supported on Linux. + Only supported on Linux and macOS. -Enable neighbor resolution for source MAC address and hostname lookup. +Enable neighbor resolution for logging when no `source_mac_address` or `source_hostname` rules exist. -Required for `source_mac_address` and `source_hostname` rule items. +See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. #### dhcp_lease_files @@ -132,7 +141,7 @@ Required for `source_mac_address` and `source_hostname` rule items. !!! quote "" - Only supported on Linux. + Only supported on Linux and macOS. Custom DHCP lease file paths for hostname and MAC address resolution. diff --git a/docs/configuration/route/index.zh.md b/docs/configuration/route/index.zh.md index 84ce76723c..518830b835 100644 --- a/docs/configuration/route/index.zh.md +++ b/docs/configuration/route/index.zh.md @@ -42,6 +42,7 @@ icon: material/alert-decagram "override_android_vpn": false, "default_interface": "", "default_mark": 0, + "find_process": false, "find_neighbor": false, "dhcp_lease_files": [], "default_network_strategy": "", @@ -113,17 +114,25 @@ icon: material/alert-decagram 如果设置了 `outbound.routing_mark` 设置,则不生效。 +#### find_process + +!!! quote "" + + 仅支持 Linux、Windows 和 macOS。 + +在没有 `process_name`、`process_path`、`package_name`、`user` 或 `user_id` 规则时启用进程搜索以输出日志。 + #### find_neighbor !!! question "自 sing-box 1.14.0 起" !!! quote "" - 仅支持 Linux。 + 仅支持 Linux 和 macOS。 -启用邻居解析以查找源 MAC 地址和主机名。 +在没有 `source_mac_address` 或 `source_hostname` 规则时启用邻居解析以输出日志。 -`source_mac_address` 和 `source_hostname` 规则项需要此选项。 +参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 #### dhcp_lease_files @@ -131,7 +140,7 @@ icon: material/alert-decagram !!! quote "" - 仅支持 Linux。 + 仅支持 Linux 和 macOS。 用于主机名和 MAC 地址解析的自定义 DHCP 租约文件路径。 diff --git a/docs/configuration/route/rule.md b/docs/configuration/route/rule.md index d226571096..767e9ef756 100644 --- a/docs/configuration/route/rule.md +++ b/docs/configuration/route/rule.md @@ -466,7 +466,7 @@ Match specified outbounds' preferred routes. !!! quote "" - Only supported on Linux with `route.find_neighbor` enabled. + Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. Match source device MAC address. @@ -476,7 +476,7 @@ Match source device MAC address. !!! quote "" - Only supported on Linux with `route.find_neighbor` enabled. + Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. Match source device hostname from DHCP leases. diff --git a/docs/configuration/route/rule.zh.md b/docs/configuration/route/rule.zh.md index 597e655f6e..e581ae995d 100644 --- a/docs/configuration/route/rule.zh.md +++ b/docs/configuration/route/rule.zh.md @@ -463,7 +463,7 @@ icon: material/new-box !!! quote "" - 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + 仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 匹配源设备 MAC 地址。 @@ -473,7 +473,7 @@ icon: material/new-box !!! quote "" - 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + 仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 匹配源设备从 DHCP 租约获取的主机名。 diff --git a/docs/configuration/shared/neighbor.md b/docs/configuration/shared/neighbor.md new file mode 100644 index 0000000000..c67d995ebe --- /dev/null +++ b/docs/configuration/shared/neighbor.md @@ -0,0 +1,49 @@ +--- +icon: material/lan +--- + +# Neighbor Resolution + +Match LAN devices by MAC address and hostname using +[`source_mac_address`](/configuration/route/rule/#source_mac_address) and +[`source_hostname`](/configuration/route/rule/#source_hostname) rule items. + +Neighbor resolution is automatically enabled when these rule items exist. +Use [`route.find_neighbor`](/configuration/route/#find_neighbor) to force enable it for logging without rules. + +## Linux + +Works natively. No special setup required. + +Hostname resolution requires DHCP lease files, +automatically detected from common DHCP servers (dnsmasq, odhcpd, ISC dhcpd, Kea). +Custom paths can be set via [`route.dhcp_lease_files`](/configuration/route/#dhcp_lease_files). + +## Android + +!!! quote "" + + Only supported in graphical clients. + +Requires Android 11 or above and ROOT. + +Must use [VPNHotspot](https://github.com/Mygod/VPNHotspot) to share the VPN connection. +ROM built-in features like "Use VPN for connected devices" can share VPN +but cannot provide MAC address or hostname information. + +Set **IP Masquerade Mode** to **None** in VPNHotspot settings. + +Only route/DNS rules are supported. TUN include/exclude routes are not supported. + +### Hostname Visibility + +Hostname is only visible in sing-box if it is visible in VPNHotspot. +For Apple devices, change **Private Wi-Fi Address** from **Rotating** to **Fixed** in the Wi-Fi settings +of the connected network. Non-Apple devices are always visible. + +## macOS + +Requires the standalone version (macOS system extension). +The App Store version can share the VPN as a hotspot but does not support MAC address or hostname reading. + +See [VPN Hotspot](/manual/misc/vpn-hotspot/#macos) for Internet Sharing setup. diff --git a/docs/configuration/shared/neighbor.zh.md b/docs/configuration/shared/neighbor.zh.md new file mode 100644 index 0000000000..96297fcb57 --- /dev/null +++ b/docs/configuration/shared/neighbor.zh.md @@ -0,0 +1,49 @@ +--- +icon: material/lan +--- + +# 邻居解析 + +通过 +[`source_mac_address`](/configuration/route/rule/#source_mac_address) 和 +[`source_hostname`](/configuration/route/rule/#source_hostname) 规则项匹配局域网设备的 MAC 地址和主机名。 + +当这些规则项存在时,邻居解析自动启用。 +使用 [`route.find_neighbor`](/configuration/route/#find_neighbor) 可在没有规则时强制启用以输出日志。 + +## Linux + +原生支持,无需特殊设置。 + +主机名解析需要 DHCP 租约文件, +自动从常见 DHCP 服务器(dnsmasq、odhcpd、ISC dhcpd、Kea)检测。 +可通过 [`route.dhcp_lease_files`](/configuration/route/#dhcp_lease_files) 设置自定义路径。 + +## Android + +!!! quote "" + + 仅在图形客户端中支持。 + +需要 Android 11 或以上版本和 ROOT。 + +必须使用 [VPNHotspot](https://github.com/Mygod/VPNHotspot) 共享 VPN 连接。 +ROM 自带的「通过 VPN 共享连接」等功能可以共享 VPN, +但无法提供 MAC 地址或主机名信息。 + +在 VPNHotspot 设置中将 **IP 遮掩模式** 设为 **无**。 + +仅支持路由/DNS 规则。不支持 TUN 的 include/exclude 路由。 + +### 设备可见性 + +MAC 地址和主机名仅在 VPNHotspot 中可见时 sing-box 才能读取。 +对于 Apple 设备,需要在所连接网络的 Wi-Fi 设置中将**私有无线局域网地址**从**轮替**改为**固定**。 +非 Apple 设备始终可见。 + +## macOS + +需要独立版本(macOS 系统扩展)。 +App Store 版本可以共享 VPN 热点但不支持 MAC 地址或主机名读取。 + +参阅 [VPN 热点](/manual/misc/vpn-hotspot/#macos) 了解互联网共享设置。 diff --git a/mkdocs.yml b/mkdocs.yml index 081ba3aa18..70edfaac43 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -129,6 +129,7 @@ nav: - UDP over TCP: configuration/shared/udp-over-tcp.md - TCP Brutal: configuration/shared/tcp-brutal.md - Wi-Fi State: configuration/shared/wifi-state.md + - Neighbor Resolution: configuration/shared/neighbor.md - Endpoint: - configuration/endpoint/index.md - WireGuard: configuration/endpoint/wireguard.md From bd0fb83d2d447287b7c2ecc7a627d437223d3733 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 11 Mar 2026 18:35:51 +0800 Subject: [PATCH 05/96] cronet-go: Update chromium to 145.0.7632.159 --- .github/CRONET_GO_VERSION | 2 +- go.mod | 4 ++-- go.sum | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/CRONET_GO_VERSION b/.github/CRONET_GO_VERSION index 47b09f9b6b..40dfcd0d14 100644 --- a/.github/CRONET_GO_VERSION +++ b/.github/CRONET_GO_VERSION @@ -1 +1 @@ -2fef65f9dba90ddb89a87d00a6eb6165487c10c1 +ea7cd33752aed62603775af3df946c1b83f4b0b3 diff --git a/go.mod b/go.mod index 32aa7061d7..dec9155f5a 100644 --- a/go.mod +++ b/go.mod @@ -29,8 +29,8 @@ require ( github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1 github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/cors v1.2.1 - github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9 - github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9 + github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc + github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc github.com/sagernet/fswatch v0.1.1 github.com/sagernet/gomobile v0.1.12 github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1 diff --git a/go.sum b/go.sum index dea19790c3..378117ec91 100644 --- a/go.sum +++ b/go.sum @@ -162,10 +162,10 @@ github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a h1:+NkI2670SQpQWvkk github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a/go.mod h1:63s7jpZqcDAIpj8oI/1v4Izok+npJOHACFCU6+huCkM= github.com/sagernet/cors v1.2.1 h1:Cv5Z8y9YSD6Gm+qSpNrL3LO4lD3eQVvbFYJSG7JCMHQ= github.com/sagernet/cors v1.2.1/go.mod h1:O64VyOjjhrkLmQIjF4KGRrJO/5dVXFdpEmCW/eISRAI= -github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9 h1:xq5Yr10jXEppD3cnGjE3WENaB6D0YsZu6KptZ8d3054= -github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9/go.mod h1:hwFHBEjjthyEquDULbr4c4ucMedp8Drb6Jvm2kt/0Bw= -github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9 h1:uxQyy6Y/boOuecVA66tf79JgtoRGfeDJcfYZZLKVA5E= -github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9/go.mod h1:Xm6cCvs0/twozC1JYNq0sVlOVmcSGzV7YON1XGcD97w= +github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc h1:YK7PwJT0irRAEui9ASdXSxcE2BOVQipWMF/A1Ogt+7c= +github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc/go.mod h1:hwFHBEjjthyEquDULbr4c4ucMedp8Drb6Jvm2kt/0Bw= +github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc h1:EJPHOqk23IuBsTjXK9OXqkNxPbKOBWKRmviQoCcriAs= +github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc/go.mod h1:8aty0RW96DrJSMWXO6bRPMBJEjuqq5JWiOIi4bCRzFA= github.com/sagernet/cronet-go/lib/android_386 v0.0.0-20260309101654-0cbdcfddded9 h1:Qi0IKBpoPP3qZqIXuOKMsT2dv+l/MLWMyBHDMLRw2EA= github.com/sagernet/cronet-go/lib/android_386 v0.0.0-20260309101654-0cbdcfddded9/go.mod h1:XXDwdjX/T8xftoeJxQmbBoYXZp8MAPFR2CwbFuTpEtw= github.com/sagernet/cronet-go/lib/android_amd64 v0.0.0-20260309101654-0cbdcfddded9 h1:p+wCMjOhj46SpSD/AJeTGgkCcbyA76FyH631XZatyU8= From a11cd1e0c63b5eef119948c266ffecfbd4919451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 7 Mar 2026 16:40:34 +0800 Subject: [PATCH 06/96] Bump version --- docs/changelog.md | 57 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 9aaba89474..b13966d616 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,10 @@ icon: material/alert-decagram --- +#### 1.14.0-alpha.3 + +* Fixes and improvements + #### 1.13.3 * Add OpenWrt and Alpine APK packages to release **1** @@ -26,6 +30,59 @@ from [SagerNet/go](https://github.com/SagerNet/go). See [OCM](/configuration/service/ocm). +#### 1.12.24 + +* Fixes and improvements + +#### 1.14.0-alpha.2 + +* Add OpenWrt and Alpine APK packages to release **1** +* Backport to macOS 10.13 High Sierra **2** +* OCM service: Add WebSocket support for Responses API **3** +* Fixes and improvements + +**1**: + +Alpine APK files use `linux` in the filename to distinguish from OpenWrt APKs which use the `openwrt` prefix: + +- OpenWrt: `sing-box_{version}_openwrt_{architecture}.apk` +- Alpine: `sing-box_{version}_linux_{architecture}.apk` + +**2**: + +Legacy macOS binaries (with `-legacy-macos-10.13` suffix) now support +macOS 10.13 High Sierra, built using Go 1.25 with patches +from [SagerNet/go](https://github.com/SagerNet/go). + +**3**: + +See [OCM](/configuration/service/ocm). + +#### 1.14.0-alpha.1 + +* Add `source_mac_address` and `source_hostname` rule items **1** +* Add `include_mac_address` and `exclude_mac_address` TUN options **2** +* Update NaiveProxy to 145.0.7632.159 **3** +* Fixes and improvements + +**1**: + +New rule items for matching LAN devices by MAC address and hostname via neighbor resolution. +Supported on Linux, macOS, or in graphical clients on Android and macOS. + +See [Route Rule](/configuration/route/rule/#source_mac_address), [DNS Rule](/configuration/dns/rule/#source_mac_address) and [Neighbor Resolution](/configuration/shared/neighbor/). + +**2**: + +Limit or exclude devices from TUN routing by MAC address. +Only supported on Linux with `auto_route` and `auto_redirect` enabled. + +See [TUN](/configuration/inbound/tun/#include_mac_address). + +**3**: + +This is not an official update from NaiveProxy. Instead, it's a Chromium codebase update maintained by Project S. + #### 1.13.2 * Fixes and improvements From 2801bce815b04460dd0731f3745ecd950cdba73e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 11 Mar 2026 01:01:04 +0800 Subject: [PATCH 07/96] ccm/ocm: Add multi-credential support with balancer and fallback strategies --- docs/configuration/service/ccm.md | 130 +++- docs/configuration/service/ccm.zh.md | 130 +++- docs/configuration/service/ocm.md | 130 +++- docs/configuration/service/ocm.zh.md | 130 +++- go.mod | 2 +- go.sum | 4 +- log/format.go | 6 +- option/ccm.go | 77 +- option/ocm.go | 77 +- service/ccm/credential.go | 52 ++ service/ccm/credential_state.go | 997 +++++++++++++++++++++++++ service/ccm/service.go | 383 ++++++---- service/ocm/credential.go | 4 + service/ocm/credential_state.go | 1022 ++++++++++++++++++++++++++ service/ocm/service.go | 412 +++++++---- service/ocm/service_websocket.go | 164 +++-- 16 files changed, 3373 insertions(+), 347 deletions(-) create mode 100644 service/ccm/credential_state.go create mode 100644 service/ocm/credential_state.go diff --git a/docs/configuration/service/ccm.md b/docs/configuration/service/ccm.md index 337cacb10b..59ef5f7c0d 100644 --- a/docs/configuration/service/ccm.md +++ b/docs/configuration/service/ccm.md @@ -10,6 +10,11 @@ CCM (Claude Code Multiplexer) service is a multiplexing service that allows you It handles OAuth authentication with Claude's API on your local machine while allowing remote Claude Code to authenticate using Auth Tokens via the `ANTHROPIC_AUTH_TOKEN` environment variable. +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [credentials](#credentials) + :material-alert: [users](#users) + ### Structure ```json @@ -19,6 +24,7 @@ It handles OAuth authentication with Claude's API on your local machine while al ... // Listen Fields "credential_path": "", + "credentials": [], "usages_path": "", "users": [], "headers": {}, @@ -45,6 +51,73 @@ On macOS, credentials are read from the system keychain first, then fall back to Refreshed tokens are automatically written back to the same location. +Conflict with `credentials`. + +#### credentials + +!!! question "Since sing-box 1.14.0" + +List of credential configurations for multi-credential mode. + +When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag. + +Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field. + +##### Default Credential + +```json +{ + "tag": "a", + "credential_path": "/path/to/.credentials.json", + "usages_path": "/path/to/usages.json", + "detour": "", + "reserve_5h": 20, + "reserve_weekly": 20 +} +``` + +A single OAuth credential file. The `type` field can be omitted (defaults to `default`). + +- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`. +- `usages_path`: Optional usage tracking file for this credential. +- `detour`: Outbound tag for connecting to the Claude API with this credential. +- `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization. +- `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization. + +##### Balancer Credential + +```json +{ + "tag": "pool", + "type": "balancer", + "strategy": "", + "credentials": ["a", "b"], + "poll_interval": "60s" +} +``` + +Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit. + +- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default. +- `credentials`: ==Required== List of default credential tags. +- `poll_interval`: How often to poll upstream usage API. Default `60s`. + +##### Fallback Credential + +```json +{ + "tag": "backup", + "type": "fallback", + "credentials": ["a", "b"], + "poll_interval": "30s" +} +``` + +Uses credentials in order. Falls through to the next when the current one is exhausted. + +- `credentials`: ==Required== Ordered list of default credential tags. +- `poll_interval`: How often to poll upstream usage API. Default `60s`. + #### usages_path Path to the file for storing aggregated API usage statistics. @@ -60,6 +133,8 @@ Statistics are organized by model, context window (200k standard vs 1M premium), The statistics file is automatically saved every minute and upon service shutdown. +Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials. + #### users List of authorized users for token authentication. @@ -71,7 +146,8 @@ Object format: ```json { "name": "", - "token": "" + "token": "", + "credential": "" } ``` @@ -79,6 +155,7 @@ Object fields: - `name`: Username identifier for tracking purposes. - `token`: Bearer token for authentication. Claude Code authenticates by setting the `ANTHROPIC_AUTH_TOKEN` environment variable to their token value. +- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set. #### headers @@ -90,6 +167,8 @@ These headers will override any existing headers with the same name. Outbound tag for connecting to the Claude API. +Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials. + #### tls TLS configuration, see [TLS](/configuration/shared/tls/#inbound). @@ -129,3 +208,52 @@ export ANTHROPIC_AUTH_TOKEN="ak-ccm-hello-world" claude ``` + +### Example with Multiple Credentials + +#### Server + +```json +{ + "services": [ + { + "type": "ccm", + "listen": "0.0.0.0", + "listen_port": 8080, + "credentials": [ + { + "tag": "a", + "credential_path": "/home/user/.claude-a/.credentials.json", + "usages_path": "/data/usages-a.json", + "reserve_5h": 20, + "reserve_weekly": 20 + }, + { + "tag": "b", + "credential_path": "/home/user/.claude-b/.credentials.json", + "reserve_5h": 10, + "reserve_weekly": 10 + }, + { + "tag": "pool", + "type": "balancer", + "poll_interval": "60s", + "credentials": ["a", "b"] + } + ], + "users": [ + { + "name": "alice", + "token": "ak-ccm-hello-world", + "credential": "pool" + }, + { + "name": "bob", + "token": "ak-ccm-hello-bob", + "credential": "a" + } + ] + } + ] +} +``` diff --git a/docs/configuration/service/ccm.zh.md b/docs/configuration/service/ccm.zh.md index 7bba322c77..d9496986a7 100644 --- a/docs/configuration/service/ccm.zh.md +++ b/docs/configuration/service/ccm.zh.md @@ -10,6 +10,11 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许 它在本地机器上处理与 Claude API 的 OAuth 身份验证,同时允许远程 Claude Code 通过 `ANTHROPIC_AUTH_TOKEN` 环境变量使用认证令牌进行身份验证。 +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [credentials](#credentials) + :material-alert: [users](#users) + ### 结构 ```json @@ -19,6 +24,7 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许 ... // 监听字段 "credential_path": "", + "credentials": [], "usages_path": "", "users": [], "headers": {}, @@ -45,6 +51,73 @@ Claude Code OAuth 凭据文件的路径。 刷新的令牌会自动写回相同位置。 +与 `credentials` 冲突。 + +#### credentials + +!!! question "自 sing-box 1.14.0 起" + +多凭据模式的凭据配置列表。 + +设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。 + +每个凭据有一个 `type` 字段(`default`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 + +##### 默认凭据 + +```json +{ + "tag": "a", + "credential_path": "/path/to/.credentials.json", + "usages_path": "/path/to/usages.json", + "detour": "", + "reserve_5h": 20, + "reserve_weekly": 20 +} +``` + +单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。 + +- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。 +- `usages_path`:此凭据的可选使用跟踪文件。 +- `detour`:此凭据用于连接 Claude API 的出站标签。 +- `reserve_5h`:5 小时窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 +- `reserve_weekly`:每周窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 + +##### 均衡凭据 + +```json +{ + "tag": "pool", + "type": "balancer", + "strategy": "", + "credentials": ["a", "b"], + "poll_interval": "60s" +} +``` + +根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。 + +- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`。 +- `credentials`:==必填== 默认凭据标签列表。 +- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 + +##### 回退凭据 + +```json +{ + "tag": "backup", + "type": "fallback", + "credentials": ["a", "b"], + "poll_interval": "30s" +} +``` + +按顺序使用凭据。当前凭据耗尽后切换到下一个。 + +- `credentials`:==必填== 有序的默认凭据标签列表。 +- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 + #### usages_path 用于存储聚合 API 使用统计信息的文件路径。 @@ -60,6 +133,8 @@ Claude Code OAuth 凭据文件的路径。 统计文件每分钟自动保存一次,并在服务关闭时保存。 +与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。 + #### users 用于令牌身份验证的授权用户列表。 @@ -71,7 +146,8 @@ Claude Code OAuth 凭据文件的路径。 ```json { "name": "", - "token": "" + "token": "", + "credential": "" } ``` @@ -79,6 +155,7 @@ Claude Code OAuth 凭据文件的路径。 - `name`:用于跟踪的用户名标识符。 - `token`:用于身份验证的 Bearer 令牌。Claude Code 通过设置 `ANTHROPIC_AUTH_TOKEN` 环境变量为其令牌值进行身份验证。 +- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。 #### headers @@ -90,6 +167,8 @@ Claude Code OAuth 凭据文件的路径。 用于连接 Claude API 的出站标签。 +与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。 + #### tls TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。 @@ -129,3 +208,52 @@ export ANTHROPIC_AUTH_TOKEN="ak-ccm-hello-world" claude ``` + +### 多凭据示例 + +#### 服务端 + +```json +{ + "services": [ + { + "type": "ccm", + "listen": "0.0.0.0", + "listen_port": 8080, + "credentials": [ + { + "tag": "a", + "credential_path": "/home/user/.claude-a/.credentials.json", + "usages_path": "/data/usages-a.json", + "reserve_5h": 20, + "reserve_weekly": 20 + }, + { + "tag": "b", + "credential_path": "/home/user/.claude-b/.credentials.json", + "reserve_5h": 10, + "reserve_weekly": 10 + }, + { + "tag": "pool", + "type": "balancer", + "poll_interval": "60s", + "credentials": ["a", "b"] + } + ], + "users": [ + { + "name": "alice", + "token": "ak-ccm-hello-world", + "credential": "pool" + }, + { + "name": "bob", + "token": "ak-ccm-hello-bob", + "credential": "a" + } + ] + } + ] +} +``` diff --git a/docs/configuration/service/ocm.md b/docs/configuration/service/ocm.md index 5fdf2b6b42..8dfd0e99ed 100644 --- a/docs/configuration/service/ocm.md +++ b/docs/configuration/service/ocm.md @@ -10,6 +10,11 @@ OCM (OpenAI Codex Multiplexer) service is a multiplexing service that allows you It handles OAuth authentication with OpenAI's API on your local machine while allowing remote clients to authenticate using custom tokens. +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [credentials](#credentials) + :material-alert: [users](#users) + ### Structure ```json @@ -19,6 +24,7 @@ It handles OAuth authentication with OpenAI's API on your local machine while al ... // Listen Fields "credential_path": "", + "credentials": [], "usages_path": "", "users": [], "headers": {}, @@ -43,6 +49,73 @@ If not specified, defaults to: Refreshed tokens are automatically written back to the same location. +Conflict with `credentials`. + +#### credentials + +!!! question "Since sing-box 1.14.0" + +List of credential configurations for multi-credential mode. + +When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag. + +Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field. + +##### Default Credential + +```json +{ + "tag": "a", + "credential_path": "/path/to/auth.json", + "usages_path": "/path/to/usages.json", + "detour": "", + "reserve_5h": 20, + "reserve_weekly": 20 +} +``` + +A single OAuth credential file. The `type` field can be omitted (defaults to `default`). + +- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`. +- `usages_path`: Optional usage tracking file for this credential. +- `detour`: Outbound tag for connecting to the OpenAI API with this credential. +- `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization. +- `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization. + +##### Balancer Credential + +```json +{ + "tag": "pool", + "type": "balancer", + "strategy": "", + "credentials": ["a", "b"], + "poll_interval": "60s" +} +``` + +Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit. + +- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default. +- `credentials`: ==Required== List of default credential tags. +- `poll_interval`: How often to poll upstream usage API. Default `60s`. + +##### Fallback Credential + +```json +{ + "tag": "backup", + "type": "fallback", + "credentials": ["a", "b"], + "poll_interval": "30s" +} +``` + +Uses credentials in order. Falls through to the next when the current one is exhausted. + +- `credentials`: ==Required== Ordered list of default credential tags. +- `poll_interval`: How often to poll upstream usage API. Default `60s`. + #### usages_path Path to the file for storing aggregated API usage statistics. @@ -58,6 +131,8 @@ Statistics are organized by model and optionally by user when authentication is The statistics file is automatically saved every minute and upon service shutdown. +Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials. + #### users List of authorized users for token authentication. @@ -69,7 +144,8 @@ Object format: ```json { "name": "", - "token": "" + "token": "", + "credential": "" } ``` @@ -77,6 +153,7 @@ Object fields: - `name`: Username identifier for tracking purposes. - `token`: Bearer token for authentication. Clients authenticate by setting the `Authorization: Bearer ` header. +- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set. #### headers @@ -88,6 +165,8 @@ These headers will override any existing headers with the same name. Outbound tag for connecting to the OpenAI API. +Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials. + #### tls TLS configuration, see [TLS](/configuration/shared/tls/#inbound). @@ -183,3 +262,52 @@ Then run: ```bash codex --profile ocm ``` + +### Example with Multiple Credentials + +#### Server + +```json +{ + "services": [ + { + "type": "ocm", + "listen": "0.0.0.0", + "listen_port": 8080, + "credentials": [ + { + "tag": "a", + "credential_path": "/home/user/.codex-a/auth.json", + "usages_path": "/data/usages-a.json", + "reserve_5h": 20, + "reserve_weekly": 20 + }, + { + "tag": "b", + "credential_path": "/home/user/.codex-b/auth.json", + "reserve_5h": 10, + "reserve_weekly": 10 + }, + { + "tag": "pool", + "type": "balancer", + "poll_interval": "60s", + "credentials": ["a", "b"] + } + ], + "users": [ + { + "name": "alice", + "token": "sk-ocm-hello-world", + "credential": "pool" + }, + { + "name": "bob", + "token": "sk-ocm-hello-bob", + "credential": "a" + } + ] + } + ] +} +``` diff --git a/docs/configuration/service/ocm.zh.md b/docs/configuration/service/ocm.zh.md index 2e02dc558b..ee4ffa633c 100644 --- a/docs/configuration/service/ocm.zh.md +++ b/docs/configuration/service/ocm.zh.md @@ -10,6 +10,11 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许 它在本地机器上处理与 OpenAI API 的 OAuth 身份验证,同时允许远程客户端使用自定义令牌进行身份验证。 +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [credentials](#credentials) + :material-alert: [users](#users) + ### 结构 ```json @@ -19,6 +24,7 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许 ... // 监听字段 "credential_path": "", + "credentials": [], "usages_path": "", "users": [], "headers": {}, @@ -43,6 +49,73 @@ OpenAI OAuth 凭据文件的路径。 刷新的令牌会自动写回相同位置。 +与 `credentials` 冲突。 + +#### credentials + +!!! question "自 sing-box 1.14.0 起" + +多凭据模式的凭据配置列表。 + +设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。 + +每个凭据有一个 `type` 字段(`default`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 + +##### 默认凭据 + +```json +{ + "tag": "a", + "credential_path": "/path/to/auth.json", + "usages_path": "/path/to/usages.json", + "detour": "", + "reserve_5h": 20, + "reserve_weekly": 20 +} +``` + +单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。 + +- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。 +- `usages_path`:此凭据的可选使用跟踪文件。 +- `detour`:此凭据用于连接 OpenAI API 的出站标签。 +- `reserve_5h`:主要速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 +- `reserve_weekly`:次要(每周)速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 + +##### 均衡凭据 + +```json +{ + "tag": "pool", + "type": "balancer", + "strategy": "", + "credentials": ["a", "b"], + "poll_interval": "60s" +} +``` + +根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。 + +- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`。 +- `credentials`:==必填== 默认凭据标签列表。 +- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 + +##### 回退凭据 + +```json +{ + "tag": "backup", + "type": "fallback", + "credentials": ["a", "b"], + "poll_interval": "30s" +} +``` + +按顺序使用凭据。当前凭据耗尽后切换到下一个。 + +- `credentials`:==必填== 有序的默认凭据标签列表。 +- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 + #### usages_path 用于存储聚合 API 使用统计信息的文件路径。 @@ -58,6 +131,8 @@ OpenAI OAuth 凭据文件的路径。 统计文件每分钟自动保存一次,并在服务关闭时保存。 +与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。 + #### users 用于令牌身份验证的授权用户列表。 @@ -69,7 +144,8 @@ OpenAI OAuth 凭据文件的路径。 ```json { "name": "", - "token": "" + "token": "", + "credential": "" } ``` @@ -77,6 +153,7 @@ OpenAI OAuth 凭据文件的路径。 - `name`:用于跟踪的用户名标识符。 - `token`:用于身份验证的 Bearer 令牌。客户端通过设置 `Authorization: Bearer ` 头进行身份验证。 +- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。 #### headers @@ -88,6 +165,8 @@ OpenAI OAuth 凭据文件的路径。 用于连接 OpenAI API 的出站标签。 +与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。 + #### tls TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。 @@ -184,3 +263,52 @@ model_provider = "ocm" ```bash codex --profile ocm ``` + +### 多凭据示例 + +#### 服务端 + +```json +{ + "services": [ + { + "type": "ocm", + "listen": "0.0.0.0", + "listen_port": 8080, + "credentials": [ + { + "tag": "a", + "credential_path": "/home/user/.codex-a/auth.json", + "usages_path": "/data/usages-a.json", + "reserve_5h": 20, + "reserve_weekly": 20 + }, + { + "tag": "b", + "credential_path": "/home/user/.codex-b/auth.json", + "reserve_5h": 10, + "reserve_weekly": 10 + }, + { + "tag": "pool", + "type": "balancer", + "poll_interval": "60s", + "credentials": ["a", "b"] + } + ], + "users": [ + { + "name": "alice", + "token": "sk-ocm-hello-world", + "credential": "pool" + }, + { + "name": "bob", + "token": "sk-ocm-hello-bob", + "credential": "a" + } + ] + } + ] +} +``` diff --git a/go.mod b/go.mod index dec9155f5a..b0a52832e8 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/sagernet/gomobile v0.1.12 github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1 github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 - github.com/sagernet/sing v0.8.2 + github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69 github.com/sagernet/sing-mux v0.3.4 github.com/sagernet/sing-quic v0.6.0 github.com/sagernet/sing-shadowsocks v0.2.8 diff --git a/go.sum b/go.sum index 378117ec91..b3ac58548a 100644 --- a/go.sum +++ b/go.sum @@ -236,8 +236,8 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 h1:6qvrUW79S+CrPwWz6cMePXohgjHoKxLo3c+MDhNwc3o= github.com/sagernet/quic-go v0.59.0-sing-box-mod.4/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4= -github.com/sagernet/sing v0.8.2 h1:kX1IH9SWJv4S0T9M8O+HNahWgbOuY1VauxbF7NU5lOg= -github.com/sagernet/sing v0.8.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69 h1:h6UF2emeydBQMAso99Nr3APV6YustOs+JszVuCkcFy0= +github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s= github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= github.com/sagernet/sing-quic v0.6.0 h1:dhrFnP45wgVKEOT1EvtsToxdzRnHIDIAgj6WHV9pLyM= diff --git a/log/format.go b/log/format.go index 6f4347b12a..d2aaa27548 100644 --- a/log/format.go +++ b/log/format.go @@ -168,7 +168,11 @@ func FormatDuration(duration time.Duration) string { return F.ToString(duration.Milliseconds(), "ms") } else if duration < time.Minute { return F.ToString(int64(duration.Seconds()), ".", int64(duration.Seconds()*100)%100, "s") - } else { + } else if duration < time.Hour { return F.ToString(int64(duration.Minutes()), "m", int64(duration.Seconds())%60, "s") + } else if duration < 24*time.Hour { + return F.ToString(int64(duration.Hours()), "h", int64(duration.Minutes())%60, "m") + } else { + return F.ToString(int64(duration.Hours())/24, "d", int64(duration.Hours())%24, "h") } } diff --git a/option/ccm.go b/option/ccm.go index c916aaf221..edfe2e417b 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -1,6 +1,9 @@ package option import ( + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" "github.com/sagernet/sing/common/json/badoption" ) @@ -8,6 +11,7 @@ type CCMServiceOptions struct { ListenOptions InboundTLSOptionsContainer CredentialPath string `json:"credential_path,omitempty"` + Credentials []CCMCredential `json:"credentials,omitempty"` Users []CCMUser `json:"users,omitempty"` Headers badoption.HTTPHeader `json:"headers,omitempty"` Detour string `json:"detour,omitempty"` @@ -15,6 +19,75 @@ type CCMServiceOptions struct { } type CCMUser struct { - Name string `json:"name,omitempty"` - Token string `json:"token,omitempty"` + Name string `json:"name,omitempty"` + Token string `json:"token,omitempty"` + Credential string `json:"credential,omitempty"` +} + +type _CCMCredential struct { + Type string `json:"type,omitempty"` + Tag string `json:"tag"` + DefaultOptions CCMDefaultCredentialOptions `json:"-"` + BalancerOptions CCMBalancerCredentialOptions `json:"-"` + FallbackOptions CCMFallbackCredentialOptions `json:"-"` +} + +type CCMCredential _CCMCredential + +func (c CCMCredential) MarshalJSON() ([]byte, error) { + var v any + switch c.Type { + case "", "default": + c.Type = "" + v = c.DefaultOptions + case "balancer": + v = c.BalancerOptions + case "fallback": + v = c.FallbackOptions + default: + return nil, E.New("unknown credential type: ", c.Type) + } + return badjson.MarshallObjects((_CCMCredential)(c), v) +} + +func (c *CCMCredential) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_CCMCredential)(c)) + if err != nil { + return err + } + if c.Tag == "" { + return E.New("missing credential tag") + } + var v any + switch c.Type { + case "", "default": + c.Type = "default" + v = &c.DefaultOptions + case "balancer": + v = &c.BalancerOptions + case "fallback": + v = &c.FallbackOptions + default: + return E.New("unknown credential type: ", c.Type) + } + return badjson.UnmarshallExcluded(bytes, (*_CCMCredential)(c), v) +} + +type CCMDefaultCredentialOptions struct { + CredentialPath string `json:"credential_path,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` + Detour string `json:"detour,omitempty"` + Reserve5h uint8 `json:"reserve_5h"` + ReserveWeekly uint8 `json:"reserve_weekly"` +} + +type CCMBalancerCredentialOptions struct { + Strategy string `json:"strategy,omitempty"` + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` +} + +type CCMFallbackCredentialOptions struct { + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/option/ocm.go b/option/ocm.go index c13a1c1f53..832b455288 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -1,6 +1,9 @@ package option import ( + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" "github.com/sagernet/sing/common/json/badoption" ) @@ -8,6 +11,7 @@ type OCMServiceOptions struct { ListenOptions InboundTLSOptionsContainer CredentialPath string `json:"credential_path,omitempty"` + Credentials []OCMCredential `json:"credentials,omitempty"` Users []OCMUser `json:"users,omitempty"` Headers badoption.HTTPHeader `json:"headers,omitempty"` Detour string `json:"detour,omitempty"` @@ -15,6 +19,75 @@ type OCMServiceOptions struct { } type OCMUser struct { - Name string `json:"name,omitempty"` - Token string `json:"token,omitempty"` + Name string `json:"name,omitempty"` + Token string `json:"token,omitempty"` + Credential string `json:"credential,omitempty"` +} + +type _OCMCredential struct { + Type string `json:"type,omitempty"` + Tag string `json:"tag"` + DefaultOptions OCMDefaultCredentialOptions `json:"-"` + BalancerOptions OCMBalancerCredentialOptions `json:"-"` + FallbackOptions OCMFallbackCredentialOptions `json:"-"` +} + +type OCMCredential _OCMCredential + +func (c OCMCredential) MarshalJSON() ([]byte, error) { + var v any + switch c.Type { + case "", "default": + c.Type = "" + v = c.DefaultOptions + case "balancer": + v = c.BalancerOptions + case "fallback": + v = c.FallbackOptions + default: + return nil, E.New("unknown credential type: ", c.Type) + } + return badjson.MarshallObjects((_OCMCredential)(c), v) +} + +func (c *OCMCredential) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_OCMCredential)(c)) + if err != nil { + return err + } + if c.Tag == "" { + return E.New("missing credential tag") + } + var v any + switch c.Type { + case "", "default": + c.Type = "default" + v = &c.DefaultOptions + case "balancer": + v = &c.BalancerOptions + case "fallback": + v = &c.FallbackOptions + default: + return E.New("unknown credential type: ", c.Type) + } + return badjson.UnmarshallExcluded(bytes, (*_OCMCredential)(c), v) +} + +type OCMDefaultCredentialOptions struct { + CredentialPath string `json:"credential_path,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` + Detour string `json:"detour,omitempty"` + Reserve5h uint8 `json:"reserve_5h"` + ReserveWeekly uint8 `json:"reserve_weekly"` +} + +type OCMBalancerCredentialOptions struct { + Strategy string `json:"strategy,omitempty"` + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` +} + +type OCMFallbackCredentialOptions struct { + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 695efc7ae3..0fe5e2b970 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -8,8 +8,11 @@ import ( "os" "os/user" "path/filepath" + "runtime" + "sync" "time" + "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" ) @@ -21,6 +24,50 @@ const ( anthropicBetaOAuthValue = "oauth-2025-04-20" ) +const ccmUserAgentFallback = "claude-code/2.1.72" + +var ( + ccmUserAgentOnce sync.Once + ccmUserAgentValue string +) + +func initCCMUserAgent(logger log.ContextLogger) { + ccmUserAgentOnce.Do(func() { + version, err := detectClaudeCodeVersion() + if err != nil { + logger.Error("detect Claude Code version: ", err) + ccmUserAgentValue = ccmUserAgentFallback + return + } + logger.Debug("detected Claude Code version: ", version) + ccmUserAgentValue = "claude-code/" + version + }) +} + +func detectClaudeCodeVersion() (string, error) { + userInfo, err := getRealUser() + if err != nil { + return "", E.Cause(err, "get user") + } + binaryName := "claude" + if runtime.GOOS == "windows" { + binaryName = "claude.exe" + } + linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName) + target, err := os.Readlink(linkPath) + if err != nil { + return "", E.Cause(err, "readlink ", linkPath) + } + if !filepath.IsAbs(target) { + target = filepath.Join(filepath.Dir(linkPath), target) + } + parent := filepath.Base(filepath.Dir(target)) + if parent != "versions" { + return "", E.New("unexpected symlink target: ", target) + } + return filepath.Base(target), nil +} + func getRealUser() (*user.User, error) { if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { sudoUserInfo, err := user.Lookup(sudoUser) @@ -106,6 +153,7 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut } request.Header.Set("Content-Type", "application/json") request.Header.Set("Accept", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) response, err := httpClient.Do(request) if err != nil { @@ -113,6 +161,10 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut } defer response.Body.Close() + if response.StatusCode == http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + } if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) return nil, E.New("refresh failed: ", response.Status, " ", string(body)) diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go new file mode 100644 index 0000000000..1681b08559 --- /dev/null +++ b/service/ccm/credential_state.go @@ -0,0 +1,997 @@ +package ccm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "errors" + "io" + "math" + "math/rand/v2" + "net" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/ntp" +) + +const defaultPollInterval = 60 * time.Minute + +type credentialState struct { + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + lastUpdated time.Time + consecutivePollFailures int +} + +type defaultCredential struct { + tag string + credentialPath string + credentials *oauthCredentials + accessMutex sync.RWMutex + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + reserve5h uint8 + reserveWeekly uint8 + usageTracker *AggregatedUsage + httpClient *http.Client + logger log.ContextLogger + + // Connection interruption + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +type credentialRequestContext struct { + context.Context + releaseOnce sync.Once + cancelOnce sync.Once + releaseFunc func() bool + cancelFunc context.CancelFunc +} + +func (c *credentialRequestContext) releaseCredentialInterrupt() { + c.releaseOnce.Do(func() { + c.releaseFunc() + }) +} + +func (c *credentialRequestContext) cancelRequest() { + c.releaseCredentialInterrupt() + c.cancelOnce.Do(c.cancelFunc) +} + +func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + reserve5h := options.Reserve5h + if reserve5h == 0 { + reserve5h = 1 + } + reserveWeekly := options.ReserveWeekly + if reserveWeekly == 0 { + reserveWeekly = 10 + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: tag, + credentialPath: options.CredentialPath, + reserve5h: reserve5h, + reserveWeekly: reserveWeekly, + httpClient: httpClient, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if options.UsagesPath != "" { + credential.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + return credential, nil +} + +func (c *defaultCredential) start() error { + credentials, err := platformReadCredentials(c.credentialPath) + if err != nil { + return E.Cause(err, "read credentials for ", c.tag) + } + c.credentials = credentials + if credentials.SubscriptionType != "" { + c.state.accountType = credentials.SubscriptionType + } + if c.usageTracker != nil { + err = c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *defaultCredential) getAccessToken() (string, error) { + c.accessMutex.RLock() + if !c.credentials.needsRefresh() { + token := c.credentials.AccessToken + c.accessMutex.RUnlock() + return token, nil + } + c.accessMutex.RUnlock() + + c.accessMutex.Lock() + defer c.accessMutex.Unlock() + + if !c.credentials.needsRefresh() { + return c.credentials.AccessToken, nil + } + + newCredentials, err := refreshToken(c.httpClient, c.credentials) + if err != nil { + return "", err + } + + c.credentials = newCredentials + if newCredentials.SubscriptionType != "" { + c.stateMutex.Lock() + c.state.accountType = newCredentials.SubscriptionType + c.stateMutex.Unlock() + } + + err = platformWriteCredentials(newCredentials, c.credentialPath) + if err != nil { + c.logger.Warn("persist refreshed token for ", c.tag, ": ", err) + } + + return newCredentials.AccessToken, nil +} + +func parseResetTimestamp(value string) (time.Time, error) { + if value == "" { + return time.Time{}, nil + } + unixEpoch, err := strconv.ParseInt(value, 10, 64) + if err == nil { + return time.Unix(unixEpoch, 0), nil + } + return time.Parse(time.RFC3339Nano, value) +} + +func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { + c.stateMutex.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + + if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + newValue := math.Ceil(value * 100) + if newValue < c.state.fiveHourUtilization { + c.logger.Error("header 5h utilization for ", c.tag, " is lower than current: ", newValue, " < ", c.state.fiveHourUtilization) + } + c.state.fiveHourUtilization = newValue + } + } + if resetAt := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetAt != "" { + value, err := parseResetTimestamp(resetAt) + if err == nil { + c.state.fiveHourReset = value + } + } + if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + newValue := math.Ceil(value * 100) + if newValue < c.state.weeklyUtilization { + c.logger.Error("header weekly utilization for ", c.tag, " is lower than current: ", newValue, " < ", c.state.weeklyUtilization) + } + c.state.weeklyUtilization = newValue + } + } + if resetAt := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetAt != "" { + value, err := parseResetTimestamp(resetAt) + if err == nil { + c.state.weeklyReset = value + } + } + c.state.lastUpdated = time.Now() + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateMutex.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) isUsable() bool { + c.stateMutex.RLock() + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateMutex.RUnlock() + return false + } + c.stateMutex.RUnlock() + c.stateMutex.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.checkReservesLocked() + c.stateMutex.Unlock() + return usable + } + usable := c.checkReservesLocked() + c.stateMutex.RUnlock() + return usable +} + +func (c *defaultCredential) checkReservesLocked() bool { + if c.state.fiveHourUtilization >= float64(100-c.reserve5h) { + return false + } + if c.state.weeklyUtilization >= float64(100-c.reserveWeekly) { + return false + } + return true +} + +// checkTransitionLocked detects usable→unusable transition. +// Must be called with stateMutex write lock held. +func (c *defaultCredential) checkTransitionLocked() bool { + unusable := c.state.hardRateLimited || !c.checkReservesLocked() + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *defaultCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFunc: stop, + cancelFunc: cancel, + } +} + +func (c *defaultCredential) weeklyUtilization() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyUtilization +} + +func (c *defaultCredential) lastUpdatedTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.lastUpdated +} + +func (c *defaultCredential) markUsagePollAttempted() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateMutex.RLock() + failures := c.state.consecutivePollFailures + c.stateMutex.RUnlock() + if failures <= 0 { + return baseInterval + } + if failures > 4 { + failures = 4 + } + return baseInterval * time.Duration(1< p.credential.pollBackoff(defaultPollInterval) { + p.credential.pollUsage(ctx) + } +} + +func (p *singleCredentialProvider) allDefaults() []*defaultCredential { + return []*defaultCredential{p.credential} +} + +func (p *singleCredentialProvider) close() {} + +const sessionExpiry = 24 * time.Hour + +type sessionEntry struct { + tag string + createdAt time.Time +} + +// balancerProvider assigns sessions to credentials based on a configurable strategy. +type balancerProvider struct { + credentials []*defaultCredential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + sessionMutex sync.RWMutex + sessions map[string]sessionEntry + logger log.ContextLogger +} + +func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &balancerProvider{ + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + sessions: make(map[string]sessionEntry), + logger: logger, + } +} + +func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) { + if sessionID != "" { + p.sessionMutex.RLock() + entry, exists := p.sessions[sessionID] + p.sessionMutex.RUnlock() + if exists { + for _, credential := range p.credentials { + if credential.tag == entry.tag && credential.isUsable() { + return credential, false, nil + } + } + p.sessionMutex.Lock() + delete(p.sessions, sessionID) + p.sessionMutex.Unlock() + } + } + + best := p.pickCredential() + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + + isNew := sessionID != "" + if isNew { + p.sessionMutex.Lock() + p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessionMutex.Unlock() + } + return best, isNew, nil +} + +func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + if sessionID != "" { + p.sessionMutex.Lock() + delete(p.sessions, sessionID) + p.sessionMutex.Unlock() + } + + best := p.pickCredential() + if best != nil && sessionID != "" { + p.sessionMutex.Lock() + p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessionMutex.Unlock() + } + return best +} + +func (p *balancerProvider) pickCredential() *defaultCredential { + switch p.strategy { + case "round_robin": + return p.pickRoundRobin() + case "random": + return p.pickRandom() + default: + return p.pickLeastUsed() + } +} + +func (p *balancerProvider) pickLeastUsed() *defaultCredential { + var best *defaultCredential + bestUtilization := float64(101) + for _, credential := range p.credentials { + if !credential.isUsable() { + continue + } + utilization := credential.weeklyUtilization() + if utilization < bestUtilization { + bestUtilization = utilization + best = credential + } + } + return best +} + +func (p *balancerProvider) pickRoundRobin() *defaultCredential { + start := int(p.roundRobinIndex.Add(1) - 1) + count := len(p.credentials) + for offset := range count { + candidate := p.credentials[(start+offset)%count] + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *balancerProvider) pickRandom() *defaultCredential { + var usable []*defaultCredential + for _, candidate := range p.credentials { + if candidate.isUsable() { + usable = append(usable, candidate) + } + } + if len(usable) == 0 { + return nil + } + return usable[rand.IntN(len(usable))] +} + +func (p *balancerProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionMutex.Lock() + for id, entry := range p.sessions { + if now.Sub(entry.createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionMutex.Unlock() + + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) + } + } +} + +func (p *balancerProvider) allDefaults() []*defaultCredential { + return p.credentials +} + +func (p *balancerProvider) close() {} + +// fallbackProvider tries credentials in order. +type fallbackProvider struct { + credentials []*defaultCredential + pollInterval time.Duration + logger log.ContextLogger +} + +func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &fallbackProvider{ + credentials: credentials, + pollInterval: pollInterval, + logger: logger, + } +} + +func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) { + for _, credential := range p.credentials { + if credential.isUsable() { + return credential, false, nil + } + } + return nil, false, allCredentialsUnavailableError(p.credentials) +} + +func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + for _, candidate := range p.credentials { + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *fallbackProvider) pollIfStale(ctx context.Context) { + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) + } + } +} + +func (p *fallbackProvider) allDefaults() []*defaultCredential { + return p.credentials +} + +func (p *fallbackProvider) close() {} + +func allCredentialsUnavailableError(credentials []*defaultCredential) error { + var earliest time.Time + for _, credential := range credentials { + resetAt := credential.earliestReset() + if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { + earliest = resetAt + } + } + if earliest.IsZero() { + return E.New("all credentials rate-limited") + } + return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) +} + +func extractCCMSessionID(bodyBytes []byte) string { + var body struct { + Metadata struct { + UserID string `json:"user_id"` + } `json:"metadata"` + } + err := json.Unmarshal(bodyBytes, &body) + if err != nil { + return "" + } + userID := body.Metadata.UserID + sessionIndex := strings.LastIndex(userID, "_session_") + if sessionIndex < 0 { + return "" + } + return userID[sessionIndex+len("_session_"):] +} + +func buildCredentialProviders( + ctx context.Context, + options option.CCMServiceOptions, + logger log.ContextLogger, +) (map[string]credentialProvider, []*defaultCredential, error) { + defaultCredentials := make(map[string]*defaultCredential) + var allDefaults []*defaultCredential + providers := make(map[string]credentialProvider) + + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "default": + credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + if err != nil { + return nil, nil, err + } + defaultCredentials[credOpt.Tag] = credential + allDefaults = append(allDefaults, credential) + providers[credOpt.Tag] = &singleCredentialProvider{credential: credential} + } + } + + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "balancer": + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger) + case "fallback": + subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newFallbackProvider(subCredentials, time.Duration(credOpt.FallbackOptions.PollInterval), logger) + } + } + + return providers, allDefaults, nil +} + +func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) { + credentials := make([]*defaultCredential, 0, len(tags)) + for _, tag := range tags { + credential, exists := defaults[tag] + if !exists { + return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag) + } + credentials = append(credentials, credential) + } + if len(credentials) == 0 { + return nil, E.New("credential ", parentTag, " has no sub-credentials") + } + return credentials, nil +} + +func parseRateLimitResetFromHeaders(headers http.Header) time.Time { + claim := headers.Get("anthropic-ratelimit-unified-representative-claim") + switch claim { + case "5h": + if resetStr := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetStr != "" { + value, err := strconv.ParseInt(resetStr, 10, 64) + if err == nil { + return time.Unix(value, 0) + } + } + case "7d": + if resetStr := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetStr != "" { + value, err := strconv.ParseInt(resetStr, 10, 64) + if err == nil { + return time.Unix(value, 0) + } + } + } + if retryAfter := headers.Get("Retry-After"); retryAfter != "" { + seconds, err := strconv.ParseInt(retryAfter, 10, 64) + if err == nil { + return time.Now().Add(time.Duration(seconds) * time.Second) + } + } + return time.Now().Add(5 * time.Minute) +} + +func validateCCMOptions(options option.CCMServiceOptions) error { + hasCredentials := len(options.Credentials) > 0 + hasLegacyPath := options.CredentialPath != "" + hasLegacyUsages := options.UsagesPath != "" + hasLegacyDetour := options.Detour != "" + + if hasCredentials && hasLegacyPath { + return E.New("credential_path and credentials are mutually exclusive") + } + if hasCredentials && hasLegacyUsages { + return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") + } + if hasCredentials && hasLegacyDetour { + return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") + } + + if hasCredentials { + tags := make(map[string]bool) + for _, credential := range options.Credentials { + if tags[credential.Tag] { + return E.New("duplicate credential tag: ", credential.Tag) + } + tags[credential.Tag] = true + if credential.Type == "default" || credential.Type == "" { + if credential.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") + } + if credential.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") + } + } + if credential.Type == "balancer" { + switch credential.BalancerOptions.Strategy { + case "", "least_used", "round_robin", "random": + default: + return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) + } + } + } + + for _, user := range options.Users { + if user.Credential == "" { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + } + } + + return nil +} + +// retryRequestWithBody re-sends a buffered request body using a different credential. +func retryRequestWithBody( + ctx context.Context, + originalRequest *http.Request, + bodyBytes []byte, + credential *defaultCredential, + httpHeaders http.Header, +) (*http.Response, error) { + accessToken, err := credential.getAccessToken() + if err != nil { + return nil, E.Cause(err, "get access token for ", credential.tag) + } + + proxyURL := claudeAPIBaseURL + originalRequest.URL.RequestURI() + retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + + for key, values := range originalRequest.Header { + if !isHopByHopHeader(key) && key != "Authorization" { + retryRequest.Header[key] = values + } + } + + serviceOverridesAcceptEncoding := len(httpHeaders.Values("Accept-Encoding")) > 0 + if credential.usageTracker != nil && !serviceOverridesAcceptEncoding { + retryRequest.Header.Del("Accept-Encoding") + } + + anthropicBetaHeader := retryRequest.Header.Get("anthropic-beta") + if anthropicBetaHeader != "" { + retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) + } else { + retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + } + + for key, values := range httpHeaders { + retryRequest.Header.Del(key) + retryRequest.Header[key] = values + } + retryRequest.Header.Set("Authorization", "Bearer "+accessToken) + + return credential.httpClient.Do(retryRequest) +} + +// credentialForUser finds the credential provider for a user. +// In legacy mode, returns the single provider. +// In multi-credential mode, returns the provider mapped to the user's credential tag. +func credentialForUser( + userCredentialMap map[string]string, + providers map[string]credentialProvider, + legacyProvider credentialProvider, + username string, +) (credentialProvider, error) { + if legacyProvider != nil { + return legacyProvider, nil + } + tag, exists := userCredentialMap[username] + if !exists { + return nil, E.New("no credential mapping for user: ", username) + } + provider, exists := providers[tag] + if !exists { + return nil, E.New("unknown credential: ", tag) + } + return provider, nil +} + +// noUserCredentialProvider returns the single provider for legacy mode or the first credential in multi-credential mode (no auth). +func noUserCredentialProvider( + providers map[string]credentialProvider, + legacyProvider credentialProvider, + options option.CCMServiceOptions, +) credentialProvider { + if legacyProvider != nil { + return legacyProvider + } + if len(options.Credentials) > 0 { + tag := options.Credentials[0].Tag + return providers[tag] + } + return nil +} diff --git a/service/ccm/service.go b/service/ccm/service.go index 34c38824cd..ea81b1b762 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -3,12 +3,10 @@ package ccm import ( "bytes" "context" - stdTLS "crypto/tls" "encoding/json" "errors" "io" "mime" - "net" "net/http" "strconv" "strings" @@ -17,7 +15,6 @@ import ( "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" - "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/listener" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" @@ -26,9 +23,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/ntp" aTLS "github.com/sagernet/sing/common/tls" "github.com/anthropics/anthropic-sdk-go" @@ -40,6 +35,7 @@ const ( contextWindowStandard = 200000 contextWindowPremium = 1000000 premiumContextThreshold = 200000 + retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" ) func RegisterService(registry *boxService.Registry) { @@ -60,7 +56,6 @@ type errorDetails struct { func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) - json.NewEncoder(w).Encode(errorResponse{ Type: "error", Error: errorDetails{ @@ -71,6 +66,50 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro }) } +func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool { + if provider == nil || currentCredential == nil { + return false + } + for _, credential := range provider.allDefaults() { + if credential == currentCredential { + continue + } + if credential.isUsable() { + return true + } + } + return false +} + +func unavailableCredentialMessage(provider credentialProvider, fallback string) string { + if provider == nil { + return fallback + } + return allCredentialsUnavailableError(provider.allDefaults()).Error() +} + +func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { + writeJSONError(w, r, http.StatusTooManyRequests, "rate_limit_error", retryableUsageMessage) +} + +func writeNonRetryableCredentialError(w http.ResponseWriter, r *http.Request, message string) { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", message) +} + +func writeCredentialUnavailableError( + w http.ResponseWriter, + r *http.Request, + provider credentialProvider, + currentCredential *defaultCredential, + fallback string, +) { + if hasAlternativeCredential(provider, currentCredential) { + writeRetryableUsageError(w, r) + return + } + writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback)) +} + func isHopByHopHeader(header string) bool { switch strings.ToLower(header) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": @@ -111,78 +150,79 @@ func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { type Service struct { boxService.Adapter - ctx context.Context - logger log.ContextLogger - credentialPath string - credentials *oauthCredentials - users []option.CCMUser - httpClient *http.Client - httpHeaders http.Header - listener *listener.Listener - tlsConfig tls.ServerConfig - httpServer *http.Server - userManager *UserManager - accessMutex sync.RWMutex - usageTracker *AggregatedUsage - trackingGroup sync.WaitGroup - shuttingDown bool + ctx context.Context + logger log.ContextLogger + options option.CCMServiceOptions + httpHeaders http.Header + listener *listener.Listener + tlsConfig tls.ServerConfig + httpServer *http.Server + userManager *UserManager + trackingGroup sync.WaitGroup + shuttingDown bool + + // Legacy mode (single credential) + legacyCredential *defaultCredential + legacyProvider credentialProvider + + // Multi-credential mode + providers map[string]credentialProvider + allDefaults []*defaultCredential + userCredentialMap map[string]string } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) { - serviceDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) + initCCMUserAgent(logger) + + err := validateCCMOptions(options) if err != nil { - return nil, E.Cause(err, "create dialer") - } - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSClientConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, + return nil, E.Cause(err, "validate options") } userManager := &UserManager{ tokenMap: make(map[string]string), } - var usageTracker *AggregatedUsage - if options.UsagesPath != "" { - usageTracker = &AggregatedUsage{ - LastUpdated: time.Now(), - Combinations: make([]CostCombination, 0), - filePath: options.UsagesPath, - logger: logger, - } - } - service := &Service{ - Adapter: boxService.NewAdapter(C.TypeCCM, tag), - ctx: ctx, - logger: logger, - credentialPath: options.CredentialPath, - users: options.Users, - httpClient: httpClient, - httpHeaders: options.Headers.Build(), + Adapter: boxService.NewAdapter(C.TypeCCM, tag), + ctx: ctx, + logger: logger, + options: options, + httpHeaders: options.Headers.Build(), listener: listener.New(listener.Options{ Context: ctx, Logger: logger, Network: []string{N.NetworkTCP}, Listen: options.ListenOptions, }), - userManager: userManager, - usageTracker: usageTracker, + userManager: userManager, + } + + if len(options.Credentials) > 0 { + providers, allDefaults, err := buildCredentialProviders(ctx, options, logger) + if err != nil { + return nil, E.Cause(err, "build credential providers") + } + service.providers = providers + service.allDefaults = allDefaults + + userCredentialMap := make(map[string]string) + for _, user := range options.Users { + userCredentialMap[user.Name] = user.Credential + } + service.userCredentialMap = userCredentialMap + } else { + credential, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{ + CredentialPath: options.CredentialPath, + UsagesPath: options.UsagesPath, + Detour: options.Detour, + }, logger) + if err != nil { + return nil, err + } + service.legacyCredential = credential + service.legacyProvider = &singleCredentialProvider{credential: credential} + service.allDefaults = []*defaultCredential{credential} } if options.TLS != nil { @@ -201,18 +241,12 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } - s.userManager.UpdateUsers(s.users) - - credentials, err := platformReadCredentials(s.credentialPath) - if err != nil { - return E.Cause(err, "read credentials") - } - s.credentials = credentials + s.userManager.UpdateUsers(s.options.Users) - if s.usageTracker != nil { - err = s.usageTracker.Load() + for _, credential := range s.allDefaults { + err := credential.start() if err != nil { - s.logger.Warn("load usage statistics: ", err) + return err } } @@ -222,7 +256,7 @@ func (s *Service) Start(stage adapter.StartStage) error { s.httpServer = &http.Server{Handler: router} if s.tlsConfig != nil { - err = s.tlsConfig.Start() + err := s.tlsConfig.Start() if err != nil { return E.Cause(err, "create TLS config") } @@ -250,44 +284,19 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } -func (s *Service) getAccessToken() (string, error) { - s.accessMutex.RLock() - if !s.credentials.needsRefresh() { - token := s.credentials.AccessToken - s.accessMutex.RUnlock() - return token, nil - } - s.accessMutex.RUnlock() - - s.accessMutex.Lock() - defer s.accessMutex.Unlock() - - if !s.credentials.needsRefresh() { - return s.credentials.AccessToken, nil - } - - newCredentials, err := refreshToken(s.httpClient, s.credentials) - if err != nil { - return "", err - } - - s.credentials = newCredentials - - err = platformWriteCredentials(newCredentials, s.credentialPath) - if err != nil { - s.logger.Warn("persist refreshed token: ", err) +func isExtendedContextRequest(betaHeader string) bool { + for _, feature := range strings.Split(betaHeader, ",") { + if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { + return true + } } - - return newCredentials.AccessToken, nil + return false } func detectContextWindow(betaHeader string, totalInputTokens int64) int { if totalInputTokens > premiumContextThreshold { - features := strings.Split(betaHeader, ",") - for _, feature := range features { - if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { - return contextWindowPremium - } + if isExtendedContextRequest(betaHeader) { + return contextWindowPremium } } return contextWindowStandard @@ -300,7 +309,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } var username string - if len(s.users) > 0 { + if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") @@ -322,26 +331,78 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + // Always read body to extract model and session ID + var bodyBytes []byte var requestModel string var messagesCount int + var sessionID string - if s.usageTracker != nil && r.Body != nil { - bodyBytes, err := io.ReadAll(r.Body) + if r.Body != nil { + var err error + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + s.logger.Error("read request body: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") + return + } + + var request struct { + Model string `json:"model"` + Messages []anthropic.MessageParam `json:"messages"` + } + err = json.Unmarshal(bodyBytes, &request) if err == nil { - var request struct { - Model string `json:"model"` - Messages []anthropic.MessageParam `json:"messages"` - } - err := json.Unmarshal(bodyBytes, &request) - if err == nil { - requestModel = request.Model - messagesCount = len(request.Messages) - } - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + requestModel = request.Model + messagesCount = len(request.Messages) + } + + sessionID = extractCCMSessionID(bodyBytes) + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + // Resolve credential provider + var provider credentialProvider + if len(s.options.Users) > 0 { + var err error + provider, err = credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username) + if err != nil { + s.logger.Error("resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + } + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") + return + } + + provider.pollIfStale(s.ctx) + + anthropicBetaHeader := r.Header.Get("anthropic-beta") + if isExtendedContextRequest(anthropicBetaHeader) { + if _, isSingle := provider.(*singleCredentialProvider); !isSingle { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "extended context (1m) requests will consume Extra usage, please use a default credential directly") + return + } + } + + credential, isNew, err := provider.selectCredential(sessionID) + if err != nil { + writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) + return + } + if isNew { + if username != "" { + s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username) + } else { + s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID) } } - accessToken, err := s.getAccessToken() + accessToken, err := credential.getAccessToken() if err != nil { s.logger.Error("get access token: ", err) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed") @@ -349,7 +410,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } proxyURL := claudeAPIBaseURL + r.URL.RequestURI() - proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body) + requestContext := credential.wrapRequestContext(r.Context()) + defer func() { + requestContext.cancelRequest() + }() + proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body) if err != nil { s.logger.Error("create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") @@ -362,14 +427,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + hasUsageTracker := credential.usageTracker != nil serviceOverridesAcceptEncoding := len(s.httpHeaders.Values("Accept-Encoding")) > 0 - if s.usageTracker != nil && !serviceOverridesAcceptEncoding { - // Strip Accept-Encoding so Go Transport adds it automatically - // and transparently decompresses the response for correct usage counting. + if hasUsageTracker && !serviceOverridesAcceptEncoding { proxyRequest.Header.Del("Accept-Encoding") } - anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta") if anthropicBetaHeader != "" { proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) } else { @@ -383,13 +446,65 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - response, err := s.httpClient.Do(proxyRequest) + response, err := credential.httpClient.Do(proxyRequest) if err != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request") + return + } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) return } + requestContext.releaseCredentialInterrupt() + + // Transparent 429 retry + for response.StatusCode == http.StatusTooManyRequests { + resetAt := parseRateLimitResetFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, credential, resetAt) + credential.updateStateFromHeaders(response.Header) + if bodyBytes == nil || nextCredential == nil { + response.Body.Close() + writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + return + } + response.Body.Close() + s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag) + requestContext.cancelRequest() + requestContext = nextCredential.wrapRequestContext(r.Context()) + retryResponse, retryErr := retryRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders) + if retryErr != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request") + return + } + s.logger.Error("retry request: ", retryErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) + return + } + requestContext.releaseCredentialInterrupt() + response = retryResponse + credential = nextCredential + } defer response.Body.Close() + credential.updateStateFromHeaders(response.Header) + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body)) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) + return + } + + hasUsageTracker = credential.usageTracker != nil + for key, values := range response.Header { if !isHopByHopHeader(key) { w.Header()[key] = values @@ -397,8 +512,8 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(response.StatusCode) - if s.usageTracker != nil && response.StatusCode == http.StatusOK { - s.handleResponseWithTracking(w, response, requestModel, anthropicBetaHeader, messagesCount, username) + if hasUsageTracker && response.StatusCode == http.StatusOK { + s.handleResponseWithTracking(w, response, credential.usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -428,7 +543,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { +func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) isStreaming := err == nil && mediaType == "text/event-stream" @@ -456,7 +571,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if responseModel != "" { totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) - s.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, messagesCount, @@ -557,7 +672,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if responseModel != "" { totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) - s.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, messagesCount, @@ -585,12 +700,8 @@ func (s *Service) Close() error { s.tlsConfig, ) - if s.usageTracker != nil { - s.usageTracker.cancelPendingSave() - saveErr := s.usageTracker.Save() - if saveErr != nil { - s.logger.Error("save usage statistics: ", saveErr) - } + for _, credential := range s.allDefaults { + credential.close() } return err diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 76651a8e14..0cdbd63790 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -138,6 +138,10 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut } defer response.Body.Close() + if response.StatusCode == http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + } if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) return nil, E.New("refresh failed: ", response.Status, " ", string(body)) diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go new file mode 100644 index 0000000000..3c6cf4ed9c --- /dev/null +++ b/service/ocm/credential_state.go @@ -0,0 +1,1022 @@ +package ocm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "errors" + "io" + "math/rand/v2" + "net" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" +) + +const defaultPollInterval = 60 * time.Minute + +type credentialState struct { + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + lastUpdated time.Time + consecutivePollFailures int +} + +type defaultCredential struct { + tag string + credentialPath string + credentials *oauthCredentials + accessMutex sync.RWMutex + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + reserve5h uint8 + reserveWeekly uint8 + usageTracker *AggregatedUsage + dialer N.Dialer + httpClient *http.Client + logger log.ContextLogger + + // Connection interruption + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +type credentialRequestContext struct { + context.Context + releaseOnce sync.Once + cancelOnce sync.Once + releaseFunc func() bool + cancelFunc context.CancelFunc +} + +func (c *credentialRequestContext) releaseCredentialInterrupt() { + c.releaseOnce.Do(func() { + c.releaseFunc() + }) +} + +func (c *credentialRequestContext) cancelRequest() { + c.releaseCredentialInterrupt() + c.cancelOnce.Do(c.cancelFunc) +} + +func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + reserve5h := options.Reserve5h + if reserve5h == 0 { + reserve5h = 1 + } + reserveWeekly := options.ReserveWeekly + if reserveWeekly == 0 { + reserveWeekly = 10 + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: tag, + credentialPath: options.CredentialPath, + reserve5h: reserve5h, + reserveWeekly: reserveWeekly, + dialer: credentialDialer, + httpClient: httpClient, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if options.UsagesPath != "" { + credential.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + return credential, nil +} + +func (c *defaultCredential) start() error { + credentials, err := platformReadCredentials(c.credentialPath) + if err != nil { + return E.Cause(err, "read credentials for ", c.tag) + } + c.credentials = credentials + if c.usageTracker != nil { + err = c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *defaultCredential) getAccessToken() (string, error) { + c.accessMutex.RLock() + if !c.credentials.needsRefresh() { + token := c.credentials.getAccessToken() + c.accessMutex.RUnlock() + return token, nil + } + c.accessMutex.RUnlock() + + c.accessMutex.Lock() + defer c.accessMutex.Unlock() + + if !c.credentials.needsRefresh() { + return c.credentials.getAccessToken(), nil + } + + newCredentials, err := refreshToken(c.httpClient, c.credentials) + if err != nil { + return "", err + } + + c.credentials = newCredentials + + err = platformWriteCredentials(newCredentials, c.credentialPath) + if err != nil { + c.logger.Warn("persist refreshed token for ", c.tag, ": ", err) + } + + return newCredentials.getAccessToken(), nil +} + +func (c *defaultCredential) getAccountID() string { + c.accessMutex.RLock() + defer c.accessMutex.RUnlock() + return c.credentials.getAccountID() +} + +func (c *defaultCredential) isAPIKeyMode() bool { + c.accessMutex.RLock() + defer c.accessMutex.RUnlock() + return c.credentials.isAPIKeyMode() +} + +func (c *defaultCredential) getBaseURL() string { + if c.isAPIKeyMode() { + return openaiAPIBaseURL + } + return chatGPTBackendURL +} + +func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { + c.stateMutex.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier == "" { + activeLimitIdentifier = "codex" + } + + fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") + if fiveHourPercent != "" { + value, err := strconv.ParseFloat(fiveHourPercent, 64) + if err == nil { + c.state.fiveHourUtilization = value + } + } + fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") + if fiveHourResetAt != "" { + value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) + if err == nil { + c.state.fiveHourReset = time.Unix(value, 0) + } + } + weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") + if weeklyPercent != "" { + value, err := strconv.ParseFloat(weeklyPercent, 64) + if err == nil { + c.state.weeklyUtilization = value + } + } + weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") + if weeklyResetAt != "" { + value, err := strconv.ParseInt(weeklyResetAt, 10, 64) + if err == nil { + c.state.weeklyReset = time.Unix(value, 0) + } + } + c.state.lastUpdated = time.Now() + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateMutex.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) isUsable() bool { + c.stateMutex.RLock() + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateMutex.RUnlock() + return false + } + c.stateMutex.RUnlock() + c.stateMutex.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.checkReservesLocked() + c.stateMutex.Unlock() + return usable + } + usable := c.checkReservesLocked() + c.stateMutex.RUnlock() + return usable +} + +func (c *defaultCredential) checkReservesLocked() bool { + if c.state.fiveHourUtilization >= float64(100-c.reserve5h) { + return false + } + if c.state.weeklyUtilization >= float64(100-c.reserveWeekly) { + return false + } + return true +} + +// checkTransitionLocked detects usable→unusable transition. +// Must be called with stateMutex write lock held. +func (c *defaultCredential) checkTransitionLocked() bool { + unusable := c.state.hardRateLimited || !c.checkReservesLocked() + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *defaultCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFunc: stop, + cancelFunc: cancel, + } +} + +func (c *defaultCredential) weeklyUtilization() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyUtilization +} + +func (c *defaultCredential) lastUpdatedTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.lastUpdated +} + +func (c *defaultCredential) markUsagePollAttempted() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateMutex.RLock() + failures := c.state.consecutivePollFailures + c.stateMutex.RUnlock() + if failures <= 0 { + return baseInterval + } + if failures > 4 { + failures = 4 + } + return baseInterval * time.Duration(1< 0 { + c.state.fiveHourReset = time.Unix(w.ResetAt, 0) + } + } + if w := usageResponse.RateLimit.SecondaryWindow; w != nil { + c.state.weeklyUtilization = w.UsedPercent + if w.ResetAt > 0 { + c.state.weeklyReset = time.Unix(w.ResetAt, 0) + } + } + } + if usageResponse.PlanType != "" { + c.state.accountType = usageResponse.PlanType + } + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) close() { + if c.usageTracker != nil { + c.usageTracker.cancelPendingSave() + err := c.usageTracker.Save() + if err != nil { + c.logger.Error("save usage statistics for ", c.tag, ": ", err) + } + } +} + +type credentialProvider interface { + selectCredential(sessionID string) (*defaultCredential, bool, error) + onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential + pollIfStale(ctx context.Context) + allDefaults() []*defaultCredential + close() +} + +type singleCredentialProvider struct { + credential *defaultCredential +} + +func (p *singleCredentialProvider) selectCredential(_ string) (*defaultCredential, bool, error) { + if !p.credential.isUsable() { + return nil, false, E.New("credential ", p.credential.tag, " is rate-limited") + } + return p.credential, false, nil +} + +func (p *singleCredentialProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + return nil +} + +func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { + if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) { + p.credential.pollUsage(ctx) + } +} + +func (p *singleCredentialProvider) allDefaults() []*defaultCredential { + return []*defaultCredential{p.credential} +} + +func (p *singleCredentialProvider) close() {} + +const sessionExpiry = 24 * time.Hour + +type sessionEntry struct { + tag string + createdAt time.Time +} + +type balancerProvider struct { + credentials []*defaultCredential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + sessionMutex sync.RWMutex + sessions map[string]sessionEntry + logger log.ContextLogger +} + +func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &balancerProvider{ + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + sessions: make(map[string]sessionEntry), + logger: logger, + } +} + +func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) { + if sessionID != "" { + p.sessionMutex.RLock() + entry, exists := p.sessions[sessionID] + p.sessionMutex.RUnlock() + if exists { + for _, credential := range p.credentials { + if credential.tag == entry.tag && credential.isUsable() { + return credential, false, nil + } + } + p.sessionMutex.Lock() + delete(p.sessions, sessionID) + p.sessionMutex.Unlock() + } + } + + best := p.pickCredential() + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + + isNew := sessionID != "" + if isNew { + p.sessionMutex.Lock() + p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessionMutex.Unlock() + } + return best, isNew, nil +} + +func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + if sessionID != "" { + p.sessionMutex.Lock() + delete(p.sessions, sessionID) + p.sessionMutex.Unlock() + } + + best := p.pickCredential() + if best != nil && sessionID != "" { + p.sessionMutex.Lock() + p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessionMutex.Unlock() + } + return best +} + +func (p *balancerProvider) pickCredential() *defaultCredential { + switch p.strategy { + case "round_robin": + return p.pickRoundRobin() + case "random": + return p.pickRandom() + default: + return p.pickLeastUsed() + } +} + +func (p *balancerProvider) pickLeastUsed() *defaultCredential { + var best *defaultCredential + bestUtilization := float64(101) + for _, credential := range p.credentials { + if !credential.isUsable() { + continue + } + utilization := credential.weeklyUtilization() + if utilization < bestUtilization { + bestUtilization = utilization + best = credential + } + } + return best +} + +func (p *balancerProvider) pickRoundRobin() *defaultCredential { + start := int(p.roundRobinIndex.Add(1) - 1) + count := len(p.credentials) + for offset := range count { + candidate := p.credentials[(start+offset)%count] + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *balancerProvider) pickRandom() *defaultCredential { + var usable []*defaultCredential + for _, candidate := range p.credentials { + if candidate.isUsable() { + usable = append(usable, candidate) + } + } + if len(usable) == 0 { + return nil + } + return usable[rand.IntN(len(usable))] +} + +func (p *balancerProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionMutex.Lock() + for id, entry := range p.sessions { + if now.Sub(entry.createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionMutex.Unlock() + + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) + } + } +} + +func (p *balancerProvider) allDefaults() []*defaultCredential { + return p.credentials +} + +func (p *balancerProvider) close() {} + +type fallbackProvider struct { + credentials []*defaultCredential + pollInterval time.Duration + logger log.ContextLogger +} + +func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &fallbackProvider{ + credentials: credentials, + pollInterval: pollInterval, + logger: logger, + } +} + +func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) { + for _, credential := range p.credentials { + if credential.isUsable() { + return credential, false, nil + } + } + return nil, false, allRateLimitedError(p.credentials) +} + +func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + for _, candidate := range p.credentials { + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *fallbackProvider) pollIfStale(ctx context.Context) { + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) + } + } +} + +func (p *fallbackProvider) allDefaults() []*defaultCredential { + return p.credentials +} + +func (p *fallbackProvider) close() {} + +func allRateLimitedError(credentials []*defaultCredential) error { + var earliest time.Time + for _, credential := range credentials { + resetAt := credential.earliestReset() + if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { + earliest = resetAt + } + } + if earliest.IsZero() { + return E.New("all credentials rate-limited") + } + return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) +} + +func buildOCMCredentialProviders( + ctx context.Context, + options option.OCMServiceOptions, + logger log.ContextLogger, +) (map[string]credentialProvider, []*defaultCredential, error) { + defaultCredentials := make(map[string]*defaultCredential) + var allDefaults []*defaultCredential + providers := make(map[string]credentialProvider) + + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "default": + credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + if err != nil { + return nil, nil, err + } + defaultCredentials[credOpt.Tag] = credential + allDefaults = append(allDefaults, credential) + providers[credOpt.Tag] = &singleCredentialProvider{credential: credential} + } + } + + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "balancer": + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger) + case "fallback": + subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newFallbackProvider(subCredentials, time.Duration(credOpt.FallbackOptions.PollInterval), logger) + } + } + + return providers, allDefaults, nil +} + +func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) { + credentials := make([]*defaultCredential, 0, len(tags)) + for _, tag := range tags { + credential, exists := defaults[tag] + if !exists { + return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag) + } + credentials = append(credentials, credential) + } + if len(credentials) == 0 { + return nil, E.New("credential ", parentTag, " has no sub-credentials") + } + return credentials, nil +} + +func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time { + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier != "" { + resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at" + if resetStr := headers.Get(resetHeader); resetStr != "" { + value, err := strconv.ParseInt(resetStr, 10, 64) + if err == nil { + return time.Unix(value, 0) + } + } + } + if retryAfter := headers.Get("Retry-After"); retryAfter != "" { + seconds, err := strconv.ParseInt(retryAfter, 10, 64) + if err == nil { + return time.Now().Add(time.Duration(seconds) * time.Second) + } + } + return time.Now().Add(5 * time.Minute) +} + +func validateOCMOptions(options option.OCMServiceOptions) error { + hasCredentials := len(options.Credentials) > 0 + hasLegacyPath := options.CredentialPath != "" + hasLegacyUsages := options.UsagesPath != "" + hasLegacyDetour := options.Detour != "" + + if hasCredentials && hasLegacyPath { + return E.New("credential_path and credentials are mutually exclusive") + } + if hasCredentials && hasLegacyUsages { + return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") + } + if hasCredentials && hasLegacyDetour { + return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") + } + + if hasCredentials { + tags := make(map[string]bool) + for _, credential := range options.Credentials { + if tags[credential.Tag] { + return E.New("duplicate credential tag: ", credential.Tag) + } + tags[credential.Tag] = true + if credential.Type == "default" || credential.Type == "" { + if credential.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") + } + if credential.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") + } + } + if credential.Type == "balancer" { + switch credential.BalancerOptions.Strategy { + case "", "least_used", "round_robin", "random": + default: + return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) + } + } + } + + for _, user := range options.Users { + if user.Credential == "" { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + } + } + + return nil +} + +func validateOCMCompositeCredentialModes( + options option.OCMServiceOptions, + providers map[string]credentialProvider, +) error { + for _, credential := range options.Credentials { + if credential.Type != "balancer" && credential.Type != "fallback" { + continue + } + + provider, exists := providers[credential.Tag] + if !exists { + return E.New("unknown credential: ", credential.Tag) + } + + for _, subCredential := range provider.allDefaults() { + if subCredential.isAPIKeyMode() { + return E.New( + "credential ", credential.Tag, + " references API key default credential ", subCredential.tag, + "; balancer and fallback only support OAuth default credentials", + ) + } + } + } + + return nil +} + +func retryOCMRequestWithBody( + ctx context.Context, + originalRequest *http.Request, + bodyBytes []byte, + credential *defaultCredential, + httpHeaders http.Header, +) (*http.Response, error) { + accessToken, err := credential.getAccessToken() + if err != nil { + return nil, E.Cause(err, "get access token for ", credential.tag) + } + + baseURL := credential.getBaseURL() + path := originalRequest.URL.Path + var proxyPath string + if credential.isAPIKeyMode() { + proxyPath = path + } else { + proxyPath = strings.TrimPrefix(path, "/v1") + } + + proxyURL := baseURL + proxyPath + if originalRequest.URL.RawQuery != "" { + proxyURL += "?" + originalRequest.URL.RawQuery + } + + var body io.Reader + if bodyBytes != nil { + body = bytes.NewReader(bodyBytes) + } + retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, body) + if err != nil { + return nil, err + } + + for key, values := range originalRequest.Header { + if !isHopByHopHeader(key) && key != "Authorization" { + retryRequest.Header[key] = values + } + } + for key, values := range httpHeaders { + retryRequest.Header.Del(key) + retryRequest.Header[key] = values + } + retryRequest.Header.Set("Authorization", "Bearer "+accessToken) + if accountID := credential.getAccountID(); accountID != "" { + retryRequest.Header.Set("ChatGPT-Account-Id", accountID) + } + + return credential.httpClient.Do(retryRequest) +} + +func credentialForUser( + userCredentialMap map[string]string, + providers map[string]credentialProvider, + legacyProvider credentialProvider, + username string, +) (credentialProvider, error) { + if legacyProvider != nil { + return legacyProvider, nil + } + tag, exists := userCredentialMap[username] + if !exists { + return nil, E.New("no credential mapping for user: ", username) + } + provider, exists := providers[tag] + if !exists { + return nil, E.New("unknown credential: ", tag) + } + return provider, nil +} + +func noUserCredentialProvider( + providers map[string]credentialProvider, + legacyProvider credentialProvider, + options option.OCMServiceOptions, +) credentialProvider { + if legacyProvider != nil { + return legacyProvider + } + if len(options.Credentials) > 0 { + tag := options.Credentials[0].Tag + return providers[tag] + } + return nil +} diff --git a/service/ocm/service.go b/service/ocm/service.go index 8b66964a93..75f28f2c1a 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -3,12 +3,10 @@ package ocm import ( "bytes" "context" - stdTLS "crypto/tls" "encoding/json" "errors" "io" "mime" - "net" "net/http" "strconv" "strings" @@ -17,7 +15,6 @@ import ( "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" - "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/listener" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" @@ -26,9 +23,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/ntp" aTLS "github.com/sagernet/sing/common/tls" "github.com/go-chi/chi/v5" @@ -52,17 +47,77 @@ type errorDetails struct { } func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) { + writeJSONErrorWithCode(w, r, statusCode, errorType, "", message) +} + +func writeJSONErrorWithCode(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, errorCode string, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) json.NewEncoder(w).Encode(errorResponse{ Error: errorDetails{ Type: errorType, + Code: errorCode, Message: message, }, }) } +func writePlainTextError(w http.ResponseWriter, statusCode int, message string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(statusCode) + _, _ = io.WriteString(w, message) +} + +const ( + retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" + retryableUsageCode = "credential_usage_exhausted" +) + +func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool { + if provider == nil || currentCredential == nil { + return false + } + for _, credential := range provider.allDefaults() { + if credential == currentCredential { + continue + } + if credential.isUsable() { + return true + } + } + return false +} + +func unavailableCredentialMessage(provider credentialProvider, fallback string) string { + if provider == nil { + return fallback + } + return allRateLimitedError(provider.allDefaults()).Error() +} + +func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { + writeJSONErrorWithCode(w, r, http.StatusServiceUnavailable, "server_error", retryableUsageCode, retryableUsageMessage) +} + +func writeNonRetryableCredentialError(w http.ResponseWriter, message string) { + writePlainTextError(w, http.StatusBadRequest, message) +} + +func writeCredentialUnavailableError( + w http.ResponseWriter, + r *http.Request, + provider credentialProvider, + currentCredential *defaultCredential, + fallback string, +) { + if hasAlternativeCredential(provider, currentCredential) { + writeRetryableUsageError(w, r) + return + } + writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback)) +} + func isHopByHopHeader(header string) bool { switch strings.ToLower(header) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": @@ -127,72 +182,43 @@ type Service struct { boxService.Adapter ctx context.Context logger log.ContextLogger - credentialPath string - credentials *oauthCredentials - users []option.OCMUser - dialer N.Dialer - httpClient *http.Client + options option.OCMServiceOptions httpHeaders http.Header listener *listener.Listener tlsConfig tls.ServerConfig httpServer *http.Server userManager *UserManager - accessMutex sync.RWMutex - usageTracker *AggregatedUsage webSocketMutex sync.Mutex webSocketGroup sync.WaitGroup webSocketConns map[*webSocketSession]struct{} shuttingDown bool + + // Legacy mode + legacyCredential *defaultCredential + legacyProvider credentialProvider + + // Multi-credential mode + providers map[string]credentialProvider + allDefaults []*defaultCredential + userCredentialMap map[string]string } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) { - serviceDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) + err := validateOCMOptions(options) if err != nil { - return nil, E.Cause(err, "create dialer") - } - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSClientConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, + return nil, E.Cause(err, "validate options") } userManager := &UserManager{ tokenMap: make(map[string]string), } - var usageTracker *AggregatedUsage - if options.UsagesPath != "" { - usageTracker = &AggregatedUsage{ - LastUpdated: time.Now(), - Combinations: make([]CostCombination, 0), - filePath: options.UsagesPath, - logger: logger, - } - } - service := &Service{ - Adapter: boxService.NewAdapter(C.TypeOCM, tag), - ctx: ctx, - logger: logger, - credentialPath: options.CredentialPath, - users: options.Users, - dialer: serviceDialer, - httpClient: httpClient, - httpHeaders: options.Headers.Build(), + Adapter: boxService.NewAdapter(C.TypeOCM, tag), + ctx: ctx, + logger: logger, + options: options, + httpHeaders: options.Headers.Build(), listener: listener.New(listener.Options{ Context: ctx, Logger: logger, @@ -200,10 +226,36 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio Listen: options.ListenOptions, }), userManager: userManager, - usageTracker: usageTracker, webSocketConns: make(map[*webSocketSession]struct{}), } + if len(options.Credentials) > 0 { + providers, allDefaults, err := buildOCMCredentialProviders(ctx, options, logger) + if err != nil { + return nil, E.Cause(err, "build credential providers") + } + service.providers = providers + service.allDefaults = allDefaults + + userCredentialMap := make(map[string]string) + for _, user := range options.Users { + userCredentialMap[user.Name] = user.Credential + } + service.userCredentialMap = userCredentialMap + } else { + credential, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{ + CredentialPath: options.CredentialPath, + UsagesPath: options.UsagesPath, + Detour: options.Detour, + }, logger) + if err != nil { + return nil, err + } + service.legacyCredential = credential + service.legacyProvider = &singleCredentialProvider{credential: credential} + service.allDefaults = []*defaultCredential{credential} + } + if options.TLS != nil { tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { @@ -220,18 +272,22 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } - s.userManager.UpdateUsers(s.users) + s.userManager.UpdateUsers(s.options.Users) - credentials, err := platformReadCredentials(s.credentialPath) - if err != nil { - return E.Cause(err, "read credentials") + for _, credential := range s.allDefaults { + err := credential.start() + if err != nil { + return err + } + tag := credential.tag + credential.onBecameUnusable = func() { + s.interruptWebSocketSessionsForCredential(tag) + } } - s.credentials = credentials - - if s.usageTracker != nil { - err = s.usageTracker.Load() + if len(s.options.Credentials) > 0 { + err := validateOCMCompositeCredentialModes(s.options, s.providers) if err != nil { - s.logger.Warn("load usage statistics: ", err) + return E.Cause(err, "validate loaded credentials") } } @@ -241,7 +297,7 @@ func (s *Service) Start(stage adapter.StartStage) error { s.httpServer = &http.Server{Handler: router} if s.tlsConfig != nil { - err = s.tlsConfig.Start() + err := s.tlsConfig.Start() if err != nil { return E.Cause(err, "create TLS config") } @@ -269,54 +325,15 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } -func (s *Service) getAccessToken() (string, error) { - s.accessMutex.RLock() - if !s.credentials.needsRefresh() { - token := s.credentials.getAccessToken() - s.accessMutex.RUnlock() - return token, nil - } - s.accessMutex.RUnlock() - - s.accessMutex.Lock() - defer s.accessMutex.Unlock() - - if !s.credentials.needsRefresh() { - return s.credentials.getAccessToken(), nil +func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { + if len(s.options.Users) > 0 { + return credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username) } - - newCredentials, err := refreshToken(s.httpClient, s.credentials) - if err != nil { - return "", err + provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + if provider == nil { + return nil, E.New("no credential available") } - - s.credentials = newCredentials - - err = platformWriteCredentials(newCredentials, s.credentialPath) - if err != nil { - s.logger.Warn("persist refreshed token: ", err) - } - - return newCredentials.getAccessToken(), nil -} - -func (s *Service) getAccountID() string { - s.accessMutex.RLock() - defer s.accessMutex.RUnlock() - return s.credentials.getAccountID() -} - -func (s *Service) isAPIKeyMode() bool { - s.accessMutex.RLock() - defer s.accessMutex.RUnlock() - return s.credentials.isAPIKeyMode() -} - -func (s *Service) getBaseURL() string { - if s.isAPIKeyMode() { - return openaiAPIBaseURL - } - return chatGPTBackendURL + return provider, nil } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -326,20 +343,8 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - var proxyPath string - if s.isAPIKeyMode() { - proxyPath = path - } else { - if path == "/v1/chat/completions" { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "chat completions endpoint is only available in API key mode") - return - } - proxyPath = strings.TrimPrefix(path, "/v1") - } - var username string - if len(s.users) > 0 { + if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") @@ -361,39 +366,91 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + sessionID := r.Header.Get("session_id") + + // Resolve credential provider + provider, err := s.resolveCredentialProvider(username) + if err != nil { + s.logger.Error("resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + provider.pollIfStale(s.ctx) + + credential, isNew, err := provider.selectCredential(sessionID) + if err != nil { + writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) + return + } + if isNew { + if username != "" { + s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username) + } else { + s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID) + } + } + if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(w, r, proxyPath, username) + s.handleWebSocket(w, r, path, username, sessionID, provider, credential) return } - var requestModel string + var proxyPath string + if credential.isAPIKeyMode() { + proxyPath = path + } else { + if path == "/v1/chat/completions" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "chat completions endpoint is only available in API key mode") + return + } + proxyPath = strings.TrimPrefix(path, "/v1") + } + + shouldTrackUsage := credential.usageTracker != nil && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) + canRetryRequest := len(provider.allDefaults()) > 1 - if s.usageTracker != nil && r.Body != nil { - bodyBytes, err := io.ReadAll(r.Body) - if err == nil { + // Read body for model extraction and retry buffer when JSON replay is useful. + var bodyBytes []byte + var requestModel string + if r.Body != nil && (shouldTrackUsage || canRetryRequest) { + mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) + isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) + if isJSONRequest { + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + s.logger.Error("read request body: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") + return + } var request struct { Model string `json:"model"` } - err := json.Unmarshal(bodyBytes, &request) - if err == nil { + if json.Unmarshal(bodyBytes, &request) == nil { requestModel = request.Model } - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } } - accessToken, err := s.getAccessToken() + accessToken, err := credential.getAccessToken() if err != nil { s.logger.Error("get access token: ", err) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed") return } - proxyURL := s.getBaseURL() + proxyPath + proxyURL := credential.getBaseURL() + proxyPath if r.URL.RawQuery != "" { proxyURL += "?" + r.URL.RawQuery } - proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body) + requestContext := credential.wrapRequestContext(r.Context()) + defer func() { + requestContext.cancelRequest() + }() + proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body) if err != nil { s.logger.Error("create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") @@ -413,17 +470,68 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - if accountID := s.getAccountID(); accountID != "" { + if accountID := credential.getAccountID(); accountID != "" { proxyRequest.Header.Set("ChatGPT-Account-Id", accountID) } - response, err := s.httpClient.Do(proxyRequest) + response, err := credential.httpClient.Do(proxyRequest) if err != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request") + return + } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) return } + requestContext.releaseCredentialInterrupt() + + // Transparent 429 retry + for response.StatusCode == http.StatusTooManyRequests { + resetAt := parseOCMRateLimitResetFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, credential, resetAt) + needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete + credential.updateStateFromHeaders(response.Header) + if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { + response.Body.Close() + writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + return + } + response.Body.Close() + s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag) + requestContext.cancelRequest() + requestContext = nextCredential.wrapRequestContext(r.Context()) + retryResponse, retryErr := retryOCMRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders) + if retryErr != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request") + return + } + s.logger.Error("retry request: ", retryErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) + return + } + requestContext.releaseCredentialInterrupt() + response = retryResponse + credential = nextCredential + } defer response.Body.Close() + credential.updateStateFromHeaders(response.Header) + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body)) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) + return + } + for key, values := range response.Header { if !isHopByHopHeader(key) { w.Header()[key] = values @@ -431,10 +539,10 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(response.StatusCode) - trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK && - (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) - if trackUsage { - s.handleResponseWithTracking(w, response, path, requestModel, username) + hasUsageTracker := credential.usageTracker != nil + if hasUsageTracker && response.StatusCode == http.StatusOK && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { + s.handleResponseWithTracking(w, response, credential.usageTracker, path, requestModel, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -464,7 +572,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) { +func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { isChatCompletions := path == "/v1/chat/completions" weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) @@ -508,7 +616,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons } if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - s.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens, @@ -619,7 +727,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if inputTokens > 0 || outputTokens > 0 { if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - s.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens, @@ -650,12 +758,8 @@ func (s *Service) Close() error { } s.webSocketGroup.Wait() - if s.usageTracker != nil { - s.usageTracker.cancelPendingSave() - saveErr := s.usageTracker.Save() - if saveErr != nil { - s.logger.Error("save usage statistics: ", saveErr) - } + for _, credential := range s.allDefaults { + credential.close() } return err @@ -693,6 +797,20 @@ func (s *Service) isShuttingDown() bool { return s.shuttingDown } +func (s *Service) interruptWebSocketSessionsForCredential(tag string) { + s.webSocketMutex.Lock() + var toClose []*webSocketSession + for session := range s.webSocketConns { + if session.credentialTag == tag { + toClose = append(toClose, session) + } + } + s.webSocketMutex.Unlock() + for _, session := range toClose { + session.Close() + } +} + func (s *Service) startWebSocketShutdown() []*webSocketSession { s.webSocketMutex.Lock() defer s.webSocketMutex.Unlock() diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index d19f2df81b..eafd37aaed 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -1,12 +1,14 @@ package ocm import ( + "bufio" "context" stdTLS "crypto/tls" "encoding/json" "io" "net" "net/http" + "net/textproto" "strings" "sync" "time" @@ -22,9 +24,10 @@ import ( ) type webSocketSession struct { - clientConn net.Conn - upstreamConn net.Conn - closeOnce sync.Once + clientConn net.Conn + upstreamConn net.Conn + credentialTag string + closeOnce sync.Once } func (s *webSocketSession) Close() { @@ -76,57 +79,113 @@ func isForwardableWebSocketRequestHeader(key string) bool { } } -func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) { - accessToken, err := s.getAccessToken() - if err != nil { - s.logger.Error("get access token for websocket: ", err) - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") - return - } +func (s *Service) handleWebSocket( + w http.ResponseWriter, + r *http.Request, + path string, + username string, + sessionID string, + provider credentialProvider, + credential *defaultCredential, +) { + var ( + err error + upstreamConn net.Conn + upstreamBufferedReader *bufio.Reader + upstreamResponseHeaders http.Header + statusCode int + ) - upstreamURL := buildUpstreamWebSocketURL(s.getBaseURL(), proxyPath) - if r.URL.RawQuery != "" { - upstreamURL += "?" + r.URL.RawQuery - } + for { + accessToken, accessErr := credential.getAccessToken() + if accessErr != nil { + s.logger.Error("get access token for websocket: ", accessErr) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") + return + } - upstreamHeaders := make(http.Header) - for key, values := range r.Header { - if isForwardableWebSocketRequestHeader(key) { + var proxyPath string + if credential.isAPIKeyMode() { + proxyPath = path + } else { + proxyPath = strings.TrimPrefix(path, "/v1") + } + + upstreamURL := buildUpstreamWebSocketURL(credential.getBaseURL(), proxyPath) + if r.URL.RawQuery != "" { + upstreamURL += "?" + r.URL.RawQuery + } + + upstreamHeaders := make(http.Header) + for key, values := range r.Header { + if isForwardableWebSocketRequestHeader(key) { + upstreamHeaders[key] = values + } + } + for key, values := range s.httpHeaders { + upstreamHeaders.Del(key) upstreamHeaders[key] = values } - } - for key, values := range s.httpHeaders { - upstreamHeaders.Del(key) - upstreamHeaders[key] = values - } - upstreamHeaders.Set("Authorization", "Bearer "+accessToken) - if accountID := s.getAccountID(); accountID != "" { - upstreamHeaders.Set("ChatGPT-Account-Id", accountID) - } + upstreamHeaders.Set("Authorization", "Bearer "+accessToken) + if accountID := credential.getAccountID(); accountID != "" { + upstreamHeaders.Set("ChatGPT-Account-Id", accountID) + } - upstreamResponseHeaders := make(http.Header) - upstreamDialer := ws.Dialer{ - NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { - return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - TLSConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(s.ctx), - Time: ntp.TimeFuncFromContext(s.ctx), - }, - Header: ws.HandshakeHeaderHTTP(upstreamHeaders), - OnHeader: func(key, value []byte) error { - upstreamResponseHeaders.Add(string(key), string(value)) - return nil - }, - } + upstreamResponseHeaders = make(http.Header) + statusCode = 0 + upstreamDialer := ws.Dialer{ + NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credential.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + TLSConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(s.ctx), + Time: ntp.TimeFuncFromContext(s.ctx), + }, + Header: ws.HandshakeHeaderHTTP(upstreamHeaders), + // gobwas/ws@v1.4.0: the response io.Reader is + // MultiReader(statusLine_without_CRLF, "\r\n", bufferedConn). + // ReadString('\n') consumes the status line, then ReadMIMEHeader + // parses the remaining headers. + OnStatusError: func(status int, reason []byte, response io.Reader) { + statusCode = status + bufferedResponse := bufio.NewReader(response) + _, readErr := bufferedResponse.ReadString('\n') + if readErr != nil { + return + } + mimeHeader, readErr := textproto.NewReader(bufferedResponse).ReadMIMEHeader() + if readErr == nil { + upstreamResponseHeaders = http.Header(mimeHeader) + } + }, + OnHeader: func(key, value []byte) error { + upstreamResponseHeaders.Add(string(key), string(value)) + return nil + }, + } - upstreamConn, upstreamBufferedReader, _, err := upstreamDialer.Dial(r.Context(), upstreamURL) - if err != nil { + upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(s.ctx, upstreamURL) + if err == nil { + break + } + if statusCode == http.StatusTooManyRequests { + resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders) + nextCredential := provider.onRateLimited(sessionID, credential, resetAt) + if nextCredential == nil { + credential.updateStateFromHeaders(upstreamResponseHeaders) + writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + return + } + s.logger.Info("retrying websocket with credential ", nextCredential.tag, " after 429 from ", credential.tag) + credential = nextCredential + continue + } s.logger.Error("dial upstream websocket: ", err) writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed") return } + credential.updateStateFromHeaders(upstreamResponseHeaders) weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders) clientResponseHeaders := make(http.Header) @@ -151,8 +210,9 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP return } session := &webSocketSession{ - clientConn: clientConn, - upstreamConn: upstreamConn, + clientConn: clientConn, + upstreamConn: upstreamConn, + credentialTag: credential.tag, } if !s.registerWebSocketSession(session) { session.Close() @@ -177,17 +237,17 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel) + s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, credential, modelChannel) }() go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, credential, modelChannel, username, weeklyCycleHint) }() waitGroup.Wait() } -func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, modelChannel chan<- string) { +func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, credential *defaultCredential, modelChannel chan<- string) { for { data, opCode, err := wsutil.ReadClientData(clientConn) if err != nil { @@ -197,7 +257,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo return } - if opCode == ws.OpText && s.usageTracker != nil { + if opCode == ws.OpText && credential.usageTracker != nil { var request struct { Type string `json:"type"` Model string `json:"model"` @@ -220,7 +280,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo } } -func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, credential *defaultCredential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { var requestModel string for { data, opCode, err := wsutil.ReadServerData(upstreamReadWriter) @@ -231,7 +291,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite return } - if opCode == ws.OpText && s.usageTracker != nil { + if opCode == ws.OpText && credential.usageTracker != nil { select { case model := <-modelChannel: requestModel = model @@ -257,7 +317,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite } if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - s.usageTracker.AddUsageWithCycleHint( + credential.usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens, From da8ff6f5788cb9831d9f9cb2d35b54eb00702925 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 12 Mar 2026 23:17:47 +0800 Subject: [PATCH 08/96] ccm/ocm: Add external credential support for cross-instance usage sharing Extract credential interface from *defaultCredential to support both default (OAuth) and external (remote proxy) credential types. External credentials proxy requests to a remote ccm/ocm instance with bearer token auth, poll a /status endpoint for utilization, and parse aggregated rate limit headers from responses. Add allow_external_usage user flag to control whether balancer/fallback providers may select external credentials. Add status endpoint (/ccm/v1/status, /ocm/v1/status) returning averaged utilization across eligible credentials. Rewrite response rate limit headers for external users with aggregated values. --- option/ccm.go | 22 +- option/ocm.go | 22 +- service/ccm/credential_external.go | 428 ++++++++++++++++++++++++++ service/ccm/credential_state.go | 434 ++++++++++++++++----------- service/ccm/service.go | 252 ++++++++++------ service/ocm/credential_external.go | 463 +++++++++++++++++++++++++++++ service/ocm/credential_state.go | 415 ++++++++++++++++---------- service/ocm/service.go | 270 +++++++++++------ service/ocm/service_websocket.go | 49 +-- 9 files changed, 1829 insertions(+), 526 deletions(-) create mode 100644 service/ccm/credential_external.go create mode 100644 service/ocm/credential_external.go diff --git a/option/ccm.go b/option/ccm.go index edfe2e417b..6846dfccb6 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -19,15 +19,18 @@ type CCMServiceOptions struct { } type CCMUser struct { - Name string `json:"name,omitempty"` - Token string `json:"token,omitempty"` - Credential string `json:"credential,omitempty"` + Name string `json:"name,omitempty"` + Token string `json:"token,omitempty"` + Credential string `json:"credential,omitempty"` + ExternalCredential string `json:"external_credential,omitempty"` + AllowExternalUsage bool `json:"allow_external_usage,omitempty"` } type _CCMCredential struct { Type string `json:"type,omitempty"` Tag string `json:"tag"` DefaultOptions CCMDefaultCredentialOptions `json:"-"` + ExternalOptions CCMExternalCredentialOptions `json:"-"` BalancerOptions CCMBalancerCredentialOptions `json:"-"` FallbackOptions CCMFallbackCredentialOptions `json:"-"` } @@ -40,6 +43,8 @@ func (c CCMCredential) MarshalJSON() ([]byte, error) { case "", "default": c.Type = "" v = c.DefaultOptions + case "external": + v = c.ExternalOptions case "balancer": v = c.BalancerOptions case "fallback": @@ -63,6 +68,8 @@ func (c *CCMCredential) UnmarshalJSON(bytes []byte) error { case "", "default": c.Type = "default" v = &c.DefaultOptions + case "external": + v = &c.ExternalOptions case "balancer": v = &c.BalancerOptions case "fallback": @@ -87,6 +94,15 @@ type CCMBalancerCredentialOptions struct { PollInterval badoption.Duration `json:"poll_interval,omitempty"` } +type CCMExternalCredentialOptions struct { + URL string `json:"url"` + ServerOptions + Token string `json:"token"` + Detour string `json:"detour,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` +} + type CCMFallbackCredentialOptions struct { Credentials badoption.Listable[string] `json:"credentials"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` diff --git a/option/ocm.go b/option/ocm.go index 832b455288..4d495ff27a 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -19,15 +19,18 @@ type OCMServiceOptions struct { } type OCMUser struct { - Name string `json:"name,omitempty"` - Token string `json:"token,omitempty"` - Credential string `json:"credential,omitempty"` + Name string `json:"name,omitempty"` + Token string `json:"token,omitempty"` + Credential string `json:"credential,omitempty"` + ExternalCredential string `json:"external_credential,omitempty"` + AllowExternalUsage bool `json:"allow_external_usage,omitempty"` } type _OCMCredential struct { Type string `json:"type,omitempty"` Tag string `json:"tag"` DefaultOptions OCMDefaultCredentialOptions `json:"-"` + ExternalOptions OCMExternalCredentialOptions `json:"-"` BalancerOptions OCMBalancerCredentialOptions `json:"-"` FallbackOptions OCMFallbackCredentialOptions `json:"-"` } @@ -40,6 +43,8 @@ func (c OCMCredential) MarshalJSON() ([]byte, error) { case "", "default": c.Type = "" v = c.DefaultOptions + case "external": + v = c.ExternalOptions case "balancer": v = c.BalancerOptions case "fallback": @@ -63,6 +68,8 @@ func (c *OCMCredential) UnmarshalJSON(bytes []byte) error { case "", "default": c.Type = "default" v = &c.DefaultOptions + case "external": + v = &c.ExternalOptions case "balancer": v = &c.BalancerOptions case "fallback": @@ -87,6 +94,15 @@ type OCMBalancerCredentialOptions struct { PollInterval badoption.Duration `json:"poll_interval,omitempty"` } +type OCMExternalCredentialOptions struct { + URL string `json:"url"` + ServerOptions + Token string `json:"token"` + Detour string `json:"detour,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` +} + type OCMFallbackCredentialOptions struct { Credentials badoption.Listable[string] `json:"credentials"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go new file mode 100644 index 0000000000..c4d2d340f7 --- /dev/null +++ b/service/ccm/credential_external.go @@ -0,0 +1,428 @@ +package ccm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "io" + "net" + "net/http" + "net/url" + "strconv" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/ntp" +) + +type externalCredential struct { + tag string + baseURL string + token string + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger + + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { + parsedURL, err := url.Parse(options.URL) + if err != nil { + return nil, E.Cause(err, "parse url for credential ", tag) + } + + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + + transport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if options.Server != "" { + serverPort := options.ServerPort + if serverPort == 0 { + portStr := parsedURL.Port() + if portStr != "" { + port, parseErr := strconv.ParseUint(portStr, 10, 16) + if parseErr == nil { + serverPort = uint16(port) + } + } + if serverPort == 0 { + if parsedURL.Scheme == "https" { + serverPort = 443 + } else { + serverPort = 80 + } + } + } + destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + return credentialDialer.DialContext(ctx, network, destination) + } + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + } + + if parsedURL.Scheme == "https" { + transport.TLSClientConfig = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + // Strip trailing slash + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + + pollInterval := time.Duration(options.PollInterval) + if pollInterval <= 0 { + pollInterval = 30 * time.Minute + } + + requestContext, cancelRequests := context.WithCancel(context.Background()) + + cred := &externalCredential{ + tag: tag, + baseURL: baseURL, + token: options.Token, + httpClient: &http.Client{Transport: transport}, + pollInterval: pollInterval, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + + if options.UsagesPath != "" { + cred.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + + return cred, nil +} + +func (c *externalCredential) start() error { + if c.usageTracker != nil { + err := c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *externalCredential) tagName() string { + return c.tag +} + +func (c *externalCredential) isExternal() bool { + return true +} + +func (c *externalCredential) isUsable() bool { + c.stateMutex.RLock() + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateMutex.RUnlock() + return false + } + c.stateMutex.RUnlock() + c.stateMutex.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + // No reserve for external: only 100% is unusable + usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 + c.stateMutex.Unlock() + return usable + } + usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 + c.stateMutex.RUnlock() + return usable +} + +func (c *externalCredential) fiveHourUtilization() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.fiveHourUtilization +} + +func (c *externalCredential) weeklyUtilization() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyUtilization +} + +func (c *externalCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateMutex.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *externalCredential) earliestReset() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + if c.state.hardRateLimited { + return c.state.rateLimitResetAt + } + earliest := c.state.fiveHourReset + if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { + earliest = c.state.weeklyReset + } + return earliest +} + +func (c *externalCredential) getAccessToken() (string, error) { + return c.token, nil +} + +func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) { + proxyURL := c.baseURL + original.URL.RequestURI() + var body io.Reader + if bodyBytes != nil { + body = bytes.NewReader(bodyBytes) + } else { + body = original.Body + } + proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) + if err != nil { + return nil, err + } + + for key, values := range original.Header { + if !isHopByHopHeader(key) && key != "Authorization" { + proxyRequest.Header[key] = values + } + } + + proxyRequest.Header.Set("Authorization", "Bearer "+c.token) + + return proxyRequest, nil +} + +func (c *externalCredential) updateStateFromHeaders(headers http.Header) { + c.stateMutex.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + + if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + // Remote CCM writes aggregated utilization as 0.0-1.0; convert to percentage + c.state.fiveHourUtilization = value * 100 + } + } + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { + c.state.fiveHourReset = value + } + if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + c.state.weeklyUtilization = value * 100 + } + } + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { + c.state.weeklyReset = value + } + c.state.lastUpdated = time.Now() + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *externalCredential) checkTransitionLocked() bool { + unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFunc: stop, + cancelFunc: cancel, + } +} + +func (c *externalCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *externalCredential) pollUsage(ctx context.Context) { + if !c.pollAccess.TryLock() { + return + } + defer c.pollAccess.Unlock() + defer c.markUsagePollAttempted() + + statusURL := c.baseURL + "/ccm/v1/status" + httpClient := &http.Client{ + Transport: c.httpClient.Transport, + Timeout: 5 * time.Second, + } + + request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) + if err != nil { + c.logger.Error("poll usage for ", c.tag, ": create request: ", err) + return + } + request.Header.Set("Authorization", "Bearer "+c.token) + + response, err := httpClient.Do(request) + if err != nil { + c.logger.Error("poll usage for ", c.tag, ": ", err) + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + c.stateMutex.Unlock() + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + c.stateMutex.Unlock() + c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + return + } + + var statusResponse struct { + FiveHourUtilization float64 `json:"five_hour_utilization"` + WeeklyUtilization float64 `json:"weekly_utilization"` + } + err = json.NewDecoder(response.Body).Decode(&statusResponse) + if err != nil { + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + c.stateMutex.Unlock() + c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + return + } + + c.stateMutex.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + c.state.consecutivePollFailures = 0 + c.state.fiveHourUtilization = statusResponse.FiveHourUtilization + c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *externalCredential) lastUpdatedTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.lastUpdated +} + +func (c *externalCredential) markUsagePollAttempted() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateMutex.RLock() + failures := c.state.consecutivePollFailures + c.stateMutex.RUnlock() + if failures <= 0 { + return baseInterval + } + if failures > 4 { + failures = 4 + } + return baseInterval * time.Duration(1< 0 + if c.usageTracker != nil && !serviceOverridesAcceptEncoding { + proxyRequest.Header.Del("Accept-Encoding") + } + + anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta") + if anthropicBetaHeader != "" { + proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) + } else { + proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + } + + for key, values := range serviceHeaders { + proxyRequest.Header.Del(key) + proxyRequest.Header[key] = values + } + proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) + + return proxyRequest, nil +} + // credentialProvider is the interface for all credential types. type credentialProvider interface { - selectCredential(sessionID string) (*defaultCredential, bool, error) - onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential + selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential pollIfStale(ctx context.Context) - allDefaults() []*defaultCredential + allCredentials() []credential close() } -// singleCredentialProvider wraps a single default credential (legacy or single default). +// singleCredentialProvider wraps a single credential (legacy or single default). type singleCredentialProvider struct { - credential *defaultCredential + cred credential } -func (p *singleCredentialProvider) selectCredential(_ string) (*defaultCredential, bool, error) { - if !p.credential.isUsable() { - return nil, false, E.New("credential ", p.credential.tag, " is rate-limited") +func (p *singleCredentialProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { + if filter != nil && !filter(p.cred) { + return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } - return p.credential, false, nil + if !p.cred.isUsable() { + return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") + } + return p.cred, false, nil } -func (p *singleCredentialProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { - credential.markRateLimited(resetAt) +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential { + cred.markRateLimited(resetAt) return nil } func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { - if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) { - p.credential.pollUsage(ctx) + if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { + p.cred.pollUsage(ctx) } } -func (p *singleCredentialProvider) allDefaults() []*defaultCredential { - return []*defaultCredential{p.credential} +func (p *singleCredentialProvider) allCredentials() []credential { + return []credential{p.cred} } func (p *singleCredentialProvider) close() {} @@ -546,7 +654,7 @@ type sessionEntry struct { // balancerProvider assigns sessions to credentials based on a configurable strategy. type balancerProvider struct { - credentials []*defaultCredential + credentials []credential strategy string roundRobinIndex atomic.Uint64 pollInterval time.Duration @@ -555,7 +663,7 @@ type balancerProvider struct { logger log.ContextLogger } -func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { +func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval } @@ -568,15 +676,15 @@ func newBalancerProvider(credentials []*defaultCredential, strategy string, poll } } -func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) { +func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { if sessionID != "" { p.sessionMutex.RLock() entry, exists := p.sessions[sessionID] p.sessionMutex.RUnlock() if exists { - for _, credential := range p.credentials { - if credential.tag == entry.tag && credential.isUsable() { - return credential, false, nil + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() { + return cred, false, nil } } p.sessionMutex.Lock() @@ -585,7 +693,7 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia } } - best := p.pickCredential() + best := p.pickCredential(filter) if best == nil { return nil, false, allCredentialsUnavailableError(p.credentials) } @@ -593,61 +701,67 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia isNew := sessionID != "" if isNew { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} p.sessionMutex.Unlock() } return best, isNew, nil } -func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential { - credential.markRateLimited(resetAt) +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential { + cred.markRateLimited(resetAt) if sessionID != "" { p.sessionMutex.Lock() delete(p.sessions, sessionID) p.sessionMutex.Unlock() } - best := p.pickCredential() + best := p.pickCredential(filter) if best != nil && sessionID != "" { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} p.sessionMutex.Unlock() } return best } -func (p *balancerProvider) pickCredential() *defaultCredential { +func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { switch p.strategy { case "round_robin": - return p.pickRoundRobin() + return p.pickRoundRobin(filter) case "random": - return p.pickRandom() + return p.pickRandom(filter) default: - return p.pickLeastUsed() + return p.pickLeastUsed(filter) } } -func (p *balancerProvider) pickLeastUsed() *defaultCredential { - var best *defaultCredential +func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { + var best credential bestUtilization := float64(101) - for _, credential := range p.credentials { - if !credential.isUsable() { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !cred.isUsable() { continue } - utilization := credential.weeklyUtilization() + utilization := cred.weeklyUtilization() if utilization < bestUtilization { bestUtilization = utilization - best = credential + best = cred } } return best } -func (p *balancerProvider) pickRoundRobin() *defaultCredential { +func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { start := int(p.roundRobinIndex.Add(1) - 1) count := len(p.credentials) for offset := range count { candidate := p.credentials[(start+offset)%count] + if filter != nil && !filter(candidate) { + continue + } if candidate.isUsable() { return candidate } @@ -655,9 +769,12 @@ func (p *balancerProvider) pickRoundRobin() *defaultCredential { return nil } -func (p *balancerProvider) pickRandom() *defaultCredential { - var usable []*defaultCredential +func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { + var usable []credential for _, candidate := range p.credentials { + if filter != nil && !filter(candidate) { + continue + } if candidate.isUsable() { usable = append(usable, candidate) } @@ -678,14 +795,14 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) { } p.sessionMutex.Unlock() - for _, credential := range p.credentials { - if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { - credential.pollUsage(ctx) + for _, cred := range p.credentials { + if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { + cred.pollUsage(ctx) } } } -func (p *balancerProvider) allDefaults() []*defaultCredential { +func (p *balancerProvider) allCredentials() []credential { return p.credentials } @@ -693,12 +810,12 @@ func (p *balancerProvider) close() {} // fallbackProvider tries credentials in order. type fallbackProvider struct { - credentials []*defaultCredential + credentials []credential pollInterval time.Duration logger log.ContextLogger } -func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { +func newFallbackProvider(credentials []credential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval } @@ -709,18 +826,24 @@ func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Dur } } -func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) { - for _, credential := range p.credentials { - if credential.isUsable() { - return credential, false, nil +func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if cred.isUsable() { + return cred, false, nil } } return nil, false, allCredentialsUnavailableError(p.credentials) } -func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { - credential.markRateLimited(resetAt) +func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential { + cred.markRateLimited(resetAt) for _, candidate := range p.credentials { + if filter != nil && !filter(candidate) { + continue + } if candidate.isUsable() { return candidate } @@ -729,23 +852,23 @@ func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential } func (p *fallbackProvider) pollIfStale(ctx context.Context) { - for _, credential := range p.credentials { - if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { - credential.pollUsage(ctx) + for _, cred := range p.credentials { + if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { + cred.pollUsage(ctx) } } } -func (p *fallbackProvider) allDefaults() []*defaultCredential { +func (p *fallbackProvider) allCredentials() []credential { return p.credentials } func (p *fallbackProvider) close() {} -func allCredentialsUnavailableError(credentials []*defaultCredential) error { +func allCredentialsUnavailableError(credentials []credential) error { var earliest time.Time - for _, credential := range credentials { - resetAt := credential.earliestReset() + for _, cred := range credentials { + resetAt := cred.earliestReset() if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { earliest = resetAt } @@ -778,34 +901,44 @@ func buildCredentialProviders( ctx context.Context, options option.CCMServiceOptions, logger log.ContextLogger, -) (map[string]credentialProvider, []*defaultCredential, error) { - defaultCredentials := make(map[string]*defaultCredential) - var allDefaults []*defaultCredential +) (map[string]credentialProvider, []credential, error) { + allCredentialMap := make(map[string]credential) + var allCreds []credential providers := make(map[string]credentialProvider) + // Pass 1: create default and external credentials for _, credOpt := range options.Credentials { switch credOpt.Type { case "default": - credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) if err != nil { return nil, nil, err } - defaultCredentials[credOpt.Tag] = credential - allDefaults = append(allDefaults, credential) - providers[credOpt.Tag] = &singleCredentialProvider{credential: credential} + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + case "external": + cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} } } + // Pass 2: create balancer and fallback providers for _, credOpt := range options.Credentials { switch credOpt.Type { case "balancer": - subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag) + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) if err != nil { return nil, nil, err } providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger) case "fallback": - subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag) + subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag) if err != nil { return nil, nil, err } @@ -813,17 +946,17 @@ func buildCredentialProviders( } } - return providers, allDefaults, nil + return providers, allCreds, nil } -func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) { - credentials := make([]*defaultCredential, 0, len(tags)) +func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { + credentials := make([]credential, 0, len(tags)) for _, tag := range tags { - credential, exists := defaults[tag] + cred, exists := allCredentials[tag] if !exists { - return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag) + return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) } - credentials = append(credentials, credential) + credentials = append(credentials, cred) } if len(credentials) == 0 { return nil, E.New("credential ", parentTag, " has no sub-credentials") @@ -835,27 +968,12 @@ func parseRateLimitResetFromHeaders(headers http.Header) time.Time { claim := headers.Get("anthropic-ratelimit-unified-representative-claim") switch claim { case "5h": - if resetStr := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetStr != "" { - value, err := strconv.ParseInt(resetStr, 10, 64) - if err == nil { - return time.Unix(value, 0) - } - } + return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset") case "7d": - if resetStr := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetStr != "" { - value, err := strconv.ParseInt(resetStr, 10, 64) - if err == nil { - return time.Unix(value, 0) - } - } - } - if retryAfter := headers.Get("Retry-After"); retryAfter != "" { - seconds, err := strconv.ParseInt(retryAfter, 10, 64) - if err == nil { - return time.Now().Add(time.Duration(seconds) * time.Second) - } + return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") + default: + panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim)) } - return time.Now().Add(5 * time.Minute) } func validateCCMOptions(options option.CCMServiceOptions) error { @@ -876,24 +994,34 @@ func validateCCMOptions(options option.CCMServiceOptions) error { if hasCredentials { tags := make(map[string]bool) - for _, credential := range options.Credentials { - if tags[credential.Tag] { - return E.New("duplicate credential tag: ", credential.Tag) + credentialTypes := make(map[string]string) + for _, cred := range options.Credentials { + if tags[cred.Tag] { + return E.New("duplicate credential tag: ", cred.Tag) + } + tags[cred.Tag] = true + credentialTypes[cred.Tag] = cred.Type + if cred.Type == "default" || cred.Type == "" { + if cred.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") + } + if cred.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") + } } - tags[credential.Tag] = true - if credential.Type == "default" || credential.Type == "" { - if credential.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") + if cred.Type == "external" { + if cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": external credential requires url") } - if credential.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") + if cred.ExternalOptions.Token == "" { + return E.New("credential ", cred.Tag, ": external credential requires token") } } - if credential.Type == "balancer" { - switch credential.BalancerOptions.Strategy { + if cred.Type == "balancer" { + switch cred.BalancerOptions.Strategy { case "", "least_used", "round_robin", "random": default: - return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) + return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) } } } @@ -905,63 +1033,25 @@ func validateCCMOptions(options option.CCMServiceOptions) error { if !tags[user.Credential] { return E.New("user ", user.Name, " references unknown credential: ", user.Credential) } + if user.ExternalCredential != "" { + if !tags[user.ExternalCredential] { + return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) + } + if credentialTypes[user.ExternalCredential] != "external" { + return E.New("user ", user.Name, ": external_credential must reference an external type credential") + } + } } } return nil } -// retryRequestWithBody re-sends a buffered request body using a different credential. -func retryRequestWithBody( - ctx context.Context, - originalRequest *http.Request, - bodyBytes []byte, - credential *defaultCredential, - httpHeaders http.Header, -) (*http.Response, error) { - accessToken, err := credential.getAccessToken() - if err != nil { - return nil, E.Cause(err, "get access token for ", credential.tag) - } - - proxyURL := claudeAPIBaseURL + originalRequest.URL.RequestURI() - retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, bytes.NewReader(bodyBytes)) - if err != nil { - return nil, err - } - - for key, values := range originalRequest.Header { - if !isHopByHopHeader(key) && key != "Authorization" { - retryRequest.Header[key] = values - } - } - - serviceOverridesAcceptEncoding := len(httpHeaders.Values("Accept-Encoding")) > 0 - if credential.usageTracker != nil && !serviceOverridesAcceptEncoding { - retryRequest.Header.Del("Accept-Encoding") - } - - anthropicBetaHeader := retryRequest.Header.Get("anthropic-beta") - if anthropicBetaHeader != "" { - retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) - } else { - retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - } - - for key, values := range httpHeaders { - retryRequest.Header.Del(key) - retryRequest.Header[key] = values - } - retryRequest.Header.Set("Authorization", "Bearer "+accessToken) - - return credential.httpClient.Do(retryRequest) -} - // credentialForUser finds the credential provider for a user. // In legacy mode, returns the single provider. // In multi-credential mode, returns the provider mapped to the user's credential tag. func credentialForUser( - userCredentialMap map[string]string, + userConfigMap map[string]*option.CCMUser, providers map[string]credentialProvider, legacyProvider credentialProvider, username string, @@ -969,13 +1059,13 @@ func credentialForUser( if legacyProvider != nil { return legacyProvider, nil } - tag, exists := userCredentialMap[username] + userConfig, exists := userConfigMap[username] if !exists { return nil, E.New("no credential mapping for user: ", username) } - provider, exists := providers[tag] + provider, exists := providers[userConfig.Credential] if !exists { - return nil, E.New("unknown credential: ", tag) + return nil, E.New("unknown credential: ", userConfig.Credential) } return provider, nil } diff --git a/service/ccm/service.go b/service/ccm/service.go index ea81b1b762..1238b63041 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -66,15 +66,18 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro }) } -func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool { +func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool { if provider == nil || currentCredential == nil { return false } - for _, credential := range provider.allDefaults() { - if credential == currentCredential { + for _, cred := range provider.allCredentials() { + if cred == currentCredential { continue } - if credential.isUsable() { + if filter != nil && !filter(cred) { + continue + } + if cred.isUsable() { return true } } @@ -85,7 +88,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string) if provider == nil { return fallback } - return allCredentialsUnavailableError(provider.allDefaults()).Error() + return allCredentialsUnavailableError(provider.allCredentials()).Error() } func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { @@ -100,10 +103,11 @@ func writeCredentialUnavailableError( w http.ResponseWriter, r *http.Request, provider credentialProvider, - currentCredential *defaultCredential, + currentCredential credential, + filter func(credential) bool, fallback string, ) { - if hasAlternativeCredential(provider, currentCredential) { + if hasAlternativeCredential(provider, currentCredential, filter) { writeRetryableUsageError(w, r) return } @@ -124,27 +128,15 @@ const ( weeklyWindowMinutes = weeklyWindowSeconds / 60 ) -func parseInt64Header(headers http.Header, headerName string) (int64, bool) { - headerValue := strings.TrimSpace(headers.Get(headerName)) - if headerValue == "" { - return 0, false - } - parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64) - if parseError != nil { - return 0, false - } - return parsedValue, true -} - func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { - resetAtUnix, hasResetAt := parseInt64Header(headers, "anthropic-ratelimit-unified-7d-reset") - if !hasResetAt || resetAtUnix <= 0 { + resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") + if !exists { return nil } return &WeeklyCycleHint{ WindowMinutes: weeklyWindowMinutes, - ResetAt: time.Unix(resetAtUnix, 0).UTC(), + ResetAt: resetAt.UTC(), } } @@ -166,9 +158,9 @@ type Service struct { legacyProvider credentialProvider // Multi-credential mode - providers map[string]credentialProvider - allDefaults []*defaultCredential - userCredentialMap map[string]string + providers map[string]credentialProvider + allCredentials []credential + userConfigMap map[string]*option.CCMUser } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) { @@ -199,20 +191,20 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio } if len(options.Credentials) > 0 { - providers, allDefaults, err := buildCredentialProviders(ctx, options, logger) + providers, allCredentials, err := buildCredentialProviders(ctx, options, logger) if err != nil { return nil, E.Cause(err, "build credential providers") } service.providers = providers - service.allDefaults = allDefaults + service.allCredentials = allCredentials - userCredentialMap := make(map[string]string) - for _, user := range options.Users { - userCredentialMap[user.Name] = user.Credential + userConfigMap := make(map[string]*option.CCMUser) + for i := range options.Users { + userConfigMap[options.Users[i].Name] = &options.Users[i] } - service.userCredentialMap = userCredentialMap + service.userConfigMap = userConfigMap } else { - credential, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{ + cred, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{ CredentialPath: options.CredentialPath, UsagesPath: options.UsagesPath, Detour: options.Detour, @@ -220,9 +212,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio if err != nil { return nil, err } - service.legacyCredential = credential - service.legacyProvider = &singleCredentialProvider{credential: credential} - service.allDefaults = []*defaultCredential{credential} + service.legacyCredential = cred + service.legacyProvider = &singleCredentialProvider{cred: cred} + service.allCredentials = []credential{cred} } if options.TLS != nil { @@ -243,8 +235,8 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) - for _, credential := range s.allDefaults { - err := credential.start() + for _, cred := range s.allCredentials { + err := cred.start() if err != nil { return err } @@ -303,6 +295,11 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int { } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ccm/v1/status" { + s.handleStatusEndpoint(w, r) + return + } + if !strings.HasPrefix(r.URL.Path, "/v1/") { writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found") return @@ -360,11 +357,13 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } - // Resolve credential provider + // Resolve credential provider and user config var provider credentialProvider + var userConfig *option.CCMUser if len(s.options.Users) > 0 { + userConfig = s.userConfigMap[username] var err error - provider, err = credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username) + provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) if err != nil { s.logger.Error("resolve credential: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) @@ -389,70 +388,48 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - credential, isNew, err := provider.selectCredential(sessionID) + var credentialFilter func(credential) bool + if userConfig != nil && !userConfig.AllowExternalUsage { + credentialFilter = func(c credential) bool { return !c.isExternal() } + } + + selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter) if err != nil { writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) return } if isNew { if username != "" { - s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username) + s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID, " by user ", username) } else { - s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID) + s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID) } } - accessToken, err := credential.getAccessToken() - if err != nil { - s.logger.Error("get access token: ", err) - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed") + if isExtendedContextRequest(anthropicBetaHeader) && selectedCredential.isExternal() { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "extended context (1m) requests cannot be proxied through external credentials") return } - proxyURL := claudeAPIBaseURL + r.URL.RequestURI() - requestContext := credential.wrapRequestContext(r.Context()) + requestContext := selectedCredential.wrapRequestContext(r.Context()) defer func() { requestContext.cancelRequest() }() - proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body) + proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if err != nil { s.logger.Error("create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") return } - for key, values := range r.Header { - if !isHopByHopHeader(key) && key != "Authorization" { - proxyRequest.Header[key] = values - } - } - - hasUsageTracker := credential.usageTracker != nil - serviceOverridesAcceptEncoding := len(s.httpHeaders.Values("Accept-Encoding")) > 0 - if hasUsageTracker && !serviceOverridesAcceptEncoding { - proxyRequest.Header.Del("Accept-Encoding") - } - - if anthropicBetaHeader != "" { - proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) - } else { - proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - } - - for key, values := range s.httpHeaders { - proxyRequest.Header.Del(key) - proxyRequest.Header[key] = values - } - - proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - - response, err := credential.httpClient.Do(proxyRequest) + response, err := selectedCredential.httpTransport().Do(proxyRequest) if err != nil { if r.Context().Err() != nil { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request") + writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request") return } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) @@ -463,24 +440,30 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Transparent 429 retry for response.StatusCode == http.StatusTooManyRequests { resetAt := parseRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, credential, resetAt) - credential.updateStateFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) + selectedCredential.updateStateFromHeaders(response.Header) if bodyBytes == nil || nextCredential == nil { response.Body.Close() - writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") return } response.Body.Close() - s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag) + s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(r.Context()) - retryResponse, retryErr := retryRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders) + retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if buildErr != nil { + s.logger.Error("retry request: ", buildErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) + return + } + retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest) if retryErr != nil { if r.Context().Err() != nil { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request") + writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") return } s.logger.Error("retry request: ", retryErr) @@ -489,21 +472,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } requestContext.releaseCredentialInterrupt() response = retryResponse - credential = nextCredential + selectedCredential = nextCredential } defer response.Body.Close() - credential.updateStateFromHeaders(response.Header) + selectedCredential.updateStateFromHeaders(response.Header) if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) - s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body)) + s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) return } - hasUsageTracker = credential.usageTracker != nil + // Rewrite response headers for external users + if userConfig != nil && userConfig.ExternalCredential != "" { + s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) + } for key, values := range response.Header { if !isHopByHopHeader(key) { @@ -512,8 +498,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(response.StatusCode) - if hasUsageTracker && response.StatusCode == http.StatusOK { - s.handleResponseWithTracking(w, response, credential.usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) + usageTracker := selectedCredential.usageTrackerOrNil() + if usageTracker != nil && response.StatusCode == http.StatusOK { + s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -693,6 +680,91 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons } } +func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") + return + } + + if len(s.options.Users) == 0 { + writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + username, ok := s.userManager.Authenticate(clientToken) + if !ok { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + + userConfig := s.userConfigMap[username] + if userConfig == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") + return + } + + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + provider.pollIfStale(r.Context()) + avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]float64{ + "five_hour_utilization": avgFiveHour, + "weekly_utilization": avgWeekly, + }) +} + +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64) { + var totalFiveHour, totalWeekly float64 + var count int + for _, cred := range provider.allCredentials() { + // Exclude the user's own external_credential (their contribution to us) + if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { + continue + } + // If user doesn't allow external usage, exclude all external credentials + if !userConfig.AllowExternalUsage && cred.isExternal() { + continue + } + totalFiveHour += cred.fiveHourUtilization() + totalWeekly += cred.weeklyUtilization() + count++ + } + if count == 0 { + return 100, 100 + } + return totalFiveHour / float64(count), totalWeekly / float64(count) +} + +func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) + if err != nil { + return + } + + avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + + // Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range) + headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) + headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) +} + func (s *Service) Close() error { err := common.Close( common.PtrOrNil(s.httpServer), @@ -700,8 +772,8 @@ func (s *Service) Close() error { s.tlsConfig, ) - for _, credential := range s.allDefaults { - credential.close() + for _, cred := range s.allCredentials { + cred.close() } return err diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go new file mode 100644 index 0000000000..337fc6da3c --- /dev/null +++ b/service/ocm/credential_external.go @@ -0,0 +1,463 @@ +package ocm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "io" + "net" + "net/http" + "net/url" + "strconv" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" +) + +type externalCredential struct { + tag string + baseURL string + token string + credDialer N.Dialer + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger + + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { + parsedURL, err := url.Parse(options.URL) + if err != nil { + return nil, E.Cause(err, "parse url for credential ", tag) + } + + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + + transport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if options.Server != "" { + serverPort := options.ServerPort + if serverPort == 0 { + portStr := parsedURL.Port() + if portStr != "" { + port, parseErr := strconv.ParseUint(portStr, 10, 16) + if parseErr == nil { + serverPort = uint16(port) + } + } + if serverPort == 0 { + if parsedURL.Scheme == "https" { + serverPort = 443 + } else { + serverPort = 80 + } + } + } + destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + return credentialDialer.DialContext(ctx, network, destination) + } + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + } + + if parsedURL.Scheme == "https" { + transport.TLSClientConfig = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + + pollInterval := time.Duration(options.PollInterval) + if pollInterval <= 0 { + pollInterval = 30 * time.Minute + } + + requestContext, cancelRequests := context.WithCancel(context.Background()) + + cred := &externalCredential{ + tag: tag, + baseURL: baseURL, + token: options.Token, + credDialer: credentialDialer, + httpClient: &http.Client{Transport: transport}, + pollInterval: pollInterval, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + + if options.UsagesPath != "" { + cred.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + + return cred, nil +} + +func (c *externalCredential) start() error { + if c.usageTracker != nil { + err := c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *externalCredential) setOnBecameUnusable(fn func()) { + c.onBecameUnusable = fn +} + +func (c *externalCredential) tagName() string { + return c.tag +} + +func (c *externalCredential) isExternal() bool { + return true +} + +func (c *externalCredential) isUsable() bool { + c.stateMutex.RLock() + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateMutex.RUnlock() + return false + } + c.stateMutex.RUnlock() + c.stateMutex.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 + c.stateMutex.Unlock() + return usable + } + usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 + c.stateMutex.RUnlock() + return usable +} + +func (c *externalCredential) fiveHourUtilization() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.fiveHourUtilization +} + +func (c *externalCredential) weeklyUtilization() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyUtilization +} + +func (c *externalCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateMutex.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *externalCredential) earliestReset() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + if c.state.hardRateLimited { + return c.state.rateLimitResetAt + } + earliest := c.state.fiveHourReset + if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { + earliest = c.state.weeklyReset + } + return earliest +} + +func (c *externalCredential) getAccessToken() (string, error) { + return c.token, nil +} + +func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) { + proxyURL := c.baseURL + original.URL.RequestURI() + var body io.Reader + if bodyBytes != nil { + body = bytes.NewReader(bodyBytes) + } else { + body = original.Body + } + proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) + if err != nil { + return nil, err + } + + for key, values := range original.Header { + if !isHopByHopHeader(key) && key != "Authorization" { + proxyRequest.Header[key] = values + } + } + + proxyRequest.Header.Set("Authorization", "Bearer "+c.token) + + return proxyRequest, nil +} + +func (c *externalCredential) updateStateFromHeaders(headers http.Header) { + c.stateMutex.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier == "" { + activeLimitIdentifier = "codex" + } + + fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") + if fiveHourPercent != "" { + value, err := strconv.ParseFloat(fiveHourPercent, 64) + if err == nil { + c.state.fiveHourUtilization = value + } + } + fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") + if fiveHourResetAt != "" { + value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) + if err == nil { + c.state.fiveHourReset = time.Unix(value, 0) + } + } + weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") + if weeklyPercent != "" { + value, err := strconv.ParseFloat(weeklyPercent, 64) + if err == nil { + c.state.weeklyUtilization = value + } + } + weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") + if weeklyResetAt != "" { + value, err := strconv.ParseInt(weeklyResetAt, 10, 64) + if err == nil { + c.state.weeklyReset = time.Unix(value, 0) + } + } + c.state.lastUpdated = time.Now() + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *externalCredential) checkTransitionLocked() bool { + unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFunc: stop, + cancelFunc: cancel, + } +} + +func (c *externalCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *externalCredential) pollUsage(ctx context.Context) { + if !c.pollAccess.TryLock() { + return + } + defer c.pollAccess.Unlock() + defer c.markUsagePollAttempted() + + statusURL := c.baseURL + "/ocm/v1/status" + httpClient := &http.Client{ + Transport: c.httpClient.Transport, + Timeout: 5 * time.Second, + } + + request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) + if err != nil { + c.logger.Error("poll usage for ", c.tag, ": create request: ", err) + return + } + request.Header.Set("Authorization", "Bearer "+c.token) + + response, err := httpClient.Do(request) + if err != nil { + c.logger.Error("poll usage for ", c.tag, ": ", err) + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + c.stateMutex.Unlock() + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + c.stateMutex.Unlock() + c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + return + } + + var statusResponse struct { + FiveHourUtilization float64 `json:"five_hour_utilization"` + WeeklyUtilization float64 `json:"weekly_utilization"` + } + err = json.NewDecoder(response.Body).Decode(&statusResponse) + if err != nil { + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + c.stateMutex.Unlock() + c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + return + } + + c.stateMutex.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + c.state.consecutivePollFailures = 0 + c.state.fiveHourUtilization = statusResponse.FiveHourUtilization + c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *externalCredential) lastUpdatedTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.lastUpdated +} + +func (c *externalCredential) markUsagePollAttempted() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateMutex.RLock() + failures := c.state.consecutivePollFailures + c.stateMutex.RUnlock() + if failures <= 0 { + return baseInterval + } + if failures > 4 { + failures = 4 + } + return baseInterval * time.Duration(1< p.credential.pollBackoff(defaultPollInterval) { - p.credential.pollUsage(ctx) + if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { + p.cred.pollUsage(ctx) } } -func (p *singleCredentialProvider) allDefaults() []*defaultCredential { - return []*defaultCredential{p.credential} +func (p *singleCredentialProvider) allCredentials() []credential { + return []credential{p.cred} } func (p *singleCredentialProvider) close() {} @@ -567,7 +693,7 @@ type sessionEntry struct { } type balancerProvider struct { - credentials []*defaultCredential + credentials []credential strategy string roundRobinIndex atomic.Uint64 pollInterval time.Duration @@ -576,7 +702,7 @@ type balancerProvider struct { logger log.ContextLogger } -func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { +func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval } @@ -589,15 +715,15 @@ func newBalancerProvider(credentials []*defaultCredential, strategy string, poll } } -func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) { +func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { if sessionID != "" { p.sessionMutex.RLock() entry, exists := p.sessions[sessionID] p.sessionMutex.RUnlock() if exists { - for _, credential := range p.credentials { - if credential.tag == entry.tag && credential.isUsable() { - return credential, false, nil + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() { + return cred, false, nil } } p.sessionMutex.Lock() @@ -606,7 +732,7 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia } } - best := p.pickCredential() + best := p.pickCredential(filter) if best == nil { return nil, false, allRateLimitedError(p.credentials) } @@ -614,61 +740,67 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia isNew := sessionID != "" if isNew { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} p.sessionMutex.Unlock() } return best, isNew, nil } -func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential { - credential.markRateLimited(resetAt) +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential { + cred.markRateLimited(resetAt) if sessionID != "" { p.sessionMutex.Lock() delete(p.sessions, sessionID) p.sessionMutex.Unlock() } - best := p.pickCredential() + best := p.pickCredential(filter) if best != nil && sessionID != "" { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} p.sessionMutex.Unlock() } return best } -func (p *balancerProvider) pickCredential() *defaultCredential { +func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { switch p.strategy { case "round_robin": - return p.pickRoundRobin() + return p.pickRoundRobin(filter) case "random": - return p.pickRandom() + return p.pickRandom(filter) default: - return p.pickLeastUsed() + return p.pickLeastUsed(filter) } } -func (p *balancerProvider) pickLeastUsed() *defaultCredential { - var best *defaultCredential +func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { + var best credential bestUtilization := float64(101) - for _, credential := range p.credentials { - if !credential.isUsable() { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !cred.isUsable() { continue } - utilization := credential.weeklyUtilization() + utilization := cred.weeklyUtilization() if utilization < bestUtilization { bestUtilization = utilization - best = credential + best = cred } } return best } -func (p *balancerProvider) pickRoundRobin() *defaultCredential { +func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { start := int(p.roundRobinIndex.Add(1) - 1) count := len(p.credentials) for offset := range count { candidate := p.credentials[(start+offset)%count] + if filter != nil && !filter(candidate) { + continue + } if candidate.isUsable() { return candidate } @@ -676,9 +808,12 @@ func (p *balancerProvider) pickRoundRobin() *defaultCredential { return nil } -func (p *balancerProvider) pickRandom() *defaultCredential { - var usable []*defaultCredential +func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { + var usable []credential for _, candidate := range p.credentials { + if filter != nil && !filter(candidate) { + continue + } if candidate.isUsable() { usable = append(usable, candidate) } @@ -699,26 +834,26 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) { } p.sessionMutex.Unlock() - for _, credential := range p.credentials { - if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { - credential.pollUsage(ctx) + for _, cred := range p.credentials { + if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { + cred.pollUsage(ctx) } } } -func (p *balancerProvider) allDefaults() []*defaultCredential { +func (p *balancerProvider) allCredentials() []credential { return p.credentials } func (p *balancerProvider) close() {} type fallbackProvider struct { - credentials []*defaultCredential + credentials []credential pollInterval time.Duration logger log.ContextLogger } -func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { +func newFallbackProvider(credentials []credential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval } @@ -729,18 +864,24 @@ func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Dur } } -func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) { - for _, credential := range p.credentials { - if credential.isUsable() { - return credential, false, nil +func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if cred.isUsable() { + return cred, false, nil } } return nil, false, allRateLimitedError(p.credentials) } -func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { - credential.markRateLimited(resetAt) +func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential { + cred.markRateLimited(resetAt) for _, candidate := range p.credentials { + if filter != nil && !filter(candidate) { + continue + } if candidate.isUsable() { return candidate } @@ -749,23 +890,23 @@ func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential } func (p *fallbackProvider) pollIfStale(ctx context.Context) { - for _, credential := range p.credentials { - if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { - credential.pollUsage(ctx) + for _, cred := range p.credentials { + if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { + cred.pollUsage(ctx) } } } -func (p *fallbackProvider) allDefaults() []*defaultCredential { +func (p *fallbackProvider) allCredentials() []credential { return p.credentials } func (p *fallbackProvider) close() {} -func allRateLimitedError(credentials []*defaultCredential) error { +func allRateLimitedError(credentials []credential) error { var earliest time.Time - for _, credential := range credentials { - resetAt := credential.earliestReset() + for _, cred := range credentials { + resetAt := cred.earliestReset() if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { earliest = resetAt } @@ -780,34 +921,44 @@ func buildOCMCredentialProviders( ctx context.Context, options option.OCMServiceOptions, logger log.ContextLogger, -) (map[string]credentialProvider, []*defaultCredential, error) { - defaultCredentials := make(map[string]*defaultCredential) - var allDefaults []*defaultCredential +) (map[string]credentialProvider, []credential, error) { + allCredentialMap := make(map[string]credential) + var allCreds []credential providers := make(map[string]credentialProvider) + // Pass 1: create default and external credentials for _, credOpt := range options.Credentials { switch credOpt.Type { case "default": - credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) if err != nil { return nil, nil, err } - defaultCredentials[credOpt.Tag] = credential - allDefaults = append(allDefaults, credential) - providers[credOpt.Tag] = &singleCredentialProvider{credential: credential} + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + case "external": + cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} } } + // Pass 2: create balancer and fallback providers for _, credOpt := range options.Credentials { switch credOpt.Type { case "balancer": - subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag) + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) if err != nil { return nil, nil, err } providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger) case "fallback": - subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag) + subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag) if err != nil { return nil, nil, err } @@ -815,17 +966,17 @@ func buildOCMCredentialProviders( } } - return providers, allDefaults, nil + return providers, allCreds, nil } -func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) { - credentials := make([]*defaultCredential, 0, len(tags)) +func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { + credentials := make([]credential, 0, len(tags)) for _, tag := range tags { - credential, exists := defaults[tag] + cred, exists := allCredentials[tag] if !exists { - return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag) + return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) } - credentials = append(credentials, credential) + credentials = append(credentials, cred) } if len(credentials) == 0 { return nil, E.New("credential ", parentTag, " has no sub-credentials") @@ -871,24 +1022,34 @@ func validateOCMOptions(options option.OCMServiceOptions) error { if hasCredentials { tags := make(map[string]bool) - for _, credential := range options.Credentials { - if tags[credential.Tag] { - return E.New("duplicate credential tag: ", credential.Tag) + credentialTypes := make(map[string]string) + for _, cred := range options.Credentials { + if tags[cred.Tag] { + return E.New("duplicate credential tag: ", cred.Tag) } - tags[credential.Tag] = true - if credential.Type == "default" || credential.Type == "" { - if credential.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") + tags[cred.Tag] = true + credentialTypes[cred.Tag] = cred.Type + if cred.Type == "default" || cred.Type == "" { + if cred.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") } - if credential.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") + if cred.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") } } - if credential.Type == "balancer" { - switch credential.BalancerOptions.Strategy { + if cred.Type == "external" { + if cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": external credential requires url") + } + if cred.ExternalOptions.Token == "" { + return E.New("credential ", cred.Tag, ": external credential requires token") + } + } + if cred.Type == "balancer" { + switch cred.BalancerOptions.Strategy { case "", "least_used", "round_robin", "random": default: - return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) + return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) } } } @@ -900,6 +1061,14 @@ func validateOCMOptions(options option.OCMServiceOptions) error { if !tags[user.Credential] { return E.New("user ", user.Name, " references unknown credential: ", user.Credential) } + if user.ExternalCredential != "" { + if !tags[user.ExternalCredential] { + return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) + } + if credentialTypes[user.ExternalCredential] != "external" { + return E.New("user ", user.Name, ": external_credential must reference an external type credential") + } + } } } @@ -910,21 +1079,21 @@ func validateOCMCompositeCredentialModes( options option.OCMServiceOptions, providers map[string]credentialProvider, ) error { - for _, credential := range options.Credentials { - if credential.Type != "balancer" && credential.Type != "fallback" { + for _, credOpt := range options.Credentials { + if credOpt.Type != "balancer" && credOpt.Type != "fallback" { continue } - provider, exists := providers[credential.Tag] + provider, exists := providers[credOpt.Tag] if !exists { - return E.New("unknown credential: ", credential.Tag) + return E.New("unknown credential: ", credOpt.Tag) } - for _, subCredential := range provider.allDefaults() { - if subCredential.isAPIKeyMode() { + for _, subCred := range provider.allCredentials() { + if subCred.ocmIsAPIKeyMode() { return E.New( - "credential ", credential.Tag, - " references API key default credential ", subCredential.tag, + "credential ", credOpt.Tag, + " references API key default credential ", subCred.tagName(), "; balancer and fallback only support OAuth default credentials", ) } @@ -934,60 +1103,8 @@ func validateOCMCompositeCredentialModes( return nil } -func retryOCMRequestWithBody( - ctx context.Context, - originalRequest *http.Request, - bodyBytes []byte, - credential *defaultCredential, - httpHeaders http.Header, -) (*http.Response, error) { - accessToken, err := credential.getAccessToken() - if err != nil { - return nil, E.Cause(err, "get access token for ", credential.tag) - } - - baseURL := credential.getBaseURL() - path := originalRequest.URL.Path - var proxyPath string - if credential.isAPIKeyMode() { - proxyPath = path - } else { - proxyPath = strings.TrimPrefix(path, "/v1") - } - - proxyURL := baseURL + proxyPath - if originalRequest.URL.RawQuery != "" { - proxyURL += "?" + originalRequest.URL.RawQuery - } - - var body io.Reader - if bodyBytes != nil { - body = bytes.NewReader(bodyBytes) - } - retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, body) - if err != nil { - return nil, err - } - - for key, values := range originalRequest.Header { - if !isHopByHopHeader(key) && key != "Authorization" { - retryRequest.Header[key] = values - } - } - for key, values := range httpHeaders { - retryRequest.Header.Del(key) - retryRequest.Header[key] = values - } - retryRequest.Header.Set("Authorization", "Bearer "+accessToken) - if accountID := credential.getAccountID(); accountID != "" { - retryRequest.Header.Set("ChatGPT-Account-Id", accountID) - } - - return credential.httpClient.Do(retryRequest) -} - func credentialForUser( - userCredentialMap map[string]string, + userConfigMap map[string]*option.OCMUser, providers map[string]credentialProvider, legacyProvider credentialProvider, username string, @@ -995,13 +1112,13 @@ func credentialForUser( if legacyProvider != nil { return legacyProvider, nil } - tag, exists := userCredentialMap[username] + userConfig, exists := userConfigMap[username] if !exists { return nil, E.New("no credential mapping for user: ", username) } - provider, exists := providers[tag] + provider, exists := providers[userConfig.Credential] if !exists { - return nil, E.New("unknown credential: ", tag) + return nil, E.New("unknown credential: ", userConfig.Credential) } return provider, nil } diff --git a/service/ocm/service.go b/service/ocm/service.go index 75f28f2c1a..38527abde8 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -74,15 +74,18 @@ const ( retryableUsageCode = "credential_usage_exhausted" ) -func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool { +func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool { if provider == nil || currentCredential == nil { return false } - for _, credential := range provider.allDefaults() { - if credential == currentCredential { + for _, cred := range provider.allCredentials() { + if cred == currentCredential { continue } - if credential.isUsable() { + if filter != nil && !filter(cred) { + continue + } + if cred.isUsable() { return true } } @@ -93,7 +96,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string) if provider == nil { return fallback } - return allRateLimitedError(provider.allDefaults()).Error() + return allRateLimitedError(provider.allCredentials()).Error() } func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { @@ -108,10 +111,11 @@ func writeCredentialUnavailableError( w http.ResponseWriter, r *http.Request, provider credentialProvider, - currentCredential *defaultCredential, + currentCredential credential, + filter func(credential) bool, fallback string, ) { - if hasAlternativeCredential(provider, currentCredential) { + if hasAlternativeCredential(provider, currentCredential, filter) { writeRetryableUsageError(w, r) return } @@ -198,9 +202,9 @@ type Service struct { legacyProvider credentialProvider // Multi-credential mode - providers map[string]credentialProvider - allDefaults []*defaultCredential - userCredentialMap map[string]string + providers map[string]credentialProvider + allCredentials []credential + userConfigMap map[string]*option.OCMUser } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) { @@ -230,20 +234,20 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio } if len(options.Credentials) > 0 { - providers, allDefaults, err := buildOCMCredentialProviders(ctx, options, logger) + providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger) if err != nil { return nil, E.Cause(err, "build credential providers") } service.providers = providers - service.allDefaults = allDefaults + service.allCredentials = allCredentials - userCredentialMap := make(map[string]string) - for _, user := range options.Users { - userCredentialMap[user.Name] = user.Credential + userConfigMap := make(map[string]*option.OCMUser) + for i := range options.Users { + userConfigMap[options.Users[i].Name] = &options.Users[i] } - service.userCredentialMap = userCredentialMap + service.userConfigMap = userConfigMap } else { - credential, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{ + cred, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{ CredentialPath: options.CredentialPath, UsagesPath: options.UsagesPath, Detour: options.Detour, @@ -251,9 +255,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio if err != nil { return nil, err } - service.legacyCredential = credential - service.legacyProvider = &singleCredentialProvider{credential: credential} - service.allDefaults = []*defaultCredential{credential} + service.legacyCredential = cred + service.legacyProvider = &singleCredentialProvider{cred: cred} + service.allCredentials = []credential{cred} } if options.TLS != nil { @@ -274,15 +278,15 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) - for _, credential := range s.allDefaults { - err := credential.start() + for _, cred := range s.allCredentials { + err := cred.start() if err != nil { return err } - tag := credential.tag - credential.onBecameUnusable = func() { + tag := cred.tagName() + cred.setOnBecameUnusable(func() { s.interruptWebSocketSessionsForCredential(tag) - } + }) } if len(s.options.Credentials) > 0 { err := validateOCMCompositeCredentialModes(s.options, s.providers) @@ -327,7 +331,7 @@ func (s *Service) Start(stage adapter.StartStage) error { func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { if len(s.options.Users) > 0 { - return credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username) + return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) } provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options) if provider == nil { @@ -337,6 +341,11 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ocm/v1/status" { + s.handleStatusEndpoint(w, r) + return + } + path := r.URL.Path if !strings.HasPrefix(path, "/v1/") { writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") @@ -368,49 +377,64 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { sessionID := r.Header.Get("session_id") - // Resolve credential provider - provider, err := s.resolveCredentialProvider(username) - if err != nil { - s.logger.Error("resolve credential: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + // Resolve credential provider and user config + var provider credentialProvider + var userConfig *option.OCMUser + if len(s.options.Users) > 0 { + userConfig = s.userConfigMap[username] + var err error + provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + s.logger.Error("resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + } + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") return } provider.pollIfStale(s.ctx) - credential, isNew, err := provider.selectCredential(sessionID) + var credentialFilter func(credential) bool + if userConfig != nil && !userConfig.AllowExternalUsage { + credentialFilter = func(c credential) bool { return !c.isExternal() } + } + + selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter) if err != nil { writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) return } if isNew { if username != "" { - s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username) + s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID, " by user ", username) } else { - s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID) + s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID) } } if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(w, r, path, username, sessionID, provider, credential) + s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter) return } - var proxyPath string - if credential.isAPIKeyMode() { - proxyPath = path - } else { + if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() { + // API key mode path handling + } else if !selectedCredential.isExternal() { if path == "/v1/chat/completions" { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "chat completions endpoint is only available in API key mode") return } - proxyPath = strings.TrimPrefix(path, "/v1") } - shouldTrackUsage := credential.usageTracker != nil && + shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil && (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) - canRetryRequest := len(provider.allDefaults()) > 1 + canRetryRequest := len(provider.allCredentials()) > 1 // Read body for model extraction and retry buffer when JSON replay is useful. var bodyBytes []byte @@ -435,52 +459,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - accessToken, err := credential.getAccessToken() - if err != nil { - s.logger.Error("get access token: ", err) - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed") - return - } - - proxyURL := credential.getBaseURL() + proxyPath - if r.URL.RawQuery != "" { - proxyURL += "?" + r.URL.RawQuery - } - requestContext := credential.wrapRequestContext(r.Context()) + requestContext := selectedCredential.wrapRequestContext(r.Context()) defer func() { requestContext.cancelRequest() }() - proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body) + proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if err != nil { s.logger.Error("create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") return } - for key, values := range r.Header { - if !isHopByHopHeader(key) && key != "Authorization" { - proxyRequest.Header[key] = values - } - } - - for key, values := range s.httpHeaders { - proxyRequest.Header.Del(key) - proxyRequest.Header[key] = values - } - - proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - - if accountID := credential.getAccountID(); accountID != "" { - proxyRequest.Header.Set("ChatGPT-Account-Id", accountID) - } - - response, err := credential.httpClient.Do(proxyRequest) + response, err := selectedCredential.httpTransport().Do(proxyRequest) if err != nil { if r.Context().Err() != nil { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request") + writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request") return } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) @@ -491,25 +487,31 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Transparent 429 retry for response.StatusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, credential, resetAt) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete - credential.updateStateFromHeaders(response.Header) + selectedCredential.updateStateFromHeaders(response.Header) if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { response.Body.Close() - writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") return } response.Body.Close() - s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag) + s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(r.Context()) - retryResponse, retryErr := retryOCMRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders) + retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if buildErr != nil { + s.logger.Error("retry request: ", buildErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) + return + } + retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest) if retryErr != nil { if r.Context().Err() != nil { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request") + writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") return } s.logger.Error("retry request: ", retryErr) @@ -518,20 +520,25 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } requestContext.releaseCredentialInterrupt() response = retryResponse - credential = nextCredential + selectedCredential = nextCredential } defer response.Body.Close() - credential.updateStateFromHeaders(response.Header) + selectedCredential.updateStateFromHeaders(response.Header) if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) - s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body)) + s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) return } + // Rewrite response headers for external users + if userConfig != nil && userConfig.ExternalCredential != "" { + s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) + } + for key, values := range response.Header { if !isHopByHopHeader(key) { w.Header()[key] = values @@ -539,10 +546,10 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(response.StatusCode) - hasUsageTracker := credential.usageTracker != nil - if hasUsageTracker && response.StatusCode == http.StatusOK && + usageTracker := selectedCredential.usageTrackerOrNil() + if usageTracker != nil && response.StatusCode == http.StatusOK && (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { - s.handleResponseWithTracking(w, response, credential.usageTracker, path, requestModel, username) + s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -745,6 +752,93 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons } } +func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") + return + } + + if len(s.options.Users) == 0 { + writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + username, ok := s.userManager.Authenticate(clientToken) + if !ok { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + + userConfig := s.userConfigMap[username] + if userConfig == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") + return + } + + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + provider.pollIfStale(r.Context()) + avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]float64{ + "five_hour_utilization": avgFiveHour, + "weekly_utilization": avgWeekly, + }) +} + +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64) { + var totalFiveHour, totalWeekly float64 + var count int + for _, cred := range provider.allCredentials() { + if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { + continue + } + if !userConfig.AllowExternalUsage && cred.isExternal() { + continue + } + totalFiveHour += cred.fiveHourUtilization() + totalWeekly += cred.weeklyUtilization() + count++ + } + if count == 0 { + return 100, 100 + } + return totalFiveHour / float64(count), totalWeekly / float64(count) +} + +func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) + if err != nil { + return + } + + avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier == "" { + activeLimitIdentifier = "codex" + } + + headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) + headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) +} + func (s *Service) Close() error { webSocketSessions := s.startWebSocketShutdown() @@ -758,8 +852,8 @@ func (s *Service) Close() error { } s.webSocketGroup.Wait() - for _, credential := range s.allDefaults { - credential.close() + for _, cred := range s.allCredentials { + cred.close() } return err diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index eafd37aaed..eeb0380560 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -14,6 +14,7 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/ntp" @@ -85,8 +86,10 @@ func (s *Service) handleWebSocket( path string, username string, sessionID string, + userConfig *option.OCMUser, provider credentialProvider, - credential *defaultCredential, + selectedCredential credential, + credentialFilter func(credential) bool, ) { var ( err error @@ -97,7 +100,7 @@ func (s *Service) handleWebSocket( ) for { - accessToken, accessErr := credential.getAccessToken() + accessToken, accessErr := selectedCredential.getAccessToken() if accessErr != nil { s.logger.Error("get access token for websocket: ", accessErr) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") @@ -105,13 +108,13 @@ func (s *Service) handleWebSocket( } var proxyPath string - if credential.isAPIKeyMode() { + if selectedCredential.ocmIsAPIKeyMode() || selectedCredential.isExternal() { proxyPath = path } else { proxyPath = strings.TrimPrefix(path, "/v1") } - upstreamURL := buildUpstreamWebSocketURL(credential.getBaseURL(), proxyPath) + upstreamURL := buildUpstreamWebSocketURL(selectedCredential.ocmGetBaseURL(), proxyPath) if r.URL.RawQuery != "" { upstreamURL += "?" + r.URL.RawQuery } @@ -127,7 +130,7 @@ func (s *Service) handleWebSocket( upstreamHeaders[key] = values } upstreamHeaders.Set("Authorization", "Bearer "+accessToken) - if accountID := credential.getAccountID(); accountID != "" { + if accountID := selectedCredential.ocmGetAccountID(); accountID != "" { upstreamHeaders.Set("ChatGPT-Account-Id", accountID) } @@ -135,7 +138,7 @@ func (s *Service) handleWebSocket( statusCode = 0 upstreamDialer := ws.Dialer{ NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { - return credential.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + return selectedCredential.ocmDialer().DialContext(ctx, network, M.ParseSocksaddr(addr)) }, TLSConfig: &stdTLS.Config{ RootCAs: adapter.RootPoolFromContext(s.ctx), @@ -170,14 +173,14 @@ func (s *Service) handleWebSocket( } if statusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders) - nextCredential := provider.onRateLimited(sessionID, credential, resetAt) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) if nextCredential == nil { - credential.updateStateFromHeaders(upstreamResponseHeaders) - writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + selectedCredential.updateStateFromHeaders(upstreamResponseHeaders) + writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") return } - s.logger.Info("retrying websocket with credential ", nextCredential.tag, " after 429 from ", credential.tag) - credential = nextCredential + s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + selectedCredential = nextCredential continue } s.logger.Error("dial upstream websocket: ", err) @@ -185,15 +188,18 @@ func (s *Service) handleWebSocket( return } - credential.updateStateFromHeaders(upstreamResponseHeaders) + selectedCredential.updateStateFromHeaders(upstreamResponseHeaders) weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders) clientResponseHeaders := make(http.Header) for key, values := range upstreamResponseHeaders { if isForwardableResponseHeader(key) { - clientResponseHeaders[key] = values + clientResponseHeaders[key] = append([]string(nil), values...) } } + if userConfig != nil && userConfig.ExternalCredential != "" { + s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, userConfig) + } clientUpgrader := ws.HTTPUpgrader{ Header: clientResponseHeaders, @@ -212,7 +218,7 @@ func (s *Service) handleWebSocket( session := &webSocketSession{ clientConn: clientConn, upstreamConn: upstreamConn, - credentialTag: credential.tag, + credentialTag: selectedCredential.tagName(), } if !s.registerWebSocketSession(session) { session.Close() @@ -237,17 +243,17 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, credential, modelChannel) + s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel) }() go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, credential, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, modelChannel, username, weeklyCycleHint) }() waitGroup.Wait() } -func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, credential *defaultCredential, modelChannel chan<- string) { +func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string) { for { data, opCode, err := wsutil.ReadClientData(clientConn) if err != nil { @@ -257,7 +263,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo return } - if opCode == ws.OpText && credential.usageTracker != nil { + if opCode == ws.OpText && selectedCredential.usageTrackerOrNil() != nil { var request struct { Type string `json:"type"` Model string `json:"model"` @@ -280,7 +286,8 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo } } -func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, credential *defaultCredential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { + usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { data, opCode, err := wsutil.ReadServerData(upstreamReadWriter) @@ -291,7 +298,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite return } - if opCode == ws.OpText && credential.usageTracker != nil { + if opCode == ws.OpText && usageTracker != nil { select { case model := <-modelChannel: requestModel = model @@ -317,7 +324,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite } if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - credential.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens, From 8e5811a8c79c3b47b0a2363c8b808d0c967c5d4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 02:21:34 +0800 Subject: [PATCH 09/96] ccm,ocm: watch credential_path and allow delayed credentials --- docs/configuration/service/ccm.md | 6 +- docs/configuration/service/ccm.zh.md | 6 +- docs/configuration/service/ocm.md | 4 +- docs/configuration/service/ocm.zh.md | 4 +- service/ccm/credential.go | 22 +++ service/ccm/credential_external.go | 8 ++ service/ccm/credential_file.go | 141 +++++++++++++++++++ service/ccm/credential_state.go | 164 +++++++++++++++++----- service/ccm/service.go | 9 +- service/ocm/credential.go | 38 +++++ service/ocm/credential_external.go | 8 ++ service/ocm/credential_file.go | 139 +++++++++++++++++++ service/ocm/credential_state.go | 198 ++++++++++++++++++++++----- service/ocm/service.go | 9 +- 14 files changed, 688 insertions(+), 68 deletions(-) create mode 100644 service/ccm/credential_file.go create mode 100644 service/ocm/credential_file.go diff --git a/docs/configuration/service/ccm.md b/docs/configuration/service/ccm.md index 59ef5f7c0d..691782968a 100644 --- a/docs/configuration/service/ccm.md +++ b/docs/configuration/service/ccm.md @@ -51,6 +51,10 @@ On macOS, credentials are read from the system keychain first, then fall back to Refreshed tokens are automatically written back to the same location. +When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid. + +On macOS without an explicit `credential_path`, keychain changes are not watched. Automatic reload only applies to the credential file path. + Conflict with `credentials`. #### credentials @@ -76,7 +80,7 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a } ``` -A single OAuth credential file. The `type` field can be omitted (defaults to `default`). +A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically. - `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`. - `usages_path`: Optional usage tracking file for this credential. diff --git a/docs/configuration/service/ccm.zh.md b/docs/configuration/service/ccm.zh.md index d9496986a7..f555fc4d2d 100644 --- a/docs/configuration/service/ccm.zh.md +++ b/docs/configuration/service/ccm.zh.md @@ -51,6 +51,10 @@ Claude Code OAuth 凭据文件的路径。 刷新的令牌会自动写回相同位置。 +当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。 + +在 macOS 上如果未显式设置 `credential_path`,不会监听钥匙串变化。自动重载只作用于凭据文件路径。 + 与 `credentials` 冲突。 #### credentials @@ -76,7 +80,7 @@ Claude Code OAuth 凭据文件的路径。 } ``` -单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。 +单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。 - `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。 - `usages_path`:此凭据的可选使用跟踪文件。 diff --git a/docs/configuration/service/ocm.md b/docs/configuration/service/ocm.md index 8dfd0e99ed..4c63de0f78 100644 --- a/docs/configuration/service/ocm.md +++ b/docs/configuration/service/ocm.md @@ -49,6 +49,8 @@ If not specified, defaults to: Refreshed tokens are automatically written back to the same location. +When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid. + Conflict with `credentials`. #### credentials @@ -74,7 +76,7 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a } ``` -A single OAuth credential file. The `type` field can be omitted (defaults to `default`). +A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically. - `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`. - `usages_path`: Optional usage tracking file for this credential. diff --git a/docs/configuration/service/ocm.zh.md b/docs/configuration/service/ocm.zh.md index ee4ffa633c..81b222ef9f 100644 --- a/docs/configuration/service/ocm.zh.md +++ b/docs/configuration/service/ocm.zh.md @@ -49,6 +49,8 @@ OpenAI OAuth 凭据文件的路径。 刷新的令牌会自动写回相同位置。 +当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。 + 与 `credentials` 冲突。 #### credentials @@ -74,7 +76,7 @@ OpenAI OAuth 凭据文件的路径。 } ``` -单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。 +单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。 - `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。 - `usages_path`:此凭据的可选使用跟踪文件。 diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 0fe5e2b970..f14d4d2bc0 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -9,6 +9,7 @@ import ( "os/user" "path/filepath" "runtime" + "slices" "sync" "time" @@ -189,3 +190,24 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut return &newCredentials, nil } + +func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { + if credentials == nil { + return nil + } + cloned := *credentials + cloned.Scopes = append([]string(nil), credentials.Scopes...) + return &cloned +} + +func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { + if left == nil || right == nil { + return left == right + } + return left.AccessToken == right.AccessToken && + left.RefreshToken == right.RefreshToken && + left.ExpiresAt == right.ExpiresAt && + slices.Equal(left.Scopes, right.Scopes) && + left.SubscriptionType == right.SubscriptionType && + left.IsMax == right.IsMax +} diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index c4d2d340f7..456ba4fcbe 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -151,6 +151,10 @@ func (c *externalCredential) isExternal() bool { return true } +func (c *externalCredential) isAvailable() bool { + return true +} + func (c *externalCredential) isUsable() bool { c.stateMutex.RLock() if c.state.hardRateLimited { @@ -210,6 +214,10 @@ func (c *externalCredential) earliestReset() time.Time { return earliest } +func (c *externalCredential) unavailableError() error { + return nil +} + func (c *externalCredential) getAccessToken() (string, error) { return c.token, nil } diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go new file mode 100644 index 0000000000..da13fae10e --- /dev/null +++ b/service/ccm/credential_file.go @@ -0,0 +1,141 @@ +package ccm + +import ( + "path/filepath" + "time" + + "github.com/sagernet/fswatch" + E "github.com/sagernet/sing/common/exceptions" +) + +const credentialReloadRetryInterval = 2 * time.Second + +func resolveCredentialFilePath(customPath string) (string, error) { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return "", err + } + } + if filepath.IsAbs(customPath) { + return customPath, nil + } + return filepath.Abs(customPath) +} + +func (c *defaultCredential) ensureCredentialWatcher() error { + c.watcherAccess.Lock() + defer c.watcherAccess.Unlock() + + if c.watcher != nil || c.credentialFilePath == "" { + return nil + } + if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) { + return nil + } + + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: []string{c.credentialFilePath}, + Logger: c.logger, + Callback: func(string) { + err := c.reloadCredentials(true) + if err != nil { + c.logger.Warn("reload credentials for ", c.tag, ": ", err) + } + }, + }) + if err != nil { + c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval) + return err + } + + err = watcher.Start() + if err != nil { + c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval) + return err + } + + c.watcher = watcher + c.watcherRetryAt = time.Time{} + return nil +} + +func (c *defaultCredential) retryCredentialReloadIfNeeded() { + c.stateMutex.RLock() + unavailable := c.state.unavailable + lastAttempt := c.state.lastCredentialLoadAttempt + c.stateMutex.RUnlock() + if !unavailable { + return + } + if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval { + return + } + + err := c.ensureCredentialWatcher() + if err != nil { + c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + } + _ = c.reloadCredentials(false) +} + +func (c *defaultCredential) reloadCredentials(force bool) error { + c.reloadAccess.Lock() + defer c.reloadAccess.Unlock() + + c.stateMutex.RLock() + unavailable := c.state.unavailable + lastAttempt := c.state.lastCredentialLoadAttempt + c.stateMutex.RUnlock() + if !force { + if !unavailable { + return nil + } + if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval { + return c.unavailableError() + } + } + + c.stateMutex.Lock() + c.state.lastCredentialLoadAttempt = time.Now() + c.stateMutex.Unlock() + + credentials, err := platformReadCredentials(c.credentialPath) + if err != nil { + return c.markCredentialsUnavailable(E.Cause(err, "read credentials")) + } + + c.accessMutex.Lock() + c.credentials = credentials + c.accessMutex.Unlock() + + c.stateMutex.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadError = "" + c.state.accountType = credentials.SubscriptionType + c.checkTransitionLocked() + c.stateMutex.Unlock() + + return nil +} + +func (c *defaultCredential) markCredentialsUnavailable(err error) error { + c.accessMutex.Lock() + hadCredentials := c.credentials != nil + c.credentials = nil + c.accessMutex.Unlock() + + c.stateMutex.Lock() + c.state.unavailable = true + c.state.lastCredentialLoadError = err.Error() + c.state.accountType = "" + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + + if shouldInterrupt && hadCredentials { + c.interruptConnections() + } + + return err +} diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index f17fe98f96..9371295972 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -17,6 +17,7 @@ import ( "sync/atomic" "time" + "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/log" @@ -29,30 +30,38 @@ import ( const defaultPollInterval = 60 * time.Minute type credentialState struct { - fiveHourUtilization float64 - fiveHourReset time.Time - weeklyUtilization float64 - weeklyReset time.Time - hardRateLimited bool - rateLimitResetAt time.Time - accountType string - lastUpdated time.Time - consecutivePollFailures int + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + lastUpdated time.Time + consecutivePollFailures int + unavailable bool + lastCredentialLoadAttempt time.Time + lastCredentialLoadError string } type defaultCredential struct { - tag string - credentialPath string - credentials *oauthCredentials - accessMutex sync.RWMutex - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - reserve5h uint8 - reserveWeekly uint8 - usageTracker *AggregatedUsage - httpClient *http.Client - logger log.ContextLogger + tag string + credentialPath string + credentialFilePath string + credentials *oauthCredentials + accessMutex sync.RWMutex + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + reloadAccess sync.Mutex + watcherAccess sync.Mutex + reserve5h uint8 + reserveWeekly uint8 + usageTracker *AggregatedUsage + httpClient *http.Client + logger log.ContextLogger + watcher *fswatch.Watcher + watcherRetryAt time.Time // Connection interruption onBecameUnusable func() @@ -83,12 +92,14 @@ func (c *credentialRequestContext) cancelRequest() { type credential interface { tagName() string + isAvailable() bool isUsable() bool isExternal() bool fiveHourUtilization() float64 weeklyUtilization() float64 markRateLimited(resetAt time.Time) earliestReset() time.Time + unavailableError() error getAccessToken() (string, error) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) @@ -160,13 +171,18 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef } func (c *defaultCredential) start() error { - credentials, err := platformReadCredentials(c.credentialPath) + credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) if err != nil { - return E.Cause(err, "read credentials for ", c.tag) + return E.Cause(err, "resolve credential path for ", c.tag) } - c.credentials = credentials - if credentials.SubscriptionType != "" { - c.state.accountType = credentials.SubscriptionType + c.credentialFilePath = credentialFilePath + err = c.ensureCredentialWatcher() + if err != nil { + c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + } + err = c.reloadCredentials(true) + if err != nil { + c.logger.Warn("initial credential load for ", c.tag, ": ", err) } if c.usageTracker != nil { err = c.usageTracker.Load() @@ -178,33 +194,68 @@ func (c *defaultCredential) start() error { } func (c *defaultCredential) getAccessToken() (string, error) { + c.retryCredentialReloadIfNeeded() + c.accessMutex.RLock() - if !c.credentials.needsRefresh() { + if c.credentials != nil && !c.credentials.needsRefresh() { token := c.credentials.AccessToken c.accessMutex.RUnlock() return token, nil } c.accessMutex.RUnlock() + err := c.reloadCredentials(true) + if err == nil { + c.accessMutex.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.AccessToken + c.accessMutex.RUnlock() + return token, nil + } + c.accessMutex.RUnlock() + } + c.accessMutex.Lock() defer c.accessMutex.Unlock() + if c.credentials == nil { + return "", c.unavailableError() + } if !c.credentials.needsRefresh() { return c.credentials.AccessToken, nil } + baseCredentials := cloneCredentials(c.credentials) newCredentials, err := refreshToken(c.httpClient, c.credentials) if err != nil { return "", err } - c.credentials = newCredentials - if newCredentials.SubscriptionType != "" { + latestCredentials, latestErr := platformReadCredentials(c.credentialPath) + if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { + c.credentials = latestCredentials c.stateMutex.Lock() - c.state.accountType = newCredentials.SubscriptionType + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.state.accountType = latestCredentials.SubscriptionType + c.checkTransitionLocked() c.stateMutex.Unlock() + if !latestCredentials.needsRefresh() { + return latestCredentials.AccessToken, nil + } + return "", E.New("credential ", c.tag, " changed while refreshing") } + c.credentials = newCredentials + c.stateMutex.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.state.accountType = newCredentials.SubscriptionType + c.checkTransitionLocked() + c.stateMutex.Unlock() + err = platformWriteCredentials(newCredentials, c.credentialPath) if err != nil { c.logger.Warn("persist refreshed token for ", c.tag, ": ", err) @@ -299,7 +350,13 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { } func (c *defaultCredential) isUsable() bool { + c.retryCredentialReloadIfNeeded() + c.stateMutex.RLock() + if c.state.unavailable { + c.stateMutex.RUnlock() + return false + } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { c.stateMutex.RUnlock() @@ -332,7 +389,7 @@ func (c *defaultCredential) checkReservesLocked() bool { // checkTransitionLocked detects usable→unusable transition. // Must be called with stateMutex write lock held. func (c *defaultCredential) checkTransitionLocked() bool { - unusable := c.state.hardRateLimited || !c.checkReservesLocked() + unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() if unusable && !c.interrupted { c.interrupted = true return true @@ -375,6 +432,26 @@ func (c *defaultCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *defaultCredential) isAvailable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return !c.state.unavailable +} + +func (c *defaultCredential) unavailableError() error { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + if !c.state.unavailable { + return nil + } + if c.state.lastCredentialLoadError == "" { + return E.New("credential ", c.tag, " is unavailable") + } + return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) +} + func (c *defaultCredential) lastUpdatedTime() time.Time { c.stateMutex.RLock() defer c.stateMutex.RUnlock() @@ -403,6 +480,9 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio func (c *defaultCredential) earliestReset() time.Time { c.stateMutex.RLock() defer c.stateMutex.RUnlock() + if c.state.unavailable { + return time.Time{} + } if c.state.hardRateLimited { return c.state.rateLimitResetAt } @@ -430,6 +510,11 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { defer c.pollAccess.Unlock() defer c.markUsagePollAttempted() + c.retryCredentialReloadIfNeeded() + if !c.isAvailable() { + return + } + accessToken, err := c.getAccessToken() if err != nil { c.logger.Error("poll usage for ", c.tag, ": get token: ", err) @@ -528,6 +613,12 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { } func (c *defaultCredential) close() { + if c.watcher != nil { + err := c.watcher.Close() + if err != nil { + c.logger.Error("close credential watcher for ", c.tag, ": ", err) + } + } if c.usageTracker != nil { c.usageTracker.cancelPendingSave() err := c.usageTracker.Save() @@ -622,6 +713,9 @@ func (p *singleCredentialProvider) selectCredential(_ string, filter func(creden if filter != nil && !filter(p.cred) { return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } + if !p.cred.isAvailable() { + return nil, false, p.cred.unavailableError() + } if !p.cred.isUsable() { return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") } @@ -866,13 +960,21 @@ func (p *fallbackProvider) allCredentials() []credential { func (p *fallbackProvider) close() {} func allCredentialsUnavailableError(credentials []credential) error { + var hasUnavailable bool var earliest time.Time for _, cred := range credentials { + if cred.unavailableError() != nil { + hasUnavailable = true + continue + } resetAt := cred.earliestReset() if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { earliest = resetAt } } + if hasUnavailable { + return E.New("all credentials unavailable") + } if earliest.IsZero() { return E.New("all credentials rate-limited") } diff --git a/service/ccm/service.go b/service/ccm/service.go index 1238b63041..c9d40219a6 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -88,7 +88,11 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string) if provider == nil { return fallback } - return allCredentialsUnavailableError(provider.allCredentials()).Error() + message := allCredentialsUnavailableError(provider.allCredentials()).Error() + if message == "all credentials unavailable" && fallback != "" { + return fallback + } + return message } func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { @@ -734,6 +738,9 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user var totalFiveHour, totalWeekly float64 var count int for _, cred := range provider.allCredentials() { + if !cred.isAvailable() { + continue + } // Exclude the user's own external_credential (their contribution to us) if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { continue diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 0cdbd63790..f16beb9163 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -175,3 +175,41 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut return &newCredentials, nil } + +func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { + if credentials == nil { + return nil + } + cloned := *credentials + if credentials.Tokens != nil { + clonedTokens := *credentials.Tokens + cloned.Tokens = &clonedTokens + } + if credentials.LastRefresh != nil { + lastRefresh := *credentials.LastRefresh + cloned.LastRefresh = &lastRefresh + } + return &cloned +} + +func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { + if left == nil || right == nil { + return left == right + } + if left.APIKey != right.APIKey { + return false + } + if (left.Tokens == nil) != (right.Tokens == nil) { + return false + } + if left.Tokens != nil && *left.Tokens != *right.Tokens { + return false + } + if (left.LastRefresh == nil) != (right.LastRefresh == nil) { + return false + } + if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) { + return false + } + return true +} diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 337fc6da3c..1584590794 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -157,6 +157,10 @@ func (c *externalCredential) isExternal() bool { return true } +func (c *externalCredential) isAvailable() bool { + return true +} + func (c *externalCredential) isUsable() bool { c.stateMutex.RLock() if c.state.hardRateLimited { @@ -215,6 +219,10 @@ func (c *externalCredential) earliestReset() time.Time { return earliest } +func (c *externalCredential) unavailableError() error { + return nil +} + func (c *externalCredential) getAccessToken() (string, error) { return c.token, nil } diff --git a/service/ocm/credential_file.go b/service/ocm/credential_file.go new file mode 100644 index 0000000000..b8252904ea --- /dev/null +++ b/service/ocm/credential_file.go @@ -0,0 +1,139 @@ +package ocm + +import ( + "path/filepath" + "time" + + "github.com/sagernet/fswatch" + E "github.com/sagernet/sing/common/exceptions" +) + +const credentialReloadRetryInterval = 2 * time.Second + +func resolveCredentialFilePath(customPath string) (string, error) { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return "", err + } + } + if filepath.IsAbs(customPath) { + return customPath, nil + } + return filepath.Abs(customPath) +} + +func (c *defaultCredential) ensureCredentialWatcher() error { + c.watcherAccess.Lock() + defer c.watcherAccess.Unlock() + + if c.watcher != nil || c.credentialFilePath == "" { + return nil + } + if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) { + return nil + } + + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: []string{c.credentialFilePath}, + Logger: c.logger, + Callback: func(string) { + err := c.reloadCredentials(true) + if err != nil { + c.logger.Warn("reload credentials for ", c.tag, ": ", err) + } + }, + }) + if err != nil { + c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval) + return err + } + + err = watcher.Start() + if err != nil { + c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval) + return err + } + + c.watcher = watcher + c.watcherRetryAt = time.Time{} + return nil +} + +func (c *defaultCredential) retryCredentialReloadIfNeeded() { + c.stateMutex.RLock() + unavailable := c.state.unavailable + lastAttempt := c.state.lastCredentialLoadAttempt + c.stateMutex.RUnlock() + if !unavailable { + return + } + if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval { + return + } + + err := c.ensureCredentialWatcher() + if err != nil { + c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + } + _ = c.reloadCredentials(false) +} + +func (c *defaultCredential) reloadCredentials(force bool) error { + c.reloadAccess.Lock() + defer c.reloadAccess.Unlock() + + c.stateMutex.RLock() + unavailable := c.state.unavailable + lastAttempt := c.state.lastCredentialLoadAttempt + c.stateMutex.RUnlock() + if !force { + if !unavailable { + return nil + } + if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval { + return c.unavailableError() + } + } + + c.stateMutex.Lock() + c.state.lastCredentialLoadAttempt = time.Now() + c.stateMutex.Unlock() + + credentials, err := platformReadCredentials(c.credentialPath) + if err != nil { + return c.markCredentialsUnavailable(E.Cause(err, "read credentials")) + } + + c.accessMutex.Lock() + c.credentials = credentials + c.accessMutex.Unlock() + + c.stateMutex.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadError = "" + c.checkTransitionLocked() + c.stateMutex.Unlock() + + return nil +} + +func (c *defaultCredential) markCredentialsUnavailable(err error) error { + c.accessMutex.Lock() + hadCredentials := c.credentials != nil + c.credentials = nil + c.accessMutex.Unlock() + + c.stateMutex.Lock() + c.state.unavailable = true + c.state.lastCredentialLoadError = err.Error() + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + + if shouldInterrupt && hadCredentials { + c.interruptConnections() + } + + return err +} diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 0cab4fc2a1..e6c0642e48 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "time" + "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/log" @@ -29,31 +30,39 @@ import ( const defaultPollInterval = 60 * time.Minute type credentialState struct { - fiveHourUtilization float64 - fiveHourReset time.Time - weeklyUtilization float64 - weeklyReset time.Time - hardRateLimited bool - rateLimitResetAt time.Time - accountType string - lastUpdated time.Time - consecutivePollFailures int + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + lastUpdated time.Time + consecutivePollFailures int + unavailable bool + lastCredentialLoadAttempt time.Time + lastCredentialLoadError string } type defaultCredential struct { - tag string - credentialPath string - credentials *oauthCredentials - accessMutex sync.RWMutex - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - reserve5h uint8 - reserveWeekly uint8 - usageTracker *AggregatedUsage - dialer N.Dialer - httpClient *http.Client - logger log.ContextLogger + tag string + credentialPath string + credentialFilePath string + credentials *oauthCredentials + accessMutex sync.RWMutex + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + reloadAccess sync.Mutex + watcherAccess sync.Mutex + reserve5h uint8 + reserveWeekly uint8 + usageTracker *AggregatedUsage + dialer N.Dialer + httpClient *http.Client + logger log.ContextLogger + watcher *fswatch.Watcher + watcherRetryAt time.Time // Connection interruption onBecameUnusable func() @@ -84,12 +93,14 @@ func (c *credentialRequestContext) cancelRequest() { type credential interface { tagName() string + isAvailable() bool isUsable() bool isExternal() bool fiveHourUtilization() float64 weeklyUtilization() float64 markRateLimited(resetAt time.Time) earliestReset() time.Time + unavailableError() error getAccessToken() (string, error) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) @@ -169,11 +180,19 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef } func (c *defaultCredential) start() error { - credentials, err := platformReadCredentials(c.credentialPath) + credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) if err != nil { - return E.Cause(err, "read credentials for ", c.tag) + return E.Cause(err, "resolve credential path for ", c.tag) + } + c.credentialFilePath = credentialFilePath + err = c.ensureCredentialWatcher() + if err != nil { + c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + } + err = c.reloadCredentials(true) + if err != nil { + c.logger.Warn("initial credential load for ", c.tag, ": ", err) } - c.credentials = credentials if c.usageTracker != nil { err = c.usageTracker.Load() if err != nil { @@ -184,27 +203,65 @@ func (c *defaultCredential) start() error { } func (c *defaultCredential) getAccessToken() (string, error) { + c.retryCredentialReloadIfNeeded() + c.accessMutex.RLock() - if !c.credentials.needsRefresh() { + if c.credentials != nil && !c.credentials.needsRefresh() { token := c.credentials.getAccessToken() c.accessMutex.RUnlock() return token, nil } c.accessMutex.RUnlock() + err := c.reloadCredentials(true) + if err == nil { + c.accessMutex.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.getAccessToken() + c.accessMutex.RUnlock() + return token, nil + } + c.accessMutex.RUnlock() + } + c.accessMutex.Lock() defer c.accessMutex.Unlock() + if c.credentials == nil { + return "", c.unavailableError() + } if !c.credentials.needsRefresh() { return c.credentials.getAccessToken(), nil } + baseCredentials := cloneCredentials(c.credentials) newCredentials, err := refreshToken(c.httpClient, c.credentials) if err != nil { return "", err } + latestCredentials, latestErr := platformReadCredentials(c.credentialPath) + if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { + c.credentials = latestCredentials + c.stateMutex.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.checkTransitionLocked() + c.stateMutex.Unlock() + if !latestCredentials.needsRefresh() { + return latestCredentials.getAccessToken(), nil + } + return "", E.New("credential ", c.tag, " changed while refreshing") + } + c.credentials = newCredentials + c.stateMutex.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.checkTransitionLocked() + c.stateMutex.Unlock() err = platformWriteCredentials(newCredentials, c.credentialPath) if err != nil { @@ -217,12 +274,18 @@ func (c *defaultCredential) getAccessToken() (string, error) { func (c *defaultCredential) getAccountID() string { c.accessMutex.RLock() defer c.accessMutex.RUnlock() + if c.credentials == nil { + return "" + } return c.credentials.getAccountID() } func (c *defaultCredential) isAPIKeyMode() bool { c.accessMutex.RLock() defer c.accessMutex.RUnlock() + if c.credentials == nil { + return false + } return c.credentials.isAPIKeyMode() } @@ -296,7 +359,13 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { } func (c *defaultCredential) isUsable() bool { + c.retryCredentialReloadIfNeeded() + c.stateMutex.RLock() + if c.state.unavailable { + c.stateMutex.RUnlock() + return false + } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { c.stateMutex.RUnlock() @@ -329,7 +398,7 @@ func (c *defaultCredential) checkReservesLocked() bool { // checkTransitionLocked detects usable→unusable transition. // Must be called with stateMutex write lock held. func (c *defaultCredential) checkTransitionLocked() bool { - unusable := c.state.hardRateLimited || !c.checkReservesLocked() + unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() if unusable && !c.interrupted { c.interrupted = true return true @@ -372,6 +441,26 @@ func (c *defaultCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *defaultCredential) isAvailable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return !c.state.unavailable +} + +func (c *defaultCredential) unavailableError() error { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + if !c.state.unavailable { + return nil + } + if c.state.lastCredentialLoadError == "" { + return E.New("credential ", c.tag, " is unavailable") + } + return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) +} + func (c *defaultCredential) lastUpdatedTime() time.Time { c.stateMutex.RLock() defer c.stateMutex.RUnlock() @@ -400,6 +489,9 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio func (c *defaultCredential) earliestReset() time.Time { c.stateMutex.RLock() defer c.stateMutex.RUnlock() + if c.state.unavailable { + return time.Time{} + } if c.state.hardRateLimited { return c.state.rateLimitResetAt } @@ -421,15 +513,20 @@ func isTimeoutError(err error) bool { } func (c *defaultCredential) pollUsage(ctx context.Context) { - if c.isAPIKeyMode() { - return - } if !c.pollAccess.TryLock() { return } defer c.pollAccess.Unlock() defer c.markUsagePollAttempted() + c.retryCredentialReloadIfNeeded() + if !c.isAvailable() { + return + } + if c.isAPIKeyMode() { + return + } + accessToken, err := c.getAccessToken() if err != nil { c.logger.Error("poll usage for ", c.tag, ": get token: ", err) @@ -546,6 +643,12 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { } func (c *defaultCredential) close() { + if c.watcher != nil { + err := c.watcher.Close() + if err != nil { + c.logger.Error("close credential watcher for ", c.tag, ": ", err) + } + } if c.usageTracker != nil { c.usageTracker.cancelPendingSave() err := c.usageTracker.Save() @@ -662,6 +765,9 @@ func (p *singleCredentialProvider) selectCredential(_ string, filter func(creden if filter != nil && !filter(p.cred) { return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } + if !p.cred.isAvailable() { + return nil, false, p.cred.unavailableError() + } if !p.cred.isUsable() { return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") } @@ -702,6 +808,10 @@ type balancerProvider struct { logger log.ContextLogger } +func compositeCredentialSelectable(cred credential) bool { + return !cred.ocmIsAPIKeyMode() +} + func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval @@ -722,7 +832,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden p.sessionMutex.RUnlock() if exists { for _, cred := range p.credentials { - if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() { + if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && (filter == nil || filter(cred)) && cred.isUsable() { return cred, false, nil } } @@ -781,6 +891,9 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia if filter != nil && !filter(cred) { continue } + if !compositeCredentialSelectable(cred) { + continue + } if !cred.isUsable() { continue } @@ -801,6 +914,9 @@ func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credenti if filter != nil && !filter(candidate) { continue } + if !compositeCredentialSelectable(candidate) { + continue + } if candidate.isUsable() { return candidate } @@ -814,6 +930,9 @@ func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { if filter != nil && !filter(candidate) { continue } + if !compositeCredentialSelectable(candidate) { + continue + } if candidate.isUsable() { usable = append(usable, candidate) } @@ -869,6 +988,9 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo if filter != nil && !filter(cred) { continue } + if !compositeCredentialSelectable(cred) { + continue + } if cred.isUsable() { return cred, false, nil } @@ -882,6 +1004,9 @@ func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time if filter != nil && !filter(candidate) { continue } + if !compositeCredentialSelectable(candidate) { + continue + } if candidate.isUsable() { return candidate } @@ -904,13 +1029,21 @@ func (p *fallbackProvider) allCredentials() []credential { func (p *fallbackProvider) close() {} func allRateLimitedError(credentials []credential) error { + var hasUnavailable bool var earliest time.Time for _, cred := range credentials { + if cred.unavailableError() != nil { + hasUnavailable = true + continue + } resetAt := cred.earliestReset() if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { earliest = resetAt } } + if hasUnavailable { + return E.New("all credentials unavailable") + } if earliest.IsZero() { return E.New("all credentials rate-limited") } @@ -1090,6 +1223,9 @@ func validateOCMCompositeCredentialModes( } for _, subCred := range provider.allCredentials() { + if !subCred.isAvailable() { + continue + } if subCred.ocmIsAPIKeyMode() { return E.New( "credential ", credOpt.Tag, diff --git a/service/ocm/service.go b/service/ocm/service.go index 38527abde8..5712c13c4a 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -96,7 +96,11 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string) if provider == nil { return fallback } - return allRateLimitedError(provider.allCredentials()).Error() + message := allRateLimitedError(provider.allCredentials()).Error() + if message == "all credentials unavailable" && fallback != "" { + return fallback + } + return message } func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { @@ -806,6 +810,9 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user var totalFiveHour, totalWeekly float64 var count int for _, cred := range provider.allCredentials() { + if !cred.isAvailable() { + continue + } if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { continue } From 6829f91a062fce6571bf612bf37017578452dd1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 03:27:42 +0800 Subject: [PATCH 10/96] ccm,ocm: check credential file writability before token refresh Refuse to refresh tokens when the credential file is not writable, preventing server-side invalidation of the old refresh token that would make the credential permanently unusable after restart. --- service/ccm/credential.go | 8 ++++++++ service/ccm/credential_darwin.go | 7 +++++++ service/ccm/credential_other.go | 11 +++++++++++ service/ccm/credential_state.go | 7 ++++++- service/ocm/credential.go | 8 ++++++++ service/ocm/credential_darwin.go | 11 +++++++++++ service/ocm/credential_other.go | 11 +++++++++++ service/ocm/credential_state.go | 7 ++++++- 8 files changed, 68 insertions(+), 2 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index f14d4d2bc0..6b30008612 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -108,6 +108,14 @@ func readCredentialsFromFile(path string) (*oauthCredentials, error) { return credentialsContainer.ClaudeAIAuth, nil } +func checkCredentialFileWritable(path string) error { + file, err := os.OpenFile(path, os.O_WRONLY, 0) + if err != nil { + return err + } + return file.Close() +} + func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error { data, err := json.MarshalIndent(map[string]any{ "claudeAiOauth": oauthCredentials, diff --git a/service/ccm/credential_darwin.go b/service/ccm/credential_darwin.go index 24047b8585..aef10c8748 100644 --- a/service/ccm/credential_darwin.go +++ b/service/ccm/credential_darwin.go @@ -69,6 +69,13 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) { return readCredentialsFromFile(defaultPath) } +func platformCanWriteCredentials(customPath string) error { + if customPath == "" { + return nil + } + return checkCredentialFileWritable(customPath) +} + func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error { if customPath != "" { return writeCredentialsToFile(oauthCredentials, customPath) diff --git a/service/ccm/credential_other.go b/service/ccm/credential_other.go index 11888b5082..02c52e71ef 100644 --- a/service/ccm/credential_other.go +++ b/service/ccm/credential_other.go @@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) { return readCredentialsFromFile(customPath) } +func platformCanWriteCredentials(customPath string) error { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return err + } + } + return checkCredentialFileWritable(customPath) +} + func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error { if customPath == "" { var err error diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 9371295972..ff64b24bc1 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -225,6 +225,11 @@ func (c *defaultCredential) getAccessToken() (string, error) { return c.credentials.AccessToken, nil } + err = platformCanWriteCredentials(c.credentialPath) + if err != nil { + return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") + } + baseCredentials := cloneCredentials(c.credentials) newCredentials, err := refreshToken(c.httpClient, c.credentials) if err != nil { @@ -258,7 +263,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { err = platformWriteCredentials(newCredentials, c.credentialPath) if err != nil { - c.logger.Warn("persist refreshed token for ", c.tag, ": ", err) + c.logger.Error("persist refreshed token for ", c.tag, ": ", err) } return newCredentials.AccessToken, nil diff --git a/service/ocm/credential.go b/service/ocm/credential.go index f16beb9163..c143f868ab 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -55,6 +55,14 @@ func readCredentialsFromFile(path string) (*oauthCredentials, error) { return &credentials, nil } +func checkCredentialFileWritable(path string) error { + file, err := os.OpenFile(path, os.O_WRONLY, 0) + if err != nil { + return err + } + return file.Close() +} + func writeCredentialsToFile(credentials *oauthCredentials, path string) error { data, err := json.MarshalIndent(credentials, "", " ") if err != nil { diff --git a/service/ocm/credential_darwin.go b/service/ocm/credential_darwin.go index f3da2a63ed..37e7c1c7a9 100644 --- a/service/ocm/credential_darwin.go +++ b/service/ocm/credential_darwin.go @@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) { return readCredentialsFromFile(customPath) } +func platformCanWriteCredentials(customPath string) error { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return err + } + } + return checkCredentialFileWritable(customPath) +} + func platformWriteCredentials(credentials *oauthCredentials, customPath string) error { if customPath == "" { var err error diff --git a/service/ocm/credential_other.go b/service/ocm/credential_other.go index 22dfd0337a..9da2a569d0 100644 --- a/service/ocm/credential_other.go +++ b/service/ocm/credential_other.go @@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) { return readCredentialsFromFile(customPath) } +func platformCanWriteCredentials(customPath string) error { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return err + } + } + return checkCredentialFileWritable(customPath) +} + func platformWriteCredentials(credentials *oauthCredentials, customPath string) error { if customPath == "" { var err error diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index e6c0642e48..547926b872 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -234,6 +234,11 @@ func (c *defaultCredential) getAccessToken() (string, error) { return c.credentials.getAccessToken(), nil } + err = platformCanWriteCredentials(c.credentialPath) + if err != nil { + return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") + } + baseCredentials := cloneCredentials(c.credentials) newCredentials, err := refreshToken(c.httpClient, c.credentials) if err != nil { @@ -265,7 +270,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { err = platformWriteCredentials(newCredentials, c.credentialPath) if err != nil { - c.logger.Warn("persist refreshed token for ", c.tag, ": ", err) + c.logger.Error("persist refreshed token for ", c.tag, ": ", err) } return newCredentials.getAccessToken(), nil From b96ab4fef908e5db566028412509e645393d3d33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 04:52:31 +0800 Subject: [PATCH 11/96] ccm,ocm,ssmapi: fix HTTP/2 over TLS with h2c handler aTLS.NewListener returns *LazyConn, not *tls.Conn, so Go's http.Server cannot detect TLS via type assertion and falls back to HTTP/1.x. When ALPN negotiates h2, the client sends HTTP/2 frames that the server fails to parse, causing HTTP 520 errors behind Cloudflare. Wrap HTTP handlers with h2c.NewHandler to intercept the HTTP/2 client preface and dispatch to http2.Server.ServeConn, consistent with DERP, v2rayhttp, naive, and v2raygrpclite services. --- service/ccm/service.go | 3 ++- service/ocm/service.go | 3 ++- service/ssmapi/server.go | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/service/ccm/service.go b/service/ccm/service.go index c9d40219a6..31b9de8e2c 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -29,6 +29,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/go-chi/chi/v5" "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) const ( @@ -249,7 +250,7 @@ func (s *Service) Start(stage adapter.StartStage) error { router := chi.NewRouter() router.Mount("/", s) - s.httpServer = &http.Server{Handler: router} + s.httpServer = &http.Server{Handler: h2c.NewHandler(router, &http2.Server{})} if s.tlsConfig != nil { err := s.tlsConfig.Start() diff --git a/service/ocm/service.go b/service/ocm/service.go index 5712c13c4a..1c393716ac 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -30,6 +30,7 @@ import ( "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) func RegisterService(registry *boxService.Registry) { @@ -302,7 +303,7 @@ func (s *Service) Start(stage adapter.StartStage) error { router := chi.NewRouter() router.Mount("/", s) - s.httpServer = &http.Server{Handler: router} + s.httpServer = &http.Server{Handler: h2c.NewHandler(router, &http2.Server{})} if s.tlsConfig != nil { err := s.tlsConfig.Start() diff --git a/service/ssmapi/server.go b/service/ssmapi/server.go index 157ea150b4..97ea6326fd 100644 --- a/service/ssmapi/server.go +++ b/service/ssmapi/server.go @@ -22,6 +22,7 @@ import ( "github.com/go-chi/chi/v5" "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) func RegisterService(registry *boxService.Registry) { @@ -59,7 +60,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio Listen: options.ListenOptions, }), httpServer: &http.Server{ - Handler: chiRouter, + Handler: h2c.NewHandler(chiRouter, &http2.Server{}), }, traffics: make(map[string]*TrafficManager), users: make(map[string]*UserManager), From 15f36199953d09fba60c4d6ed72916e71dbcda3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 04:54:11 +0800 Subject: [PATCH 12/96] ccm,ocm: strip reverse proxy headers before forwarding to upstream --- service/ccm/credential_external.go | 2 +- service/ccm/credential_state.go | 2 +- service/ccm/service.go | 13 +++++++++++++ service/ocm/credential_external.go | 2 +- service/ocm/credential_state.go | 2 +- service/ocm/service.go | 13 +++++++++++++ service/ocm/service_websocket.go | 2 +- 7 files changed, 31 insertions(+), 5 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 456ba4fcbe..f6560a2e61 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -236,7 +236,7 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht } for key, values := range original.Header { - if !isHopByHopHeader(key) && key != "Authorization" { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { proxyRequest.Header[key] = values } } diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index ff64b24bc1..fbde0e8aca 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -674,7 +674,7 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt } for key, values := range original.Header { - if !isHopByHopHeader(key) && key != "Authorization" { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { proxyRequest.Header[key] = values } } diff --git a/service/ccm/service.go b/service/ccm/service.go index 31b9de8e2c..4bd24b1762 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -128,6 +128,19 @@ func isHopByHopHeader(header string) bool { } } +func isReverseProxyHeader(header string) bool { + lowerHeader := strings.ToLower(header) + if strings.HasPrefix(lowerHeader, "cf-") { + return true + } + switch lowerHeader { + case "cdn-loop", "true-client-ip", "x-forwarded-for", "x-forwarded-proto", "x-real-ip": + return true + default: + return false + } +} + const ( weeklyWindowSeconds = 604800 weeklyWindowMinutes = weeklyWindowSeconds / 60 diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 1584590794..ca9664f1e1 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -241,7 +241,7 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht } for key, values := range original.Header { - if !isHopByHopHeader(key) && key != "Authorization" { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { proxyRequest.Header[key] = values } } diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 547926b872..92745492d4 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -736,7 +736,7 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt } for key, values := range original.Header { - if !isHopByHopHeader(key) && key != "Authorization" { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { proxyRequest.Header[key] = values } } diff --git a/service/ocm/service.go b/service/ocm/service.go index 1c393716ac..74fa776d8b 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -136,6 +136,19 @@ func isHopByHopHeader(header string) bool { } } +func isReverseProxyHeader(header string) bool { + lowerHeader := strings.ToLower(header) + if strings.HasPrefix(lowerHeader, "cf-") { + return true + } + switch lowerHeader { + case "cdn-loop", "true-client-ip", "x-forwarded-for", "x-forwarded-proto", "x-real-ip": + return true + default: + return false + } +} + func normalizeRateLimitIdentifier(limitIdentifier string) string { trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier)) if trimmedIdentifier == "" { diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index eeb0380560..7aa68499cb 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -65,7 +65,7 @@ func isForwardableResponseHeader(key string) bool { } func isForwardableWebSocketRequestHeader(key string) bool { - if isHopByHopHeader(key) { + if isHopByHopHeader(key) || isReverseProxyHeader(key) { return false } From 970951f36996df3e2b8d7f2f58d9f8d2b7f18860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 18:51:02 +0800 Subject: [PATCH 13/96] ccm,ocm: add reverse proxy support for external credentials MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow two CCM/OCM instances to share credentials when only one has a public IP, using yamux-multiplexed reverse connections. Three credential modes: - Normal: URL set, reverse=false — standard HTTP proxy - Receiver: URL empty — waits for incoming reverse connection - Connector: URL set, reverse=true — dials out to establish connection Extend InterfaceUpdated to services so network changes trigger reverse connection reconnection. --- option/ccm.go | 3 +- option/ocm.go | 3 +- route/network.go | 11 ++ service/ccm/credential_external.go | 224 ++++++++++++++++++-------- service/ccm/credential_state.go | 6 +- service/ccm/reverse.go | 243 +++++++++++++++++++++++++++++ service/ccm/service.go | 22 +++ service/ocm/credential_external.go | 224 ++++++++++++++++++-------- service/ocm/credential_state.go | 6 +- service/ocm/reverse.go | 243 +++++++++++++++++++++++++++++ service/ocm/service.go | 22 +++ 11 files changed, 873 insertions(+), 134 deletions(-) create mode 100644 service/ccm/reverse.go create mode 100644 service/ocm/reverse.go diff --git a/option/ccm.go b/option/ccm.go index 6846dfccb6..ae80cc64b2 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -95,9 +95,10 @@ type CCMBalancerCredentialOptions struct { } type CCMExternalCredentialOptions struct { - URL string `json:"url"` + URL string `json:"url,omitempty"` ServerOptions Token string `json:"token"` + Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` diff --git a/option/ocm.go b/option/ocm.go index 4d495ff27a..20cafee123 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -95,9 +95,10 @@ type OCMBalancerCredentialOptions struct { } type OCMExternalCredentialOptions struct { - URL string `json:"url"` + URL string `json:"url,omitempty"` ServerOptions Token string `json:"token"` + Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` diff --git a/route/network.go b/route/network.go index b8eefdc069..3f0cf57cae 100644 --- a/route/network.go +++ b/route/network.go @@ -51,6 +51,7 @@ type NetworkManager struct { endpoint adapter.EndpointManager inbound adapter.InboundManager outbound adapter.OutboundManager + serviceManager adapter.ServiceManager needWIFIState bool wifiMonitor settings.WIFIMonitor wifiState adapter.WIFIState @@ -94,6 +95,7 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, options endpoint: service.FromContext[adapter.EndpointManager](ctx), inbound: service.FromContext[adapter.InboundManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx), + serviceManager: service.FromContext[adapter.ServiceManager](ctx), needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), } if options.DefaultNetworkStrategy != nil { @@ -475,6 +477,15 @@ func (r *NetworkManager) ResetNetwork() { listener.InterfaceUpdated() } } + + if r.serviceManager != nil { + for _, svc := range r.serviceManager.Services() { + listener, isListener := svc.(adapter.InterfaceUpdateListener) + if isListener { + listener.InterfaceUpdated() + } + } + } } func (r *NetworkManager) notifyInterfaceUpdate(defaultInterface *control.Interface, flags int) { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index f6560a2e61..e8e53c181c 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -19,9 +19,14 @@ import ( "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" + + "github.com/hashicorp/yamux" ) +const reverseProxyBaseURL = "http://reverse-proxy" + type externalCredential struct { tag string baseURL string @@ -39,86 +44,134 @@ type externalCredential struct { requestContext context.Context cancelRequests context.CancelFunc requestAccess sync.Mutex + + // Reverse proxy fields + reverse bool + reverseSession *yamux.Session + reverseAccess sync.RWMutex + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler } func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { - parsedURL, err := url.Parse(options.URL) - if err != nil { - return nil, E.Cause(err, "parse url for credential ", tag) - } - - credentialDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer for credential ", tag) - } - - transport := &http.Transport{ - ForceAttemptHTTP2: true, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if options.Server != "" { - serverPort := options.ServerPort - if serverPort == 0 { - portStr := parsedURL.Port() - if portStr != "" { - port, parseErr := strconv.ParseUint(portStr, 10, 16) - if parseErr == nil { - serverPort = uint16(port) - } - } - if serverPort == 0 { - if parsedURL.Scheme == "https" { - serverPort = 443 - } else { - serverPort = 80 - } - } - } - destination := M.ParseSocksaddrHostPort(options.Server, serverPort) - return credentialDialer.DialContext(ctx, network, destination) - } - return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - } - - if parsedURL.Scheme == "https" { - transport.TLSClientConfig = &stdTLS.Config{ - ServerName: parsedURL.Hostname(), - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - } - } - - baseURL := parsedURL.Scheme + "://" + parsedURL.Host - if parsedURL.Path != "" && parsedURL.Path != "/" { - baseURL += parsedURL.Path - } - // Strip trailing slash - if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { - baseURL = baseURL[:len(baseURL)-1] - } - pollInterval := time.Duration(options.PollInterval) if pollInterval <= 0 { pollInterval = 30 * time.Minute } requestContext, cancelRequests := context.WithCancel(context.Background()) + reverseContext, reverseCancel := context.WithCancel(context.Background()) cred := &externalCredential{ tag: tag, - baseURL: baseURL, token: options.Token, - httpClient: &http.Client{Transport: transport}, pollInterval: pollInterval, logger: logger, requestContext: requestContext, cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, + } + + if options.URL == "" { + // Receiver mode: no URL, wait for reverse connection + cred.baseURL = reverseProxyBaseURL + cred.httpClient = &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: false, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + session := cred.getReverseSession() + if session == nil || session.IsClosed() { + return nil, E.New("reverse connection not established for ", cred.tag) + } + return session.Open() + }, + }, + } + } else { + // Normal or connector mode: has URL + parsedURL, err := url.Parse(options.URL) + if err != nil { + return nil, E.Cause(err, "parse url for credential ", tag) + } + + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + + transport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if options.Server != "" { + serverPort := options.ServerPort + if serverPort == 0 { + portStr := parsedURL.Port() + if portStr != "" { + port, parseErr := strconv.ParseUint(portStr, 10, 16) + if parseErr == nil { + serverPort = uint16(port) + } + } + if serverPort == 0 { + if parsedURL.Scheme == "https" { + serverPort = 443 + } else { + serverPort = 80 + } + } + } + destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + return credentialDialer.DialContext(ctx, network, destination) + } + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + } + + if parsedURL.Scheme == "https" { + transport.TLSClientConfig = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + + cred.baseURL = baseURL + + if options.Reverse { + // Connector mode: we dial out to serve, not to proxy + cred.connectorDialer = credentialDialer + cred.connectorURL = parsedURL + if parsedURL.Scheme == "https" { + cred.connectorTLS = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + } else { + // Normal mode: standard HTTP client for proxying + cred.httpClient = &http.Client{Transport: transport} + } } if options.UsagesPath != "" { @@ -140,6 +193,9 @@ func (c *externalCredential) start() error { c.logger.Warn("load usage statistics for ", c.tag, ": ", err) } } + if c.reverse && c.connectorURL != nil { + go c.connectorLoop() + } return nil } @@ -152,6 +208,14 @@ func (c *externalCredential) isExternal() bool { } func (c *externalCredential) isAvailable() bool { + if c.reverse && c.connectorURL != nil { + return false // connector mode: not for local proxying + } + if c.baseURL == reverseProxyBaseURL { + // receiver mode: only available when reverse connection active + session := c.getReverseSession() + return session != nil && !session.IsClosed() + } return true } @@ -426,6 +490,16 @@ func (c *externalCredential) httpTransport() *http.Client { } func (c *externalCredential) close() { + if c.reverseCancel != nil { + c.reverseCancel() + } + c.reverseAccess.Lock() + session := c.reverseSession + c.reverseSession = nil + c.reverseAccess.Unlock() + if session != nil { + session.Close() + } if c.usageTracker != nil { c.usageTracker.cancelPendingSave() err := c.usageTracker.Save() @@ -434,3 +508,27 @@ func (c *externalCredential) close() { } } } + +func (c *externalCredential) getReverseSession() *yamux.Session { + c.reverseAccess.RLock() + defer c.reverseAccess.RUnlock() + return c.reverseSession +} + +func (c *externalCredential) setReverseSession(session *yamux.Session) { + c.reverseAccess.Lock() + old := c.reverseSession + c.reverseSession = session + c.reverseAccess.Unlock() + if old != nil { + old.Close() + } +} + +func (c *externalCredential) clearReverseSession(session *yamux.Session) { + c.reverseAccess.Lock() + if c.reverseSession == session { + c.reverseSession = nil + } + c.reverseAccess.Unlock() +} diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index fbde0e8aca..673af5c2e6 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -1117,12 +1117,12 @@ func validateCCMOptions(options option.CCMServiceOptions) error { } } if cred.Type == "external" { - if cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": external credential requires url") - } if cred.ExternalOptions.Token == "" { return E.New("credential ", cred.Tag, ": external credential requires token") } + if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": reverse external credential requires url") + } } if cred.Type == "balancer" { switch cred.BalancerOptions.Strategy { diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go new file mode 100644 index 0000000000..571c8c55ae --- /dev/null +++ b/service/ccm/reverse.go @@ -0,0 +1,243 @@ +package ccm + +import ( + "bufio" + stdTLS "crypto/tls" + "errors" + "io" + "math/rand/v2" + "net" + "net/http" + "strings" + "time" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "github.com/hashicorp/yamux" +) + +func reverseYamuxConfig() *yamux.Config { + config := yamux.DefaultConfig() + config.KeepAliveInterval = 15 * time.Second + config.ConnectionWriteTimeout = 10 * time.Second + config.MaxStreamWindowSize = 512 * 1024 + config.LogOutput = io.Discard + return config +} + +type yamuxNetListener struct { + session *yamux.Session +} + +func (l *yamuxNetListener) Accept() (net.Conn, error) { + return l.session.Accept() +} + +func (l *yamuxNetListener) Close() error { + return l.session.Close() +} + +func (l *yamuxNetListener) Addr() net.Addr { + return l.session.Addr() +} + +func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") != "reverse-proxy" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + + receiverCredential := s.findReceiverCredential(clientToken) + if receiverCredential == nil { + s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + s.logger.Error("reverse connect: hijack not supported") + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") + return + } + + conn, bufferedReadWriter, err := hijacker.Hijack() + if err != nil { + s.logger.Error("reverse connect: hijack: ", err) + return + } + + response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n" + _, err = bufferedReadWriter.WriteString(response) + if err != nil { + conn.Close() + s.logger.Error("reverse connect: write upgrade response: ", err) + return + } + err = bufferedReadWriter.Flush() + if err != nil { + conn.Close() + s.logger.Error("reverse connect: flush upgrade response: ", err) + return + } + + session, err := yamux.Client(conn, reverseYamuxConfig()) + if err != nil { + conn.Close() + s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) + return + } + + receiverCredential.setReverseSession(session) + s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) + + go func() { + <-session.CloseChan() + receiverCredential.clearReverseSession(session) + s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) + }() +} + +func (s *Service) findReceiverCredential(token string) *externalCredential { + for _, cred := range s.allCredentials { + extCred, ok := cred.(*externalCredential) + if !ok { + continue + } + if extCred.baseURL == reverseProxyBaseURL && extCred.token == token { + return extCred + } + } + return nil +} + +func (c *externalCredential) connectorLoop() { + var consecutiveFailures int + for { + select { + case <-c.reverseContext.Done(): + return + default: + } + + err := c.connectorConnect() + if c.reverseContext.Err() != nil { + return + } + consecutiveFailures++ + backoff := connectorBackoff(consecutiveFailures) + c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) + select { + case <-time.After(backoff): + case <-c.reverseContext.Done(): + return + } + } +} + +func connectorBackoff(failures int) time.Duration { + if failures > 5 { + failures = 5 + } + base := time.Second * time.Duration(1< 30*time.Second { + base = 30 * time.Second + } + jitter := time.Duration(rand.Int64N(int64(base) / 2)) + return base + jitter +} + +func (c *externalCredential) connectorConnect() error { + destination := c.connectorResolveDestination() + conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) + if err != nil { + return E.Cause(err, "dial") + } + + if c.connectorTLS != nil { + tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone()) + err = tlsConn.HandshakeContext(c.reverseContext) + if err != nil { + conn.Close() + return E.Cause(err, "tls handshake") + } + conn = tlsConn + } + + upgradeRequest := "GET /ccm/v1/reverse HTTP/1.1\r\n" + + "Host: " + c.connectorURL.Host + "\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: reverse-proxy\r\n" + + "Authorization: Bearer " + c.token + "\r\n" + + "\r\n" + _, err = io.WriteString(conn, upgradeRequest) + if err != nil { + conn.Close() + return E.Cause(err, "write upgrade request") + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + conn.Close() + return E.Cause(err, "read upgrade response") + } + if !strings.HasPrefix(statusLine, "HTTP/1.1 101") { + conn.Close() + return E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine)) + } + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + conn.Close() + return E.Cause(readErr, "read upgrade headers") + } + if strings.TrimSpace(line) == "" { + break + } + } + + session, err := yamux.Server(conn, reverseYamuxConfig()) + if err != nil { + conn.Close() + return E.Cause(err, "create yamux server") + } + defer session.Close() + + c.logger.Info("reverse connection established for ", c.tag) + + httpServer := &http.Server{ + Handler: c.reverseService, + ReadTimeout: 0, + IdleTimeout: 120 * time.Second, + } + err = httpServer.Serve(&yamuxNetListener{session: session}) + if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { + return E.Cause(err, "serve") + } + return E.New("connection closed") +} + +func (c *externalCredential) connectorResolveDestination() M.Socksaddr { + port := c.connectorURL.Port() + if port == "" { + if c.connectorURL.Scheme == "https" { + port = "443" + } else { + port = "80" + } + } + return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port)) +} diff --git a/service/ccm/service.go b/service/ccm/service.go index 4bd24b1762..69697b5c05 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -258,6 +258,9 @@ func (s *Service) Start(stage adapter.StartStage) error { if err != nil { return err } + if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s + } } router := chi.NewRouter() @@ -318,6 +321,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if r.URL.Path == "/ccm/v1/reverse" { + s.handleReverseConnect(w, r) + return + } + if !strings.HasPrefix(r.URL.Path, "/v1/") { writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found") return @@ -786,6 +794,20 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) } +func (s *Service) InterfaceUpdated() { + for _, cred := range s.allCredentials { + extCred, ok := cred.(*externalCredential) + if !ok { + continue + } + if extCred.reverse && extCred.connectorURL != nil { + extCred.reverseCancel() + extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) + go extCred.connectorLoop() + } + } +} + func (s *Service) Close() error { err := common.Close( common.PtrOrNil(s.httpServer), diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index ca9664f1e1..8226d63666 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -21,8 +21,12 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" + + "github.com/hashicorp/yamux" ) +const reverseProxyBaseURL = "http://reverse-proxy" + type externalCredential struct { tag string baseURL string @@ -41,86 +45,135 @@ type externalCredential struct { requestContext context.Context cancelRequests context.CancelFunc requestAccess sync.Mutex + + // Reverse proxy fields + reverse bool + reverseSession *yamux.Session + reverseAccess sync.RWMutex + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler } func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { - parsedURL, err := url.Parse(options.URL) - if err != nil { - return nil, E.Cause(err, "parse url for credential ", tag) - } - - credentialDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer for credential ", tag) - } - - transport := &http.Transport{ - ForceAttemptHTTP2: true, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if options.Server != "" { - serverPort := options.ServerPort - if serverPort == 0 { - portStr := parsedURL.Port() - if portStr != "" { - port, parseErr := strconv.ParseUint(portStr, 10, 16) - if parseErr == nil { - serverPort = uint16(port) - } - } - if serverPort == 0 { - if parsedURL.Scheme == "https" { - serverPort = 443 - } else { - serverPort = 80 - } - } - } - destination := M.ParseSocksaddrHostPort(options.Server, serverPort) - return credentialDialer.DialContext(ctx, network, destination) - } - return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - } - - if parsedURL.Scheme == "https" { - transport.TLSClientConfig = &stdTLS.Config{ - ServerName: parsedURL.Hostname(), - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - } - } - - baseURL := parsedURL.Scheme + "://" + parsedURL.Host - if parsedURL.Path != "" && parsedURL.Path != "/" { - baseURL += parsedURL.Path - } - if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { - baseURL = baseURL[:len(baseURL)-1] - } - pollInterval := time.Duration(options.PollInterval) if pollInterval <= 0 { pollInterval = 30 * time.Minute } requestContext, cancelRequests := context.WithCancel(context.Background()) + reverseContext, reverseCancel := context.WithCancel(context.Background()) cred := &externalCredential{ tag: tag, - baseURL: baseURL, token: options.Token, - credDialer: credentialDialer, - httpClient: &http.Client{Transport: transport}, pollInterval: pollInterval, logger: logger, requestContext: requestContext, cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, + } + + if options.URL == "" { + // Receiver mode: no URL, wait for reverse connection + cred.baseURL = reverseProxyBaseURL + cred.httpClient = &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: false, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + session := cred.getReverseSession() + if session == nil || session.IsClosed() { + return nil, E.New("reverse connection not established for ", cred.tag) + } + return session.Open() + }, + }, + } + } else { + // Normal or connector mode: has URL + parsedURL, err := url.Parse(options.URL) + if err != nil { + return nil, E.Cause(err, "parse url for credential ", tag) + } + + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + + transport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if options.Server != "" { + serverPort := options.ServerPort + if serverPort == 0 { + portStr := parsedURL.Port() + if portStr != "" { + port, parseErr := strconv.ParseUint(portStr, 10, 16) + if parseErr == nil { + serverPort = uint16(port) + } + } + if serverPort == 0 { + if parsedURL.Scheme == "https" { + serverPort = 443 + } else { + serverPort = 80 + } + } + } + destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + return credentialDialer.DialContext(ctx, network, destination) + } + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + } + + if parsedURL.Scheme == "https" { + transport.TLSClientConfig = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + + cred.baseURL = baseURL + + if options.Reverse { + // Connector mode: we dial out to serve, not to proxy + cred.connectorDialer = credentialDialer + cred.connectorURL = parsedURL + if parsedURL.Scheme == "https" { + cred.connectorTLS = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + } else { + // Normal mode: standard HTTP client for proxying + cred.credDialer = credentialDialer + cred.httpClient = &http.Client{Transport: transport} + } } if options.UsagesPath != "" { @@ -142,6 +195,9 @@ func (c *externalCredential) start() error { c.logger.Warn("load usage statistics for ", c.tag, ": ", err) } } + if c.reverse && c.connectorURL != nil { + go c.connectorLoop() + } return nil } @@ -158,6 +214,14 @@ func (c *externalCredential) isExternal() bool { } func (c *externalCredential) isAvailable() bool { + if c.reverse && c.connectorURL != nil { + return false // connector mode: not for local proxying + } + if c.baseURL == reverseProxyBaseURL { + // receiver mode: only available when reverse connection active + session := c.getReverseSession() + return session != nil && !session.IsClosed() + } return true } @@ -461,6 +525,16 @@ func (c *externalCredential) ocmGetBaseURL() string { } func (c *externalCredential) close() { + if c.reverseCancel != nil { + c.reverseCancel() + } + c.reverseAccess.Lock() + session := c.reverseSession + c.reverseSession = nil + c.reverseAccess.Unlock() + if session != nil { + session.Close() + } if c.usageTracker != nil { c.usageTracker.cancelPendingSave() err := c.usageTracker.Save() @@ -469,3 +543,27 @@ func (c *externalCredential) close() { } } } + +func (c *externalCredential) getReverseSession() *yamux.Session { + c.reverseAccess.RLock() + defer c.reverseAccess.RUnlock() + return c.reverseSession +} + +func (c *externalCredential) setReverseSession(session *yamux.Session) { + c.reverseAccess.Lock() + old := c.reverseSession + c.reverseSession = session + c.reverseAccess.Unlock() + if old != nil { + old.Close() + } +} + +func (c *externalCredential) clearReverseSession(session *yamux.Session) { + c.reverseAccess.Lock() + if c.reverseSession == session { + c.reverseSession = nil + } + c.reverseAccess.Unlock() +} diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 92745492d4..b663632af0 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -1176,12 +1176,12 @@ func validateOCMOptions(options option.OCMServiceOptions) error { } } if cred.Type == "external" { - if cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": external credential requires url") - } if cred.ExternalOptions.Token == "" { return E.New("credential ", cred.Tag, ": external credential requires token") } + if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": reverse external credential requires url") + } } if cred.Type == "balancer" { switch cred.BalancerOptions.Strategy { diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go new file mode 100644 index 0000000000..b02a202220 --- /dev/null +++ b/service/ocm/reverse.go @@ -0,0 +1,243 @@ +package ocm + +import ( + "bufio" + stdTLS "crypto/tls" + "errors" + "io" + "math/rand/v2" + "net" + "net/http" + "strings" + "time" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "github.com/hashicorp/yamux" +) + +func reverseYamuxConfig() *yamux.Config { + config := yamux.DefaultConfig() + config.KeepAliveInterval = 15 * time.Second + config.ConnectionWriteTimeout = 10 * time.Second + config.MaxStreamWindowSize = 512 * 1024 + config.LogOutput = io.Discard + return config +} + +type yamuxNetListener struct { + session *yamux.Session +} + +func (l *yamuxNetListener) Accept() (net.Conn, error) { + return l.session.Accept() +} + +func (l *yamuxNetListener) Close() error { + return l.session.Close() +} + +func (l *yamuxNetListener) Addr() net.Addr { + return l.session.Addr() +} + +func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") != "reverse-proxy" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + + receiverCredential := s.findReceiverCredential(clientToken) + if receiverCredential == nil { + s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + s.logger.Error("reverse connect: hijack not supported") + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") + return + } + + conn, bufferedReadWriter, err := hijacker.Hijack() + if err != nil { + s.logger.Error("reverse connect: hijack: ", err) + return + } + + response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n" + _, err = bufferedReadWriter.WriteString(response) + if err != nil { + conn.Close() + s.logger.Error("reverse connect: write upgrade response: ", err) + return + } + err = bufferedReadWriter.Flush() + if err != nil { + conn.Close() + s.logger.Error("reverse connect: flush upgrade response: ", err) + return + } + + session, err := yamux.Client(conn, reverseYamuxConfig()) + if err != nil { + conn.Close() + s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) + return + } + + receiverCredential.setReverseSession(session) + s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) + + go func() { + <-session.CloseChan() + receiverCredential.clearReverseSession(session) + s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) + }() +} + +func (s *Service) findReceiverCredential(token string) *externalCredential { + for _, cred := range s.allCredentials { + extCred, ok := cred.(*externalCredential) + if !ok { + continue + } + if extCred.baseURL == reverseProxyBaseURL && extCred.token == token { + return extCred + } + } + return nil +} + +func (c *externalCredential) connectorLoop() { + var consecutiveFailures int + for { + select { + case <-c.reverseContext.Done(): + return + default: + } + + err := c.connectorConnect() + if c.reverseContext.Err() != nil { + return + } + consecutiveFailures++ + backoff := connectorBackoff(consecutiveFailures) + c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) + select { + case <-time.After(backoff): + case <-c.reverseContext.Done(): + return + } + } +} + +func connectorBackoff(failures int) time.Duration { + if failures > 5 { + failures = 5 + } + base := time.Second * time.Duration(1< 30*time.Second { + base = 30 * time.Second + } + jitter := time.Duration(rand.Int64N(int64(base) / 2)) + return base + jitter +} + +func (c *externalCredential) connectorConnect() error { + destination := c.connectorResolveDestination() + conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) + if err != nil { + return E.Cause(err, "dial") + } + + if c.connectorTLS != nil { + tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone()) + err = tlsConn.HandshakeContext(c.reverseContext) + if err != nil { + conn.Close() + return E.Cause(err, "tls handshake") + } + conn = tlsConn + } + + upgradeRequest := "GET /ocm/v1/reverse HTTP/1.1\r\n" + + "Host: " + c.connectorURL.Host + "\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: reverse-proxy\r\n" + + "Authorization: Bearer " + c.token + "\r\n" + + "\r\n" + _, err = io.WriteString(conn, upgradeRequest) + if err != nil { + conn.Close() + return E.Cause(err, "write upgrade request") + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + conn.Close() + return E.Cause(err, "read upgrade response") + } + if !strings.HasPrefix(statusLine, "HTTP/1.1 101") { + conn.Close() + return E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine)) + } + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + conn.Close() + return E.Cause(readErr, "read upgrade headers") + } + if strings.TrimSpace(line) == "" { + break + } + } + + session, err := yamux.Server(conn, reverseYamuxConfig()) + if err != nil { + conn.Close() + return E.Cause(err, "create yamux server") + } + defer session.Close() + + c.logger.Info("reverse connection established for ", c.tag) + + httpServer := &http.Server{ + Handler: c.reverseService, + ReadTimeout: 0, + IdleTimeout: 120 * time.Second, + } + err = httpServer.Serve(&yamuxNetListener{session: session}) + if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { + return E.Cause(err, "serve") + } + return E.New("connection closed") +} + +func (c *externalCredential) connectorResolveDestination() M.Socksaddr { + port := c.connectorURL.Port() + if port == "" { + if c.connectorURL.Scheme == "https" { + port = "443" + } else { + port = "80" + } + } + return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port)) +} diff --git a/service/ocm/service.go b/service/ocm/service.go index 74fa776d8b..50a44db896 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -305,6 +305,9 @@ func (s *Service) Start(stage adapter.StartStage) error { cred.setOnBecameUnusable(func() { s.interruptWebSocketSessionsForCredential(tag) }) + if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s + } } if len(s.options.Credentials) > 0 { err := validateOCMCompositeCredentialModes(s.options, s.providers) @@ -364,6 +367,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if r.URL.Path == "/ocm/v1/reverse" { + s.handleReverseConnect(w, r) + return + } + path := r.URL.Path if !strings.HasPrefix(path, "/v1/") { writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") @@ -860,6 +868,20 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) } +func (s *Service) InterfaceUpdated() { + for _, cred := range s.allCredentials { + extCred, ok := cred.(*externalCredential) + if !ok { + continue + } + if extCred.reverse && extCred.connectorURL != nil { + extCred.reverseCancel() + extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) + go extCred.connectorLoop() + } + } +} + func (s *Service) Close() error { webSocketSessions := s.startWebSocketShutdown() From af94ea9089be3a449991a9248eee4be3653550a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 19:36:51 +0800 Subject: [PATCH 14/96] Fix reverse external credential handling --- service/ccm/credential_external.go | 152 +++++++++++++++++--------- service/ccm/reverse.go | 15 +-- service/ccm/service.go | 7 +- service/ocm/credential_external.go | 169 ++++++++++++++++++++--------- service/ocm/reverse.go | 15 +-- service/ocm/service.go | 7 +- 6 files changed, 239 insertions(+), 126 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index e8e53c181c..0bcf15a77b 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -10,6 +10,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "sync" "time" @@ -46,15 +47,59 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseSession *yamux.Session - reverseAccess sync.RWMutex - reverseContext context.Context - reverseCancel context.CancelFunc - connectorDialer N.Dialer - connectorURL *url.URL - connectorTLS *stdTLS.Config - reverseService http.Handler + reverse bool + reverseSession *yamux.Session + reverseAccess sync.RWMutex + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorDestination M.Socksaddr + connectorRequestPath string + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler +} + +func externalCredentialURLPort(parsedURL *url.URL) uint16 { + portStr := parsedURL.Port() + if portStr != "" { + port, err := strconv.ParseUint(portStr, 10, 16) + if err == nil { + return uint16(port) + } + } + if parsedURL.Scheme == "https" { + return 443 + } + return 80 +} + +func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 { + if configuredPort != 0 { + return configuredPort + } + return externalCredentialURLPort(parsedURL) +} + +func externalCredentialBaseURL(parsedURL *url.URL) string { + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + return baseURL +} + +func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string { + pathPrefix := parsedURL.EscapedPath() + if pathPrefix == "/" { + pathPrefix = "" + } else { + pathPrefix = strings.TrimSuffix(pathPrefix, "/") + } + return pathPrefix + endpointPath } func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { @@ -85,11 +130,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - session := cred.getReverseSession() - if session == nil || session.IsClosed() { - return nil, E.New("reverse connection not established for ", cred.tag) - } - return session.Open() + return cred.openReverseConnection(ctx) }, }, } @@ -115,24 +156,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if options.Server != "" { - serverPort := options.ServerPort - if serverPort == 0 { - portStr := parsedURL.Port() - if portStr != "" { - port, parseErr := strconv.ParseUint(portStr, 10, 16) - if parseErr == nil { - serverPort = uint16(port) - } - } - if serverPort == 0 { - if parsedURL.Scheme == "https" { - serverPort = 443 - } else { - serverPort = 80 - } - } - } - destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) return credentialDialer.DialContext(ctx, network, destination) } return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) @@ -147,19 +171,17 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx } } - baseURL := parsedURL.Scheme + "://" + parsedURL.Host - if parsedURL.Path != "" && parsedURL.Path != "/" { - baseURL += parsedURL.Path - } - if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { - baseURL = baseURL[:len(baseURL)-1] - } - - cred.baseURL = baseURL + cred.baseURL = externalCredentialBaseURL(parsedURL) if options.Reverse { // Connector mode: we dial out to serve, not to proxy cred.connectorDialer = credentialDialer + if options.Server != "" { + cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) + } else { + cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL)) + } + cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse") cred.connectorURL = parsedURL if parsedURL.Scheme == "https" { cred.connectorTLS = &stdTLS.Config{ @@ -208,18 +230,13 @@ func (c *externalCredential) isExternal() bool { } func (c *externalCredential) isAvailable() bool { - if c.reverse && c.connectorURL != nil { - return false // connector mode: not for local proxying - } - if c.baseURL == reverseProxyBaseURL { - // receiver mode: only available when reverse connection active - session := c.getReverseSession() - return session != nil && !session.IsClosed() - } - return true + return c.unavailableError() == nil } func (c *externalCredential) isUsable() bool { + if !c.isAvailable() { + return false + } c.stateMutex.RLock() if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { @@ -279,6 +296,15 @@ func (c *externalCredential) earliestReset() time.Time { } func (c *externalCredential) unavailableError() error { + if c.reverse && c.connectorURL != nil { + return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests") + } + if c.baseURL == reverseProxyBaseURL { + session := c.getReverseSession() + if session == nil || session.IsClosed() { + return E.New("credential ", c.tag, " is unavailable: reverse connection not established") + } + } return nil } @@ -310,6 +336,32 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht return proxyRequest, nil } +func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + session := c.getReverseSession() + if session == nil || session.IsClosed() { + return nil, E.New("reverse connection not established for ", c.tag) + } + conn, err := session.Open() + if err != nil { + return nil, err + } + select { + case <-ctx.Done(): + conn.Close() + return nil, ctx.Err() + default: + } + return conn, nil +} + func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.stateMutex.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 571c8c55ae..7e38c9ceda 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -160,6 +160,9 @@ func connectorBackoff(failures int) time.Duration { } func (c *externalCredential) connectorConnect() error { + if c.reverseService == nil { + return E.New("reverse service not initialized") + } destination := c.connectorResolveDestination() conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) if err != nil { @@ -176,7 +179,7 @@ func (c *externalCredential) connectorConnect() error { conn = tlsConn } - upgradeRequest := "GET /ccm/v1/reverse HTTP/1.1\r\n" + + upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" + "Host: " + c.connectorURL.Host + "\r\n" + "Connection: Upgrade\r\n" + "Upgrade: reverse-proxy\r\n" + @@ -231,13 +234,5 @@ func (c *externalCredential) connectorConnect() error { } func (c *externalCredential) connectorResolveDestination() M.Socksaddr { - port := c.connectorURL.Port() - if port == "" { - if c.connectorURL.Scheme == "https" { - port = "443" - } else { - port = "80" - } - } - return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port)) + return c.connectorDestination } diff --git a/service/ccm/service.go b/service/ccm/service.go index 69697b5c05..5d3415ea2c 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -254,13 +254,13 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) for _, cred := range s.allCredentials { + if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s + } err := cred.start() if err != nil { return err } - if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { - extCred.reverseService = s - } } router := chi.NewRouter() @@ -801,6 +801,7 @@ func (s *Service) InterfaceUpdated() { continue } if extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s extCred.reverseCancel() extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) go extCred.connectorLoop() diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 8226d63666..83d37f3855 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -9,7 +9,9 @@ import ( "net" "net/http" "net/url" + "os" "strconv" + "strings" "sync" "time" @@ -47,15 +49,74 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseSession *yamux.Session - reverseAccess sync.RWMutex - reverseContext context.Context - reverseCancel context.CancelFunc - connectorDialer N.Dialer - connectorURL *url.URL - connectorTLS *stdTLS.Config - reverseService http.Handler + reverse bool + reverseSession *yamux.Session + reverseAccess sync.RWMutex + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorDestination M.Socksaddr + connectorRequestPath string + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler +} + +type reverseSessionDialer struct { + credential *externalCredential +} + +func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if N.NetworkName(network) != N.NetworkTCP { + return nil, os.ErrInvalid + } + return d.credential.openReverseConnection(ctx) +} + +func (d reverseSessionDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return nil, os.ErrInvalid +} + +func externalCredentialURLPort(parsedURL *url.URL) uint16 { + portStr := parsedURL.Port() + if portStr != "" { + port, err := strconv.ParseUint(portStr, 10, 16) + if err == nil { + return uint16(port) + } + } + if parsedURL.Scheme == "https" { + return 443 + } + return 80 +} + +func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 { + if configuredPort != 0 { + return configuredPort + } + return externalCredentialURLPort(parsedURL) +} + +func externalCredentialBaseURL(parsedURL *url.URL) string { + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + return baseURL +} + +func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string { + pathPrefix := parsedURL.EscapedPath() + if pathPrefix == "/" { + pathPrefix = "" + } else { + pathPrefix = strings.TrimSuffix(pathPrefix, "/") + } + return pathPrefix + endpointPath } func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { @@ -82,15 +143,12 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx if options.URL == "" { // Receiver mode: no URL, wait for reverse connection cred.baseURL = reverseProxyBaseURL + cred.credDialer = reverseSessionDialer{credential: cred} cred.httpClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - session := cred.getReverseSession() - if session == nil || session.IsClosed() { - return nil, E.New("reverse connection not established for ", cred.tag) - } - return session.Open() + return cred.openReverseConnection(ctx) }, }, } @@ -116,24 +174,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if options.Server != "" { - serverPort := options.ServerPort - if serverPort == 0 { - portStr := parsedURL.Port() - if portStr != "" { - port, parseErr := strconv.ParseUint(portStr, 10, 16) - if parseErr == nil { - serverPort = uint16(port) - } - } - if serverPort == 0 { - if parsedURL.Scheme == "https" { - serverPort = 443 - } else { - serverPort = 80 - } - } - } - destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) return credentialDialer.DialContext(ctx, network, destination) } return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) @@ -148,19 +189,17 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx } } - baseURL := parsedURL.Scheme + "://" + parsedURL.Host - if parsedURL.Path != "" && parsedURL.Path != "/" { - baseURL += parsedURL.Path - } - if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { - baseURL = baseURL[:len(baseURL)-1] - } - - cred.baseURL = baseURL + cred.baseURL = externalCredentialBaseURL(parsedURL) if options.Reverse { // Connector mode: we dial out to serve, not to proxy cred.connectorDialer = credentialDialer + if options.Server != "" { + cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) + } else { + cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL)) + } + cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ocm/v1/reverse") cred.connectorURL = parsedURL if parsedURL.Scheme == "https" { cred.connectorTLS = &stdTLS.Config{ @@ -214,18 +253,13 @@ func (c *externalCredential) isExternal() bool { } func (c *externalCredential) isAvailable() bool { - if c.reverse && c.connectorURL != nil { - return false // connector mode: not for local proxying - } - if c.baseURL == reverseProxyBaseURL { - // receiver mode: only available when reverse connection active - session := c.getReverseSession() - return session != nil && !session.IsClosed() - } - return true + return c.unavailableError() == nil } func (c *externalCredential) isUsable() bool { + if !c.isAvailable() { + return false + } c.stateMutex.RLock() if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { @@ -284,6 +318,15 @@ func (c *externalCredential) earliestReset() time.Time { } func (c *externalCredential) unavailableError() error { + if c.reverse && c.connectorURL != nil { + return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests") + } + if c.baseURL == reverseProxyBaseURL { + session := c.getReverseSession() + if session == nil || session.IsClosed() { + return E.New("credential ", c.tag, " is unavailable: reverse connection not established") + } + } return nil } @@ -315,6 +358,32 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht return proxyRequest, nil } +func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + session := c.getReverseSession() + if session == nil || session.IsClosed() { + return nil, E.New("reverse connection not established for ", c.tag) + } + conn, err := session.Open() + if err != nil { + return nil, err + } + select { + case <-ctx.Done(): + conn.Close() + return nil, ctx.Err() + default: + } + return conn, nil +} + func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.stateMutex.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index b02a202220..23ca1cc47c 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -160,6 +160,9 @@ func connectorBackoff(failures int) time.Duration { } func (c *externalCredential) connectorConnect() error { + if c.reverseService == nil { + return E.New("reverse service not initialized") + } destination := c.connectorResolveDestination() conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) if err != nil { @@ -176,7 +179,7 @@ func (c *externalCredential) connectorConnect() error { conn = tlsConn } - upgradeRequest := "GET /ocm/v1/reverse HTTP/1.1\r\n" + + upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" + "Host: " + c.connectorURL.Host + "\r\n" + "Connection: Upgrade\r\n" + "Upgrade: reverse-proxy\r\n" + @@ -231,13 +234,5 @@ func (c *externalCredential) connectorConnect() error { } func (c *externalCredential) connectorResolveDestination() M.Socksaddr { - port := c.connectorURL.Port() - if port == "" { - if c.connectorURL.Scheme == "https" { - port = "443" - } else { - port = "80" - } - } - return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port)) + return c.connectorDestination } diff --git a/service/ocm/service.go b/service/ocm/service.go index 50a44db896..245f2a444c 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -297,6 +297,9 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) for _, cred := range s.allCredentials { + if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s + } err := cred.start() if err != nil { return err @@ -305,9 +308,6 @@ func (s *Service) Start(stage adapter.StartStage) error { cred.setOnBecameUnusable(func() { s.interruptWebSocketSessionsForCredential(tag) }) - if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { - extCred.reverseService = s - } } if len(s.options.Credentials) > 0 { err := validateOCMCompositeCredentialModes(s.options, s.providers) @@ -875,6 +875,7 @@ func (s *Service) InterfaceUpdated() { continue } if extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s extCred.reverseCancel() extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) go extCred.connectorLoop() From 02a1409e9addf4c29be152683325fd3e75bd78db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 20:05:54 +0800 Subject: [PATCH 15/96] ccm,ocm: unify HTTP request retry with fast retry and exponential backoff --- service/ccm/credential.go | 22 +++++----- service/ccm/credential_external.go | 16 +++---- service/ccm/credential_state.go | 69 +++++++++++++++++------------- service/ocm/credential.go | 20 +++++---- service/ocm/credential_external.go | 16 +++---- service/ocm/credential_state.go | 69 +++++++++++++++++------------- 6 files changed, 119 insertions(+), 93 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 6b30008612..75ae62f972 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -2,6 +2,7 @@ package ccm import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -142,7 +143,7 @@ func (c *oauthCredentials) needsRefresh() bool { return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs } -func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { if credentials.RefreshToken == "" { return nil, E.New("refresh token is empty") } @@ -156,15 +157,16 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut return nil, E.Cause(err, "marshal request") } - request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, err - } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - - response, err := httpClient.Do(request) + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) if err != nil { return nil, err } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 0bcf15a77b..8a0ffda86e 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -449,14 +449,14 @@ func (c *externalCredential) pollUsage(ctx context.Context) { Timeout: 5 * time.Second, } - request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) - if err != nil { - c.logger.Error("poll usage for ", c.tag, ": create request: ", err) - return - } - request.Header.Set("Authorization", "Bearer "+c.token) - - response, err := httpClient.Do(request) + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+c.token) + return request, nil + }) if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) c.stateMutex.Lock() diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 673af5c2e6..6ecdd50a83 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -5,7 +5,6 @@ import ( "context" stdTLS "crypto/tls" "encoding/json" - "errors" "io" "math" "math/rand/v2" @@ -29,6 +28,38 @@ import ( const defaultPollInterval = 60 * time.Minute +const ( + httpRetryMaxAttempts = 3 + httpRetryInitialDelay = 200 * time.Millisecond +) + +func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { + var lastError error + for attempt := range httpRetryMaxAttempts { + if attempt > 0 { + delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + select { + case <-ctx.Done(): + return nil, lastError + case <-time.After(delay): + } + } + request, err := buildRequest() + if err != nil { + return nil, err + } + response, err := client.Do(request) + if err == nil { + return response, nil + } + lastError = err + if ctx.Err() != nil { + return nil, lastError + } + } + return nil, lastError +} + type credentialState struct { fiveHourUtilization float64 fiveHourReset time.Time @@ -46,6 +77,7 @@ type credentialState struct { type defaultCredential struct { tag string + serviceContext context.Context credentialPath string credentialFilePath string credentials *oauthCredentials @@ -151,6 +183,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef requestContext, cancelRequests := context.WithCancel(context.Background()) credential := &defaultCredential{ tag: tag, + serviceContext: ctx, credentialPath: options.CredentialPath, reserve5h: reserve5h, reserveWeekly: reserveWeekly, @@ -231,7 +264,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { } baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.httpClient, c.credentials) + newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials) if err != nil { return "", err } @@ -498,16 +531,6 @@ func (c *defaultCredential) earliestReset() time.Time { return earliest } -const pollUsageMaxRetries = 3 - -func isTimeoutError(err error) bool { - var netErr net.Error - if errors.As(err, &netErr) { - return netErr.Timeout() - } - return false -} - func (c *defaultCredential) pollUsage(ctx context.Context) { if !c.pollAccess.TryLock() { return @@ -531,30 +554,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { Timeout: 5 * time.Second, } - var response *http.Response - for attempt := range pollUsageMaxRetries { + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) if err != nil { - c.logger.Error("poll usage for ", c.tag, ": create request: ", err) - return + return nil, err } request.Header.Set("Authorization", "Bearer "+accessToken) request.Header.Set("Content-Type", "application/json") request.Header.Set("User-Agent", ccmUserAgentValue) request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - - response, err = httpClient.Do(request) - if err == nil { - break - } - if !isTimeoutError(err) { - c.logger.Error("poll usage for ", c.tag, ": ", err) - return - } - if attempt < pollUsageMaxRetries-1 { - c.logger.Warn("poll usage for ", c.tag, ": timeout, retrying (", attempt+1, "/", pollUsageMaxRetries, ")") - continue - } + return request, nil + }) + if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) return } diff --git a/service/ocm/credential.go b/service/ocm/credential.go index c143f868ab..bb240b5aba 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -2,6 +2,7 @@ package ocm import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -118,7 +119,7 @@ func (c *oauthCredentials) needsRefresh() bool { return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour } -func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" { return nil, E.New("refresh token is empty") } @@ -133,14 +134,15 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut return nil, E.Cause(err, "marshal request") } - request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, err - } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "application/json") - - response, err := httpClient.Do(request) + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + return request, nil + }) if err != nil { return nil, err } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 83d37f3855..0d19ea557b 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -485,14 +485,14 @@ func (c *externalCredential) pollUsage(ctx context.Context) { Timeout: 5 * time.Second, } - request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) - if err != nil { - c.logger.Error("poll usage for ", c.tag, ": create request: ", err) - return - } - request.Header.Set("Authorization", "Bearer "+c.token) - - response, err := httpClient.Do(request) + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+c.token) + return request, nil + }) if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) c.stateMutex.Lock() diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index b663632af0..821183da2c 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -5,7 +5,6 @@ import ( "context" stdTLS "crypto/tls" "encoding/json" - "errors" "io" "math/rand/v2" "net" @@ -29,6 +28,38 @@ import ( const defaultPollInterval = 60 * time.Minute +const ( + httpRetryMaxAttempts = 3 + httpRetryInitialDelay = 200 * time.Millisecond +) + +func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { + var lastError error + for attempt := range httpRetryMaxAttempts { + if attempt > 0 { + delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + select { + case <-ctx.Done(): + return nil, lastError + case <-time.After(delay): + } + } + request, err := buildRequest() + if err != nil { + return nil, err + } + response, err := client.Do(request) + if err == nil { + return response, nil + } + lastError = err + if ctx.Err() != nil { + return nil, lastError + } + } + return nil, lastError +} + type credentialState struct { fiveHourUtilization float64 fiveHourReset time.Time @@ -46,6 +77,7 @@ type credentialState struct { type defaultCredential struct { tag string + serviceContext context.Context credentialPath string credentialFilePath string credentials *oauthCredentials @@ -159,6 +191,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef requestContext, cancelRequests := context.WithCancel(context.Background()) credential := &defaultCredential{ tag: tag, + serviceContext: ctx, credentialPath: options.CredentialPath, reserve5h: reserve5h, reserveWeekly: reserveWeekly, @@ -240,7 +273,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { } baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.httpClient, c.credentials) + newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials) if err != nil { return "", err } @@ -507,16 +540,6 @@ func (c *defaultCredential) earliestReset() time.Time { return earliest } -const pollUsageMaxRetries = 3 - -func isTimeoutError(err error) bool { - var netErr net.Error - if errors.As(err, &netErr) { - return netErr.Timeout() - } - return false -} - func (c *defaultCredential) pollUsage(ctx context.Context) { if !c.pollAccess.TryLock() { return @@ -551,30 +574,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { Timeout: 5 * time.Second, } - var response *http.Response - for attempt := range pollUsageMaxRetries { + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil) if err != nil { - c.logger.Error("poll usage for ", c.tag, ": create request: ", err) - return + return nil, err } request.Header.Set("Authorization", "Bearer "+accessToken) if accountID != "" { request.Header.Set("ChatGPT-Account-Id", accountID) } - - response, err = httpClient.Do(request) - if err == nil { - break - } - if !isTimeoutError(err) { - c.logger.Error("poll usage for ", c.tag, ": ", err) - return - } - if attempt < pollUsageMaxRetries-1 { - c.logger.Warn("poll usage for ", c.tag, ": timeout, retrying (", attempt+1, "/", pollUsageMaxRetries, ")") - continue - } + return request, nil + }) + if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) return } From 182488171971035985e19465f45dd0276e5336d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 20:07:18 +0800 Subject: [PATCH 16/96] ccm,ocm: reset connector backoff after successful connection The consecutiveFailures counter in connectorLoop never resets, causing backoff to permanently cap at 30-45s even after a connection that served successfully for hours. Reset the counter when connectorConnect ran for at least one minute, indicating a successful session rather than a transient dial/handshake failure. --- service/ccm/reverse.go | 31 +++++++++++++++++++------------ service/ocm/reverse.go | 31 +++++++++++++++++++------------ 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 7e38c9ceda..97d71f0ade 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -132,10 +132,13 @@ func (c *externalCredential) connectorLoop() { default: } - err := c.connectorConnect() + sessionLifetime, err := c.connectorConnect() if c.reverseContext.Err() != nil { return } + if sessionLifetime >= connectorBackoffResetThreshold { + consecutiveFailures = 0 + } consecutiveFailures++ backoff := connectorBackoff(consecutiveFailures) c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) @@ -147,6 +150,8 @@ func (c *externalCredential) connectorLoop() { } } +const connectorBackoffResetThreshold = time.Minute + func connectorBackoff(failures int) time.Duration { if failures > 5 { failures = 5 @@ -159,14 +164,14 @@ func connectorBackoff(failures int) time.Duration { return base + jitter } -func (c *externalCredential) connectorConnect() error { +func (c *externalCredential) connectorConnect() (time.Duration, error) { if c.reverseService == nil { - return E.New("reverse service not initialized") + return 0, E.New("reverse service not initialized") } destination := c.connectorResolveDestination() conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) if err != nil { - return E.Cause(err, "dial") + return 0, E.Cause(err, "dial") } if c.connectorTLS != nil { @@ -174,7 +179,7 @@ func (c *externalCredential) connectorConnect() error { err = tlsConn.HandshakeContext(c.reverseContext) if err != nil { conn.Close() - return E.Cause(err, "tls handshake") + return 0, E.Cause(err, "tls handshake") } conn = tlsConn } @@ -188,24 +193,24 @@ func (c *externalCredential) connectorConnect() error { _, err = io.WriteString(conn, upgradeRequest) if err != nil { conn.Close() - return E.Cause(err, "write upgrade request") + return 0, E.Cause(err, "write upgrade request") } reader := bufio.NewReader(conn) statusLine, err := reader.ReadString('\n') if err != nil { conn.Close() - return E.Cause(err, "read upgrade response") + return 0, E.Cause(err, "read upgrade response") } if !strings.HasPrefix(statusLine, "HTTP/1.1 101") { conn.Close() - return E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine)) + return 0, E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine)) } for { line, readErr := reader.ReadString('\n') if readErr != nil { conn.Close() - return E.Cause(readErr, "read upgrade headers") + return 0, E.Cause(readErr, "read upgrade headers") } if strings.TrimSpace(line) == "" { break @@ -215,22 +220,24 @@ func (c *externalCredential) connectorConnect() error { session, err := yamux.Server(conn, reverseYamuxConfig()) if err != nil { conn.Close() - return E.Cause(err, "create yamux server") + return 0, E.Cause(err, "create yamux server") } defer session.Close() c.logger.Info("reverse connection established for ", c.tag) + serveStart := time.Now() httpServer := &http.Server{ Handler: c.reverseService, ReadTimeout: 0, IdleTimeout: 120 * time.Second, } err = httpServer.Serve(&yamuxNetListener{session: session}) + sessionLifetime := time.Since(serveStart) if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { - return E.Cause(err, "serve") + return sessionLifetime, E.Cause(err, "serve") } - return E.New("connection closed") + return sessionLifetime, E.New("connection closed") } func (c *externalCredential) connectorResolveDestination() M.Socksaddr { diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index 23ca1cc47c..b3c17f45fc 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -132,10 +132,13 @@ func (c *externalCredential) connectorLoop() { default: } - err := c.connectorConnect() + sessionLifetime, err := c.connectorConnect() if c.reverseContext.Err() != nil { return } + if sessionLifetime >= connectorBackoffResetThreshold { + consecutiveFailures = 0 + } consecutiveFailures++ backoff := connectorBackoff(consecutiveFailures) c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) @@ -147,6 +150,8 @@ func (c *externalCredential) connectorLoop() { } } +const connectorBackoffResetThreshold = time.Minute + func connectorBackoff(failures int) time.Duration { if failures > 5 { failures = 5 @@ -159,14 +164,14 @@ func connectorBackoff(failures int) time.Duration { return base + jitter } -func (c *externalCredential) connectorConnect() error { +func (c *externalCredential) connectorConnect() (time.Duration, error) { if c.reverseService == nil { - return E.New("reverse service not initialized") + return 0, E.New("reverse service not initialized") } destination := c.connectorResolveDestination() conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) if err != nil { - return E.Cause(err, "dial") + return 0, E.Cause(err, "dial") } if c.connectorTLS != nil { @@ -174,7 +179,7 @@ func (c *externalCredential) connectorConnect() error { err = tlsConn.HandshakeContext(c.reverseContext) if err != nil { conn.Close() - return E.Cause(err, "tls handshake") + return 0, E.Cause(err, "tls handshake") } conn = tlsConn } @@ -188,24 +193,24 @@ func (c *externalCredential) connectorConnect() error { _, err = io.WriteString(conn, upgradeRequest) if err != nil { conn.Close() - return E.Cause(err, "write upgrade request") + return 0, E.Cause(err, "write upgrade request") } reader := bufio.NewReader(conn) statusLine, err := reader.ReadString('\n') if err != nil { conn.Close() - return E.Cause(err, "read upgrade response") + return 0, E.Cause(err, "read upgrade response") } if !strings.HasPrefix(statusLine, "HTTP/1.1 101") { conn.Close() - return E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine)) + return 0, E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine)) } for { line, readErr := reader.ReadString('\n') if readErr != nil { conn.Close() - return E.Cause(readErr, "read upgrade headers") + return 0, E.Cause(readErr, "read upgrade headers") } if strings.TrimSpace(line) == "" { break @@ -215,22 +220,24 @@ func (c *externalCredential) connectorConnect() error { session, err := yamux.Server(conn, reverseYamuxConfig()) if err != nil { conn.Close() - return E.Cause(err, "create yamux server") + return 0, E.Cause(err, "create yamux server") } defer session.Close() c.logger.Info("reverse connection established for ", c.tag) + serveStart := time.Now() httpServer := &http.Server{ Handler: c.reverseService, ReadTimeout: 0, IdleTimeout: 120 * time.Second, } err = httpServer.Serve(&yamuxNetListener{session: session}) + sessionLifetime := time.Since(serveStart) if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { - return E.Cause(err, "serve") + return sessionLifetime, E.Cause(err, "serve") } - return E.New("connection closed") + return sessionLifetime, E.New("connection closed") } func (c *externalCredential) connectorResolveDestination() M.Socksaddr { From 3b177df05e3c5b44b6547463b80286d4f665d4c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 20:10:31 +0800 Subject: [PATCH 17/96] ccm,ocm: fix data race on reverseContext/reverseCancel InterfaceUpdated() writes reverseContext and reverseCancel without synchronization while connectorLoop/connectorConnect goroutines read them concurrently. close() also accesses reverseCancel without a lock. Fix by extending reverseAccess mutex to protect these fields: - Add getReverseContext()/resetReverseContext() methods - Pass context as parameter to connectorConnect - Merge close() into a single lock acquisition - Use resetReverseContext() in InterfaceUpdated() --- service/ccm/credential_external.go | 15 ++++++++++++++- service/ccm/reverse.go | 18 ++++++++++-------- service/ccm/service.go | 3 +-- service/ocm/credential_external.go | 15 ++++++++++++++- service/ocm/reverse.go | 18 ++++++++++-------- service/ocm/service.go | 3 +-- 6 files changed, 50 insertions(+), 22 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 8a0ffda86e..7459a8891e 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -542,10 +542,10 @@ func (c *externalCredential) httpTransport() *http.Client { } func (c *externalCredential) close() { + c.reverseAccess.Lock() if c.reverseCancel != nil { c.reverseCancel() } - c.reverseAccess.Lock() session := c.reverseSession c.reverseSession = nil c.reverseAccess.Unlock() @@ -584,3 +584,16 @@ func (c *externalCredential) clearReverseSession(session *yamux.Session) { } c.reverseAccess.Unlock() } + +func (c *externalCredential) getReverseContext() context.Context { + c.reverseAccess.RLock() + defer c.reverseAccess.RUnlock() + return c.reverseContext +} + +func (c *externalCredential) resetReverseContext() { + c.reverseAccess.Lock() + c.reverseCancel() + c.reverseContext, c.reverseCancel = context.WithCancel(context.Background()) + c.reverseAccess.Unlock() +} diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 97d71f0ade..e07480a0b4 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -2,6 +2,7 @@ package ccm import ( "bufio" + "context" stdTLS "crypto/tls" "errors" "io" @@ -125,15 +126,16 @@ func (s *Service) findReceiverCredential(token string) *externalCredential { func (c *externalCredential) connectorLoop() { var consecutiveFailures int + ctx := c.getReverseContext() for { select { - case <-c.reverseContext.Done(): + case <-ctx.Done(): return default: } - sessionLifetime, err := c.connectorConnect() - if c.reverseContext.Err() != nil { + sessionLifetime, err := c.connectorConnect(ctx) + if ctx.Err() != nil { return } if sessionLifetime >= connectorBackoffResetThreshold { @@ -144,7 +146,7 @@ func (c *externalCredential) connectorLoop() { c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) select { case <-time.After(backoff): - case <-c.reverseContext.Done(): + case <-ctx.Done(): return } } @@ -164,19 +166,19 @@ func connectorBackoff(failures int) time.Duration { return base + jitter } -func (c *externalCredential) connectorConnect() (time.Duration, error) { +func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) { if c.reverseService == nil { return 0, E.New("reverse service not initialized") } destination := c.connectorResolveDestination() - conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) + conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination) if err != nil { return 0, E.Cause(err, "dial") } if c.connectorTLS != nil { tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone()) - err = tlsConn.HandshakeContext(c.reverseContext) + err = tlsConn.HandshakeContext(ctx) if err != nil { conn.Close() return 0, E.Cause(err, "tls handshake") @@ -234,7 +236,7 @@ func (c *externalCredential) connectorConnect() (time.Duration, error) { } err = httpServer.Serve(&yamuxNetListener{session: session}) sessionLifetime := time.Since(serveStart) - if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil { return sessionLifetime, E.Cause(err, "serve") } return sessionLifetime, E.New("connection closed") diff --git a/service/ccm/service.go b/service/ccm/service.go index 5d3415ea2c..2e7685e71d 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -802,8 +802,7 @@ func (s *Service) InterfaceUpdated() { } if extCred.reverse && extCred.connectorURL != nil { extCred.reverseService = s - extCred.reverseCancel() - extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) + extCred.resetReverseContext() go extCred.connectorLoop() } } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 0d19ea557b..d396705f22 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -594,10 +594,10 @@ func (c *externalCredential) ocmGetBaseURL() string { } func (c *externalCredential) close() { + c.reverseAccess.Lock() if c.reverseCancel != nil { c.reverseCancel() } - c.reverseAccess.Lock() session := c.reverseSession c.reverseSession = nil c.reverseAccess.Unlock() @@ -636,3 +636,16 @@ func (c *externalCredential) clearReverseSession(session *yamux.Session) { } c.reverseAccess.Unlock() } + +func (c *externalCredential) getReverseContext() context.Context { + c.reverseAccess.RLock() + defer c.reverseAccess.RUnlock() + return c.reverseContext +} + +func (c *externalCredential) resetReverseContext() { + c.reverseAccess.Lock() + c.reverseCancel() + c.reverseContext, c.reverseCancel = context.WithCancel(context.Background()) + c.reverseAccess.Unlock() +} diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index b3c17f45fc..e88ccea0ae 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -2,6 +2,7 @@ package ocm import ( "bufio" + "context" stdTLS "crypto/tls" "errors" "io" @@ -125,15 +126,16 @@ func (s *Service) findReceiverCredential(token string) *externalCredential { func (c *externalCredential) connectorLoop() { var consecutiveFailures int + ctx := c.getReverseContext() for { select { - case <-c.reverseContext.Done(): + case <-ctx.Done(): return default: } - sessionLifetime, err := c.connectorConnect() - if c.reverseContext.Err() != nil { + sessionLifetime, err := c.connectorConnect(ctx) + if ctx.Err() != nil { return } if sessionLifetime >= connectorBackoffResetThreshold { @@ -144,7 +146,7 @@ func (c *externalCredential) connectorLoop() { c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) select { case <-time.After(backoff): - case <-c.reverseContext.Done(): + case <-ctx.Done(): return } } @@ -164,19 +166,19 @@ func connectorBackoff(failures int) time.Duration { return base + jitter } -func (c *externalCredential) connectorConnect() (time.Duration, error) { +func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) { if c.reverseService == nil { return 0, E.New("reverse service not initialized") } destination := c.connectorResolveDestination() - conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) + conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination) if err != nil { return 0, E.Cause(err, "dial") } if c.connectorTLS != nil { tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone()) - err = tlsConn.HandshakeContext(c.reverseContext) + err = tlsConn.HandshakeContext(ctx) if err != nil { conn.Close() return 0, E.Cause(err, "tls handshake") @@ -234,7 +236,7 @@ func (c *externalCredential) connectorConnect() (time.Duration, error) { } err = httpServer.Serve(&yamuxNetListener{session: session}) sessionLifetime := time.Since(serveStart) - if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil { return sessionLifetime, E.Cause(err, "serve") } return sessionLifetime, E.New("connection closed") diff --git a/service/ocm/service.go b/service/ocm/service.go index 245f2a444c..3868725ff2 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -876,8 +876,7 @@ func (s *Service) InterfaceUpdated() { } if extCred.reverse && extCred.connectorURL != nil { extCred.reverseService = s - extCred.reverseCancel() - extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) + extCred.resetReverseContext() go extCred.connectorLoop() } } From 4d5108fe7f619069ea4debab68c59f3fa2331a33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 20:11:34 +0800 Subject: [PATCH 18/96] ccm,ocm: fix connector-side bufio data loss in reverse proxy connectorConnect() creates a bufio.NewReader to read the HTTP 101 upgrade response, but then passes the raw conn to yamux.Server(). If TCP coalesces the 101 response with initial yamux frames, the bufio reader over-reads into its buffer and those bytes are lost to yamux, causing session failure. Wrap the bufio.Reader and raw conn into a bufferedConn so yamux reads through the buffer first. --- service/ccm/reverse.go | 11 ++++++++++- service/ocm/reverse.go | 11 ++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index e07480a0b4..ae00df79fd 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -27,6 +27,15 @@ func reverseYamuxConfig() *yamux.Config { return config } +type bufferedConn struct { + reader *bufio.Reader + net.Conn +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + type yamuxNetListener struct { session *yamux.Session } @@ -219,7 +228,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio } } - session, err := yamux.Server(conn, reverseYamuxConfig()) + session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, reverseYamuxConfig()) if err != nil { conn.Close() return 0, E.Cause(err, "create yamux server") diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index e88ccea0ae..25cf017e3e 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -27,6 +27,15 @@ func reverseYamuxConfig() *yamux.Config { return config } +type bufferedConn struct { + reader *bufio.Reader + net.Conn +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + type yamuxNetListener struct { session *yamux.Session } @@ -219,7 +228,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio } } - session, err := yamux.Server(conn, reverseYamuxConfig()) + session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, reverseYamuxConfig()) if err != nil { conn.Close() return 0, E.Cause(err, "create yamux server") From ff8585f7c696579e183ae4290cca1d1eecba1e69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 20:14:27 +0800 Subject: [PATCH 19/96] ccm,ocm: block utilization decrease within same rate-limit window updateStateFromHeaders unconditionally applied header utilization values even when they were lower than the current state, causing poll-sourced values to be overwritten by stale header values. Parse reset timestamps before utilization and only allow decreases when the reset timestamp changes (indicating a new rate-limit window). Also add math.Ceil to CCM external credential for consistency with default credential. --- service/ccm/credential_external.go | 12 ++++----- service/ccm/credential_state.go | 29 +++++++++++++-------- service/ocm/credential_external.go | 23 +++++++++-------- service/ocm/credential_state.go | 41 ++++++++++++++++++++---------- 4 files changed, 64 insertions(+), 41 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 7459a8891e..141c893c34 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -368,15 +368,18 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { + c.state.fiveHourReset = value + } if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { value, err := strconv.ParseFloat(utilization, 64) if err == nil { - // Remote CCM writes aggregated utilization as 0.0-1.0; convert to percentage c.state.fiveHourUtilization = value * 100 } } - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { - c.state.fiveHourReset = value + + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { + c.state.weeklyReset = value } if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { value, err := strconv.ParseFloat(utilization, 64) @@ -384,9 +387,6 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.weeklyUtilization = value * 100 } } - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { - c.state.weeklyReset = value - } c.state.lastUpdated = time.Now() if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 6ecdd50a83..7880585479 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -338,32 +338,39 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization + fiveHourResetChanged := false + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { + if value.After(c.state.fiveHourReset) { + fiveHourResetChanged = true + c.state.fiveHourReset = value + } + } if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { value, err := strconv.ParseFloat(utilization, 64) if err == nil { newValue := math.Ceil(value * 100) - if newValue < c.state.fiveHourUtilization { - c.logger.Error("header 5h utilization for ", c.tag, " is lower than current: ", newValue, " < ", c.state.fiveHourUtilization) + if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged { + c.state.fiveHourUtilization = newValue } - c.state.fiveHourUtilization = newValue } } - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { - c.state.fiveHourReset = value + + weeklyResetChanged := false + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { + if value.After(c.state.weeklyReset) { + weeklyResetChanged = true + c.state.weeklyReset = value + } } if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { value, err := strconv.ParseFloat(utilization, 64) if err == nil { newValue := math.Ceil(value * 100) - if newValue < c.state.weeklyUtilization { - c.logger.Error("header weekly utilization for ", c.tag, " is lower than current: ", newValue, " < ", c.state.weeklyUtilization) + if newValue >= c.state.weeklyUtilization || weeklyResetChanged { + c.state.weeklyUtilization = newValue } - c.state.weeklyUtilization = newValue } } - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { - c.state.weeklyReset = value - } c.state.lastUpdated = time.Now() if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index d396705f22..0d6e6b4b1b 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -395,13 +395,6 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { activeLimitIdentifier = "codex" } - fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") - if fiveHourPercent != "" { - value, err := strconv.ParseFloat(fiveHourPercent, 64) - if err == nil { - c.state.fiveHourUtilization = value - } - } fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") if fiveHourResetAt != "" { value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) @@ -409,13 +402,14 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.fiveHourReset = time.Unix(value, 0) } } - weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") - if weeklyPercent != "" { - value, err := strconv.ParseFloat(weeklyPercent, 64) + fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") + if fiveHourPercent != "" { + value, err := strconv.ParseFloat(fiveHourPercent, 64) if err == nil { - c.state.weeklyUtilization = value + c.state.fiveHourUtilization = value } } + weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") if weeklyResetAt != "" { value, err := strconv.ParseInt(weeklyResetAt, 10, 64) @@ -423,6 +417,13 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.weeklyReset = time.Unix(value, 0) } } + weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") + if weeklyPercent != "" { + value, err := strconv.ParseFloat(weeklyPercent, 64) + if err == nil { + c.state.weeklyUtilization = value + } + } c.state.lastUpdated = time.Now() if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 821183da2c..81019336b5 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -345,32 +345,47 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { activeLimitIdentifier = "codex" } - fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") - if fiveHourPercent != "" { - value, err := strconv.ParseFloat(fiveHourPercent, 64) - if err == nil { - c.state.fiveHourUtilization = value - } - } + fiveHourResetChanged := false fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") if fiveHourResetAt != "" { value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) if err == nil { - c.state.fiveHourReset = time.Unix(value, 0) + newReset := time.Unix(value, 0) + if newReset.After(c.state.fiveHourReset) { + fiveHourResetChanged = true + c.state.fiveHourReset = newReset + } } } - weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") - if weeklyPercent != "" { - value, err := strconv.ParseFloat(weeklyPercent, 64) + fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") + if fiveHourPercent != "" { + value, err := strconv.ParseFloat(fiveHourPercent, 64) if err == nil { - c.state.weeklyUtilization = value + if value >= c.state.fiveHourUtilization || fiveHourResetChanged { + c.state.fiveHourUtilization = value + } } } + + weeklyResetChanged := false weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") if weeklyResetAt != "" { value, err := strconv.ParseInt(weeklyResetAt, 10, 64) if err == nil { - c.state.weeklyReset = time.Unix(value, 0) + newReset := time.Unix(value, 0) + if newReset.After(c.state.weeklyReset) { + weeklyResetChanged = true + c.state.weeklyReset = newReset + } + } + } + weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") + if weeklyPercent != "" { + value, err := strconv.ParseFloat(weeklyPercent, 64) + if err == nil { + if value >= c.state.weeklyUtilization || weeklyResetChanged { + c.state.weeklyUtilization = value + } } } c.state.lastUpdated = time.Now() From 74bf20d349433557d744f0ad19d4326daf69b85c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 21:31:23 +0800 Subject: [PATCH 20/96] ccm,ocm: fix reverse session shutdown race --- service/ccm/credential_external.go | 24 +++++++++++++++++++----- service/ccm/reverse.go | 5 ++++- service/ocm/credential_external.go | 24 +++++++++++++++++++----- service/ocm/reverse.go | 5 ++++- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 141c893c34..a0350a9fdd 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -50,6 +50,7 @@ type externalCredential struct { reverse bool reverseSession *yamux.Session reverseAccess sync.RWMutex + closed bool reverseContext context.Context reverseCancel context.CancelFunc connectorDialer N.Dialer @@ -542,12 +543,16 @@ func (c *externalCredential) httpTransport() *http.Client { } func (c *externalCredential) close() { + var session *yamux.Session c.reverseAccess.Lock() - if c.reverseCancel != nil { - c.reverseCancel() + if !c.closed { + c.closed = true + if c.reverseCancel != nil { + c.reverseCancel() + } + session = c.reverseSession + c.reverseSession = nil } - session := c.reverseSession - c.reverseSession = nil c.reverseAccess.Unlock() if session != nil { session.Close() @@ -567,14 +572,19 @@ func (c *externalCredential) getReverseSession() *yamux.Session { return c.reverseSession } -func (c *externalCredential) setReverseSession(session *yamux.Session) { +func (c *externalCredential) setReverseSession(session *yamux.Session) bool { c.reverseAccess.Lock() + if c.closed { + c.reverseAccess.Unlock() + return false + } old := c.reverseSession c.reverseSession = session c.reverseAccess.Unlock() if old != nil { old.Close() } + return true } func (c *externalCredential) clearReverseSession(session *yamux.Session) { @@ -593,6 +603,10 @@ func (c *externalCredential) getReverseContext() context.Context { func (c *externalCredential) resetReverseContext() { c.reverseAccess.Lock() + if c.closed { + c.reverseAccess.Unlock() + return + } c.reverseCancel() c.reverseContext, c.reverseCancel = context.WithCancel(context.Background()) c.reverseAccess.Unlock() diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index ae00df79fd..625e55a9dc 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -110,7 +110,10 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { return } - receiverCredential.setReverseSession(session) + if !receiverCredential.setReverseSession(session) { + session.Close() + return + } s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) go func() { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 0d6e6b4b1b..5c42350d7c 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -52,6 +52,7 @@ type externalCredential struct { reverse bool reverseSession *yamux.Session reverseAccess sync.RWMutex + closed bool reverseContext context.Context reverseCancel context.CancelFunc connectorDialer N.Dialer @@ -595,12 +596,16 @@ func (c *externalCredential) ocmGetBaseURL() string { } func (c *externalCredential) close() { + var session *yamux.Session c.reverseAccess.Lock() - if c.reverseCancel != nil { - c.reverseCancel() + if !c.closed { + c.closed = true + if c.reverseCancel != nil { + c.reverseCancel() + } + session = c.reverseSession + c.reverseSession = nil } - session := c.reverseSession - c.reverseSession = nil c.reverseAccess.Unlock() if session != nil { session.Close() @@ -620,14 +625,19 @@ func (c *externalCredential) getReverseSession() *yamux.Session { return c.reverseSession } -func (c *externalCredential) setReverseSession(session *yamux.Session) { +func (c *externalCredential) setReverseSession(session *yamux.Session) bool { c.reverseAccess.Lock() + if c.closed { + c.reverseAccess.Unlock() + return false + } old := c.reverseSession c.reverseSession = session c.reverseAccess.Unlock() if old != nil { old.Close() } + return true } func (c *externalCredential) clearReverseSession(session *yamux.Session) { @@ -646,6 +656,10 @@ func (c *externalCredential) getReverseContext() context.Context { func (c *externalCredential) resetReverseContext() { c.reverseAccess.Lock() + if c.closed { + c.reverseAccess.Unlock() + return + } c.reverseCancel() c.reverseContext, c.reverseCancel = context.WithCancel(context.Background()) c.reverseAccess.Unlock() diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index 25cf017e3e..906778df58 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -110,7 +110,10 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { return } - receiverCredential.setReverseSession(session) + if !receiverCredential.setReverseSession(session) { + session.Close() + return + } s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) go func() { From 22376472d04aeaea52c80fa91c4275d6fbeb756f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 21:42:05 +0800 Subject: [PATCH 21/96] ccm,ocm: fix passive usage update for WebSocket connections WebSocket 101 upgrade responses do not include utilization headers (confirmed via codex CLI source). Rate limit data is delivered exclusively through in-band events (codex.rate_limits and error events with status 429). Previously, updateStateFromHeaders unconditionally bumped lastUpdated even when no utilization headers were found, which suppressed polling and left credential utilization permanently stale during WebSocket sessions. - Only bump lastUpdated when actual utilization data is parsed - Parse in-band codex.rate_limits events to update credential state - Detect in-band 429 error events to markRateLimited - Fix WebSocket 429 retry to update old credential state before retry --- service/ccm/credential_external.go | 9 +- service/ccm/credential_state.go | 9 +- service/ocm/credential_external.go | 9 +- service/ocm/credential_state.go | 9 +- service/ocm/service_websocket.go | 151 ++++++++++++++++++++++------- 5 files changed, 147 insertions(+), 40 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index a0350a9fdd..8a550719b4 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -368,27 +368,34 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization + hadData := false if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { + hadData = true c.state.fiveHourReset = value } if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { value, err := strconv.ParseFloat(utilization, 64) if err == nil { + hadData = true c.state.fiveHourUtilization = value * 100 } } if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { + hadData = true c.state.weeklyReset = value } if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { value, err := strconv.ParseFloat(utilization, 64) if err == nil { + hadData = true c.state.weeklyUtilization = value * 100 } } - c.state.lastUpdated = time.Now() + if hadData { + c.state.lastUpdated = time.Now() + } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") } diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 7880585479..35fb52dcef 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -337,9 +337,11 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization + hadData := false fiveHourResetChanged := false if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { + hadData = true if value.After(c.state.fiveHourReset) { fiveHourResetChanged = true c.state.fiveHourReset = value @@ -348,6 +350,7 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { value, err := strconv.ParseFloat(utilization, 64) if err == nil { + hadData = true newValue := math.Ceil(value * 100) if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged { c.state.fiveHourUtilization = newValue @@ -357,6 +360,7 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { weeklyResetChanged := false if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { + hadData = true if value.After(c.state.weeklyReset) { weeklyResetChanged = true c.state.weeklyReset = value @@ -365,13 +369,16 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { value, err := strconv.ParseFloat(utilization, 64) if err == nil { + hadData = true newValue := math.Ceil(value * 100) if newValue >= c.state.weeklyUtilization || weeklyResetChanged { c.state.weeklyUtilization = newValue } } } - c.state.lastUpdated = time.Now() + if hadData { + c.state.lastUpdated = time.Now() + } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 5c42350d7c..edc369edda 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -390,6 +390,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization + hadData := false activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) if activeLimitIdentifier == "" { @@ -400,6 +401,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { if fiveHourResetAt != "" { value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) if err == nil { + hadData = true c.state.fiveHourReset = time.Unix(value, 0) } } @@ -407,6 +409,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { if fiveHourPercent != "" { value, err := strconv.ParseFloat(fiveHourPercent, 64) if err == nil { + hadData = true c.state.fiveHourUtilization = value } } @@ -415,6 +418,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { if weeklyResetAt != "" { value, err := strconv.ParseInt(weeklyResetAt, 10, 64) if err == nil { + hadData = true c.state.weeklyReset = time.Unix(value, 0) } } @@ -422,10 +426,13 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { if weeklyPercent != "" { value, err := strconv.ParseFloat(weeklyPercent, 64) if err == nil { + hadData = true c.state.weeklyUtilization = value } } - c.state.lastUpdated = time.Now() + if hadData { + c.state.lastUpdated = time.Now() + } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") } diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 81019336b5..2de63b960c 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -339,6 +339,7 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization + hadData := false activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) if activeLimitIdentifier == "" { @@ -350,6 +351,7 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if fiveHourResetAt != "" { value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) if err == nil { + hadData = true newReset := time.Unix(value, 0) if newReset.After(c.state.fiveHourReset) { fiveHourResetChanged = true @@ -361,6 +363,7 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if fiveHourPercent != "" { value, err := strconv.ParseFloat(fiveHourPercent, 64) if err == nil { + hadData = true if value >= c.state.fiveHourUtilization || fiveHourResetChanged { c.state.fiveHourUtilization = value } @@ -372,6 +375,7 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if weeklyResetAt != "" { value, err := strconv.ParseInt(weeklyResetAt, 10, 64) if err == nil { + hadData = true newReset := time.Unix(value, 0) if newReset.After(c.state.weeklyReset) { weeklyResetChanged = true @@ -383,12 +387,15 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if weeklyPercent != "" { value, err := strconv.ParseFloat(weeklyPercent, 64) if err == nil { + hadData = true if value >= c.state.weeklyUtilization || weeklyResetChanged { c.state.weeklyUtilization = value } } } - c.state.lastUpdated = time.Now() + if hadData { + c.state.lastUpdated = time.Now() + } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") } diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 7aa68499cb..d3f2535c05 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/textproto" + "strconv" "strings" "sync" "time" @@ -174,8 +175,8 @@ func (s *Service) handleWebSocket( if statusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders) nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) + selectedCredential.updateStateFromHeaders(upstreamResponseHeaders) if nextCredential == nil { - selectedCredential.updateStateFromHeaders(upstreamResponseHeaders) writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") return } @@ -298,44 +299,27 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite return } - if opCode == ws.OpText && usageTracker != nil { - select { - case model := <-modelChannel: - requestModel = model - default: - } - + if opCode == ws.OpText { var event struct { - Type string `json:"type"` + Type string `json:"type"` + StatusCode int `json:"status_code"` } - if json.Unmarshal(data, &event) == nil && event.Type == "response.completed" { - var streamEvent responses.ResponseStreamEventUnion - if json.Unmarshal(data, &streamEvent) == nil { - completedEvent := streamEvent.AsResponseCompleted() - responseModel := string(completedEvent.Response.Model) - serviceTier := string(completedEvent.Response.ServiceTier) - inputTokens := completedEvent.Response.Usage.InputTokens - outputTokens := completedEvent.Response.Usage.OutputTokens - cachedTokens := completedEvent.Response.Usage.InputTokensDetails.CachedTokens - - if inputTokens > 0 || outputTokens > 0 { - if responseModel == "" { - responseModel = requestModel - } - if responseModel != "" { - contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - inputTokens, - outputTokens, - cachedTokens, - serviceTier, - username, - time.Now(), - weeklyCycleHint, - ) + if json.Unmarshal(data, &event) == nil { + switch event.Type { + case "codex.rate_limits": + s.handleWebSocketRateLimitsEvent(data, selectedCredential) + case "error": + if event.StatusCode == http.StatusTooManyRequests { + s.handleWebSocketErrorRateLimited(data, selectedCredential) + } + case "response.completed": + if usageTracker != nil { + select { + case model := <-modelChannel: + requestModel = model + default: } + s.handleWebSocketResponseCompleted(data, usageTracker, requestModel, username, weeklyCycleHint) } } } @@ -350,3 +334,98 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite } } } + +func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential credential) { + var rateLimitsEvent struct { + RateLimits struct { + Primary *struct { + UsedPercent float64 `json:"used_percent"` + ResetAt int64 `json:"reset_at"` + } `json:"primary"` + Secondary *struct { + UsedPercent float64 `json:"used_percent"` + ResetAt int64 `json:"reset_at"` + } `json:"secondary"` + } `json:"rate_limits"` + LimitName string `json:"limit_name"` + MeteredLimitName string `json:"metered_limit_name"` + } + err := json.Unmarshal(data, &rateLimitsEvent) + if err != nil { + return + } + identifier := rateLimitsEvent.MeteredLimitName + if identifier == "" { + identifier = rateLimitsEvent.LimitName + } + if identifier == "" { + identifier = "codex" + } + identifier = normalizeRateLimitIdentifier(identifier) + + headers := make(http.Header) + headers.Set("x-codex-active-limit", identifier) + if w := rateLimitsEvent.RateLimits.Primary; w != nil { + headers.Set("x-"+identifier+"-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) + if w.ResetAt > 0 { + headers.Set("x-"+identifier+"-primary-reset-at", strconv.FormatInt(w.ResetAt, 10)) + } + } + if w := rateLimitsEvent.RateLimits.Secondary; w != nil { + headers.Set("x-"+identifier+"-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) + if w.ResetAt > 0 { + headers.Set("x-"+identifier+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10)) + } + } + selectedCredential.updateStateFromHeaders(headers) +} + +func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredential credential) { + var errorEvent struct { + Headers map[string]string `json:"headers"` + } + err := json.Unmarshal(data, &errorEvent) + if err != nil { + return + } + headers := make(http.Header) + for key, value := range errorEvent.Headers { + headers.Set(key, value) + } + selectedCredential.updateStateFromHeaders(headers) + resetAt := parseOCMRateLimitResetFromHeaders(headers) + selectedCredential.markRateLimited(resetAt) +} + +func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) { + var streamEvent responses.ResponseStreamEventUnion + if json.Unmarshal(data, &streamEvent) != nil { + return + } + completedEvent := streamEvent.AsResponseCompleted() + responseModel := string(completedEvent.Response.Model) + serviceTier := string(completedEvent.Response.ServiceTier) + inputTokens := completedEvent.Response.Usage.InputTokens + outputTokens := completedEvent.Response.Usage.OutputTokens + cachedTokens := completedEvent.Response.Usage.InputTokensDetails.CachedTokens + + if inputTokens > 0 || outputTokens > 0 { + if responseModel == "" { + responseModel = requestModel + } + if responseModel != "" { + contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + inputTokens, + outputTokens, + cachedTokens, + serviceTier, + username, + time.Now(), + weeklyCycleHint, + ) + } + } +} From 1993da37359503d5b9238aeffe86e4c06be0981a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 21:54:47 +0800 Subject: [PATCH 22/96] ocm: rewrite codex.rate_limits WebSocket events for external users The HTTP path rewrites utilization headers for external users via rewriteResponseHeadersForExternalUser to show aggregated values. The WebSocket upgrade headers were also rewritten, but in-band codex.rate_limits events were forwarded unmodified, leaking per-credential utilization to external users. --- service/ocm/service_websocket.go | 42 ++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index d3f2535c05..b35733b698 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -249,7 +249,7 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) }() waitGroup.Wait() } @@ -287,7 +287,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo } } -func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { @@ -308,6 +308,12 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite switch event.Type { case "codex.rate_limits": s.handleWebSocketRateLimitsEvent(data, selectedCredential) + if userConfig != nil && userConfig.ExternalCredential != "" { + rewritten, rewriteErr := s.rewriteWebSocketRateLimitsForExternalUser(data, provider, userConfig) + if rewriteErr == nil { + data = rewritten + } + } case "error": if event.StatusCode == http.StatusTooManyRequests { s.handleWebSocketErrorRateLimited(data, selectedCredential) @@ -397,6 +403,38 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia selectedCredential.markRateLimited(resetAt) } +func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) { + var event struct { + Type string `json:"type"` + RateLimits struct { + Primary *struct { + UsedPercent float64 `json:"used_percent"` + WindowMinutes int64 `json:"window_minutes,omitempty"` + ResetAt int64 `json:"reset_at,omitempty"` + } `json:"primary,omitempty"` + Secondary *struct { + UsedPercent float64 `json:"used_percent"` + WindowMinutes int64 `json:"window_minutes,omitempty"` + ResetAt int64 `json:"reset_at,omitempty"` + } `json:"secondary,omitempty"` + } `json:"rate_limits"` + LimitName string `json:"limit_name,omitempty"` + MeteredLimitName string `json:"metered_limit_name,omitempty"` + } + err := json.Unmarshal(data, &event) + if err != nil { + return nil, err + } + averageFiveHour, averageWeekly := s.computeAggregatedUtilization(provider, userConfig) + if event.RateLimits.Primary != nil { + event.RateLimits.Primary.UsedPercent = averageFiveHour + } + if event.RateLimits.Secondary != nil { + event.RateLimits.Secondary.UsedPercent = averageWeekly + } + return json.Marshal(event) +} + func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) { var streamEvent responses.ResponseStreamEventUnion if json.Unmarshal(data, &streamEvent) != nil { From df6e47f5f110e9c503d361c43fa71cb64c9142c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 22:24:08 +0800 Subject: [PATCH 23/96] ocm: preserve websocket rate limit event fields --- service/ocm/service_websocket.go | 73 +++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index b35733b698..8cbe3056c9 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -404,37 +404,68 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia } func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) { - var event struct { - Type string `json:"type"` - RateLimits struct { - Primary *struct { - UsedPercent float64 `json:"used_percent"` - WindowMinutes int64 `json:"window_minutes,omitempty"` - ResetAt int64 `json:"reset_at,omitempty"` - } `json:"primary,omitempty"` - Secondary *struct { - UsedPercent float64 `json:"used_percent"` - WindowMinutes int64 `json:"window_minutes,omitempty"` - ResetAt int64 `json:"reset_at,omitempty"` - } `json:"secondary,omitempty"` - } `json:"rate_limits"` - LimitName string `json:"limit_name,omitempty"` - MeteredLimitName string `json:"metered_limit_name,omitempty"` - } + var event map[string]json.RawMessage err := json.Unmarshal(data, &event) if err != nil { return nil, err } + + rateLimitsData, exists := event["rate_limits"] + if !exists || len(rateLimitsData) == 0 || string(rateLimitsData) == "null" { + return data, nil + } + + var rateLimits map[string]json.RawMessage + err = json.Unmarshal(rateLimitsData, &rateLimits) + if err != nil { + return nil, err + } + averageFiveHour, averageWeekly := s.computeAggregatedUtilization(provider, userConfig) - if event.RateLimits.Primary != nil { - event.RateLimits.Primary.UsedPercent = averageFiveHour + + primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour) + if err != nil { + return nil, err } - if event.RateLimits.Secondary != nil { - event.RateLimits.Secondary.UsedPercent = averageWeekly + if primaryData != nil { + rateLimits["primary"] = primaryData + } + + secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], averageWeekly) + if err != nil { + return nil, err } + if secondaryData != nil { + rateLimits["secondary"] = secondaryData + } + + event["rate_limits"], err = json.Marshal(rateLimits) + if err != nil { + return nil, err + } + return json.Marshal(event) } +func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) (json.RawMessage, error) { + if len(data) == 0 || string(data) == "null" { + return nil, nil + } + + var window map[string]json.RawMessage + err := json.Unmarshal(data, &window) + if err != nil { + return nil, err + } + + window["used_percent"], err = json.Marshal(usedPercent) + if err != nil { + return nil, err + } + + return json.Marshal(window) +} + func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) { var streamEvent responses.ResponseStreamEventUnion if json.Unmarshal(data, &streamEvent) != nil { From 7f93c76b1aed3c503ee77a9b1cfae18a18524b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 23:13:44 +0800 Subject: [PATCH 24/96] ccm,ocm: add limit options and fix aggregated utilization scaling Add limit_5h and limit_weekly options as alternatives to reserve_5h and reserve_weekly for capping credential utilization. The two are mutually exclusive per window. Fix computeAggregatedUtilization to scale per-credential utilization relative to each credential's cap before averaging, so external users see correct available capacity regardless of per-credential caps. Fix pickLeastUsed to compare remaining capacity (cap - utilization) instead of raw utilization, ensuring fair comparison across credentials with different caps. --- option/ccm.go | 2 + option/ocm.go | 2 + service/ccm/credential_external.go | 20 ++++++- service/ccm/credential_state.go | 92 ++++++++++++++++++++++++++---- service/ccm/service.go | 49 +++++++++++++--- service/ocm/credential_external.go | 20 ++++++- service/ocm/credential_state.go | 92 ++++++++++++++++++++++++++---- service/ocm/service.go | 45 +++++++++++---- 8 files changed, 279 insertions(+), 43 deletions(-) diff --git a/option/ccm.go b/option/ccm.go index ae80cc64b2..dd55a4ba4e 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -86,6 +86,8 @@ type CCMDefaultCredentialOptions struct { Detour string `json:"detour,omitempty"` Reserve5h uint8 `json:"reserve_5h"` ReserveWeekly uint8 `json:"reserve_weekly"` + Limit5h uint8 `json:"limit_5h,omitempty"` + LimitWeekly uint8 `json:"limit_weekly,omitempty"` } type CCMBalancerCredentialOptions struct { diff --git a/option/ocm.go b/option/ocm.go index 20cafee123..e508abae7e 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -86,6 +86,8 @@ type OCMDefaultCredentialOptions struct { Detour string `json:"detour,omitempty"` Reserve5h uint8 `json:"reserve_5h"` ReserveWeekly uint8 `json:"reserve_weekly"` + Limit5h uint8 `json:"limit_5h,omitempty"` + LimitWeekly uint8 `json:"limit_weekly,omitempty"` } type OCMBalancerCredentialOptions struct { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 8a550719b4..74ce6617ec 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -271,6 +271,14 @@ func (c *externalCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *externalCredential) fiveHourCap() float64 { + return 100 +} + +func (c *externalCredential) weeklyCap() float64 { + return 100 +} + func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) c.stateMutex.Lock() @@ -397,7 +405,11 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.lastUpdated = time.Now() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() @@ -507,7 +519,11 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.state.hardRateLimited = false } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 35fb52dcef..d681c222bb 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -89,6 +89,8 @@ type defaultCredential struct { watcherAccess sync.Mutex reserve5h uint8 reserveWeekly uint8 + cap5h float64 + capWeekly float64 usageTracker *AggregatedUsage httpClient *http.Client logger log.ContextLogger @@ -129,6 +131,8 @@ type credential interface { isExternal() bool fiveHourUtilization() float64 weeklyUtilization() float64 + fiveHourCap() float64 + weeklyCap() float64 markRateLimited(resetAt time.Time) earliestReset() time.Time unavailableError() error @@ -180,6 +184,18 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef if reserveWeekly == 0 { reserveWeekly = 10 } + var cap5h float64 + if options.Limit5h > 0 { + cap5h = float64(options.Limit5h) + } else { + cap5h = float64(100 - reserve5h) + } + var capWeekly float64 + if options.LimitWeekly > 0 { + capWeekly = float64(options.LimitWeekly) + } else { + capWeekly = float64(100 - reserveWeekly) + } requestContext, cancelRequests := context.WithCancel(context.Background()) credential := &defaultCredential{ tag: tag, @@ -187,6 +203,8 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef credentialPath: options.CredentialPath, reserve5h: reserve5h, reserveWeekly: reserveWeekly, + cap5h: cap5h, + capWeekly: capWeekly, httpClient: httpClient, logger: logger, requestContext: requestContext, @@ -380,7 +398,11 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { c.state.lastUpdated = time.Now() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() @@ -429,10 +451,10 @@ func (c *defaultCredential) isUsable() bool { } func (c *defaultCredential) checkReservesLocked() bool { - if c.state.fiveHourUtilization >= float64(100-c.reserve5h) { + if c.state.fiveHourUtilization >= c.cap5h { return false } - if c.state.weeklyUtilization >= float64(100-c.reserveWeekly) { + if c.state.weeklyUtilization >= c.capWeekly { return false } return true @@ -633,7 +655,11 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { c.state.hardRateLimited = false } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() @@ -672,6 +698,14 @@ func (c *defaultCredential) fiveHourUtilization() float64 { return c.state.fiveHourUtilization } +func (c *defaultCredential) fiveHourCap() float64 { + return c.cap5h +} + +func (c *defaultCredential) weeklyCap() float64 { + return c.capWeekly +} + func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { return c.usageTracker } @@ -736,10 +770,12 @@ type credentialProvider interface { // singleCredentialProvider wraps a single credential (legacy or single default). type singleCredentialProvider struct { - cred credential + cred credential + sessionAccess sync.RWMutex + sessions map[string]time.Time } -func (p *singleCredentialProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { +func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { if filter != nil && !filter(p.cred) { return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } @@ -749,7 +785,20 @@ func (p *singleCredentialProvider) selectCredential(_ string, filter func(creden if !p.cred.isUsable() { return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") } - return p.cred, false, nil + var isNew bool + if sessionID != "" { + p.sessionAccess.Lock() + if p.sessions == nil { + p.sessions = make(map[string]time.Time) + } + _, exists := p.sessions[sessionID] + if !exists { + p.sessions[sessionID] = time.Now() + isNew = true + } + p.sessionAccess.Unlock() + } + return p.cred, isNew, nil } func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential { @@ -758,6 +807,15 @@ func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, rese } func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, createdAt := range p.sessions { + if now.Sub(createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { p.cred.pollUsage(ctx) } @@ -861,7 +919,7 @@ func (p *balancerProvider) pickCredential(filter func(credential) bool) credenti func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { var best credential - bestUtilization := float64(101) + bestRemaining := float64(-1) for _, cred := range p.credentials { if filter != nil && !filter(cred) { continue @@ -869,9 +927,9 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia if !cred.isUsable() { continue } - utilization := cred.weeklyUtilization() - if utilization < bestUtilization { - bestUtilization = utilization + remaining := cred.weeklyCap() - cred.weeklyUtilization() + if remaining > bestRemaining { + bestRemaining = remaining best = cred } } @@ -1140,6 +1198,18 @@ func validateCCMOptions(options option.CCMServiceOptions) error { if cred.DefaultOptions.ReserveWeekly > 99 { return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") } + if cred.DefaultOptions.Limit5h > 100 { + return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") + } + if cred.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") + } + if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { + return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") + } + if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") + } } if cred.Type == "external" { if cred.ExternalOptions.Token == "" { diff --git a/service/ccm/service.go b/service/ccm/service.go index 2e7685e71d..6f1500fb8c 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -306,6 +306,15 @@ func isExtendedContextRequest(betaHeader string) bool { return false } +func isFastModeRequest(betaHeader string) bool { + for _, feature := range strings.Split(betaHeader, ",") { + if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") { + return true + } + } + return false +} + func detectContextWindow(betaHeader string, totalInputTokens int64) int { if totalInputTokens > premiumContextThreshold { if isExtendedContextRequest(betaHeader) { @@ -414,6 +423,14 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + if isFastModeRequest(anthropicBetaHeader) { + if _, isSingle := provider.(*singleCredentialProvider); !isSingle { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "fast mode requests will consume Extra usage, please use a default credential directly") + return + } + } + var credentialFilter func(credential) bool if userConfig != nil && !userConfig.AllowExternalUsage { credentialFilter = func(c credential) bool { return !c.isExternal() } @@ -424,13 +441,23 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) return } + var logParts []any if isNew { - if username != "" { - s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID, " by user ", username) - } else { - s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID) - } + logParts = append(logParts, "assigned credential ") + } else { + logParts = append(logParts, "credential ") } + logParts = append(logParts, selectedCredential.tagName()) + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if isNew && username != "" { + logParts = append(logParts, " by user ", username) + } + if requestModel != "" { + logParts = append(logParts, ", model=", requestModel) + } + s.logger.Debug(logParts...) if isExtendedContextRequest(anthropicBetaHeader) && selectedCredential.isExternal() { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", @@ -771,8 +798,16 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user if !userConfig.AllowExternalUsage && cred.isExternal() { continue } - totalFiveHour += cred.fiveHourUtilization() - totalWeekly += cred.weeklyUtilization() + scaledFiveHour := cred.fiveHourUtilization() / cred.fiveHourCap() * 100 + if scaledFiveHour > 100 { + scaledFiveHour = 100 + } + scaledWeekly := cred.weeklyUtilization() / cred.weeklyCap() * 100 + if scaledWeekly > 100 { + scaledWeekly = 100 + } + totalFiveHour += scaledFiveHour + totalWeekly += scaledWeekly count++ } if count == 0 { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index edc369edda..2864c684e4 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -293,6 +293,14 @@ func (c *externalCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *externalCredential) fiveHourCap() float64 { + return 100 +} + +func (c *externalCredential) weeklyCap() float64 { + return 100 +} + func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) c.stateMutex.Lock() @@ -434,7 +442,11 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.lastUpdated = time.Now() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() @@ -544,7 +556,11 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.state.hardRateLimited = false } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 2de63b960c..b3564d9b03 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -89,6 +89,8 @@ type defaultCredential struct { watcherAccess sync.Mutex reserve5h uint8 reserveWeekly uint8 + cap5h float64 + capWeekly float64 usageTracker *AggregatedUsage dialer N.Dialer httpClient *http.Client @@ -130,6 +132,8 @@ type credential interface { isExternal() bool fiveHourUtilization() float64 weeklyUtilization() float64 + fiveHourCap() float64 + weeklyCap() float64 markRateLimited(resetAt time.Time) earliestReset() time.Time unavailableError() error @@ -188,6 +192,18 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef if reserveWeekly == 0 { reserveWeekly = 10 } + var cap5h float64 + if options.Limit5h > 0 { + cap5h = float64(options.Limit5h) + } else { + cap5h = float64(100 - reserve5h) + } + var capWeekly float64 + if options.LimitWeekly > 0 { + capWeekly = float64(options.LimitWeekly) + } else { + capWeekly = float64(100 - reserveWeekly) + } requestContext, cancelRequests := context.WithCancel(context.Background()) credential := &defaultCredential{ tag: tag, @@ -195,6 +211,8 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef credentialPath: options.CredentialPath, reserve5h: reserve5h, reserveWeekly: reserveWeekly, + cap5h: cap5h, + capWeekly: capWeekly, dialer: credentialDialer, httpClient: httpClient, logger: logger, @@ -397,7 +415,11 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { c.state.lastUpdated = time.Now() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() @@ -446,10 +468,10 @@ func (c *defaultCredential) isUsable() bool { } func (c *defaultCredential) checkReservesLocked() bool { - if c.state.fiveHourUtilization >= float64(100-c.reserve5h) { + if c.state.fiveHourUtilization >= c.cap5h { return false } - if c.state.weeklyUtilization >= float64(100-c.reserveWeekly) { + if c.state.weeklyUtilization >= c.capWeekly { return false } return true @@ -671,7 +693,11 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { c.state.hardRateLimited = false } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() @@ -714,6 +740,14 @@ func (c *defaultCredential) fiveHourUtilization() float64 { return c.state.fiveHourUtilization } +func (c *defaultCredential) fiveHourCap() float64 { + return c.cap5h +} + +func (c *defaultCredential) weeklyCap() float64 { + return c.capWeekly +} + func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { return c.usageTracker } @@ -796,10 +830,12 @@ type credentialProvider interface { } type singleCredentialProvider struct { - cred credential + cred credential + sessionAccess sync.RWMutex + sessions map[string]time.Time } -func (p *singleCredentialProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { +func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { if filter != nil && !filter(p.cred) { return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } @@ -809,7 +845,20 @@ func (p *singleCredentialProvider) selectCredential(_ string, filter func(creden if !p.cred.isUsable() { return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") } - return p.cred, false, nil + var isNew bool + if sessionID != "" { + p.sessionAccess.Lock() + if p.sessions == nil { + p.sessions = make(map[string]time.Time) + } + _, exists := p.sessions[sessionID] + if !exists { + p.sessions[sessionID] = time.Now() + isNew = true + } + p.sessionAccess.Unlock() + } + return p.cred, isNew, nil } func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential { @@ -818,6 +867,15 @@ func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, rese } func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, createdAt := range p.sessions { + if now.Sub(createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { p.cred.pollUsage(ctx) } @@ -924,7 +982,7 @@ func (p *balancerProvider) pickCredential(filter func(credential) bool) credenti func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { var best credential - bestUtilization := float64(101) + bestRemaining := float64(-1) for _, cred := range p.credentials { if filter != nil && !filter(cred) { continue @@ -935,9 +993,9 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia if !cred.isUsable() { continue } - utilization := cred.weeklyUtilization() - if utilization < bestUtilization { - bestUtilization = utilization + remaining := cred.weeklyCap() - cred.weeklyUtilization() + if remaining > bestRemaining { + bestRemaining = remaining best = cred } } @@ -1207,6 +1265,18 @@ func validateOCMOptions(options option.OCMServiceOptions) error { if cred.DefaultOptions.ReserveWeekly > 99 { return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") } + if cred.DefaultOptions.Limit5h > 100 { + return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") + } + if cred.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") + } + if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { + return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") + } + if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") + } } if cred.Type == "external" { if cred.ExternalOptions.Token == "" { diff --git a/service/ocm/service.go b/service/ocm/service.go index 3868725ff2..751d03f2bc 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -435,13 +435,6 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) return } - if isNew { - if username != "" { - s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID, " by user ", username) - } else { - s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID) - } - } if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter) @@ -465,6 +458,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Read body for model extraction and retry buffer when JSON replay is useful. var bodyBytes []byte var requestModel string + var requestServiceTier string if r.Body != nil && (shouldTrackUsage || canRetryRequest) { mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) @@ -476,15 +470,38 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } var request struct { - Model string `json:"model"` + Model string `json:"model"` + ServiceTier string `json:"service_tier"` } if json.Unmarshal(bodyBytes, &request) == nil { requestModel = request.Model + requestServiceTier = request.ServiceTier } r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } } + var logParts []any + if isNew { + logParts = append(logParts, "assigned credential ") + } else { + logParts = append(logParts, "credential ") + } + logParts = append(logParts, selectedCredential.tagName()) + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if isNew && username != "" { + logParts = append(logParts, " by user ", username) + } + if requestModel != "" { + logParts = append(logParts, ", model=", requestModel) + } + if requestServiceTier == "priority" { + logParts = append(logParts, ", fast") + } + s.logger.Debug(logParts...) + requestContext := selectedCredential.wrapRequestContext(r.Context()) defer func() { requestContext.cancelRequest() @@ -841,8 +858,16 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user if !userConfig.AllowExternalUsage && cred.isExternal() { continue } - totalFiveHour += cred.fiveHourUtilization() - totalWeekly += cred.weeklyUtilization() + scaledFiveHour := cred.fiveHourUtilization() / cred.fiveHourCap() * 100 + if scaledFiveHour > 100 { + scaledFiveHour = 100 + } + scaledWeekly := cred.weeklyUtilization() / cred.weeklyCap() * 100 + if scaledWeekly > 100 { + scaledWeekly = 100 + } + totalFiveHour += scaledFiveHour + totalWeekly += scaledWeekly count++ } if count == 0 { From ce543a935fa7d5ea8cbec6a28fbdb08458519fd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 23:29:06 +0800 Subject: [PATCH 25/96] ccm,ocm: fix reserveWeekly default and remove dead reserve fields --- service/ccm/credential_state.go | 6 +----- service/ocm/credential_state.go | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index d681c222bb..a8f7d2f62c 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -87,8 +87,6 @@ type defaultCredential struct { pollAccess sync.Mutex reloadAccess sync.Mutex watcherAccess sync.Mutex - reserve5h uint8 - reserveWeekly uint8 cap5h float64 capWeekly float64 usageTracker *AggregatedUsage @@ -182,7 +180,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef } reserveWeekly := options.ReserveWeekly if reserveWeekly == 0 { - reserveWeekly = 10 + reserveWeekly = 1 } var cap5h float64 if options.Limit5h > 0 { @@ -201,8 +199,6 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef tag: tag, serviceContext: ctx, credentialPath: options.CredentialPath, - reserve5h: reserve5h, - reserveWeekly: reserveWeekly, cap5h: cap5h, capWeekly: capWeekly, httpClient: httpClient, diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index b3564d9b03..8dacf40758 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -87,8 +87,6 @@ type defaultCredential struct { pollAccess sync.Mutex reloadAccess sync.Mutex watcherAccess sync.Mutex - reserve5h uint8 - reserveWeekly uint8 cap5h float64 capWeekly float64 usageTracker *AggregatedUsage @@ -190,7 +188,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef } reserveWeekly := options.ReserveWeekly if reserveWeekly == 0 { - reserveWeekly = 10 + reserveWeekly = 1 } var cap5h float64 if options.Limit5h > 0 { @@ -209,8 +207,6 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef tag: tag, serviceContext: ctx, credentialPath: options.CredentialPath, - reserve5h: reserve5h, - reserveWeekly: reserveWeekly, cap5h: cap5h, capWeekly: capWeekly, dialer: credentialDialer, From a09174a9a272c986d616c3e0f20b5d143689876b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 23:31:19 +0800 Subject: [PATCH 26/96] service/ccm: reject fast-mode external credentials --- service/ccm/service.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/service/ccm/service.go b/service/ccm/service.go index 6f1500fb8c..b41eb6da2a 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -464,6 +464,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { "extended context (1m) requests cannot be proxied through external credentials") return } + if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "fast mode requests cannot be proxied through external credentials") + return + } requestContext := selectedCredential.wrapRequestContext(r.Context()) defer func() { From ee65b375cba4e8ccf970b9b17c2ae1100ced3d9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 13:48:23 +0800 Subject: [PATCH 27/96] service/ccm: allow extended context (1m) for all credentials 1m context is now available to all subscribers and no longer consumes Extra Usage. --- service/ccm/service.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/service/ccm/service.go b/service/ccm/service.go index b41eb6da2a..85d333d4bc 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -415,14 +415,6 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { provider.pollIfStale(s.ctx) anthropicBetaHeader := r.Header.Get("anthropic-beta") - if isExtendedContextRequest(anthropicBetaHeader) { - if _, isSingle := provider.(*singleCredentialProvider); !isSingle { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "extended context (1m) requests will consume Extra usage, please use a default credential directly") - return - } - } - if isFastModeRequest(anthropicBetaHeader) { if _, isSingle := provider.(*singleCredentialProvider); !isSingle { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", @@ -459,11 +451,6 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } s.logger.Debug(logParts...) - if isExtendedContextRequest(anthropicBetaHeader) && selectedCredential.isExternal() { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "extended context (1m) requests cannot be proxied through external credentials") - return - } if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "fast mode requests cannot be proxied through external credentials") From 162827250729d249be788de4ae0b8be519145ade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 14:05:10 +0800 Subject: [PATCH 28/96] ccm,ocm: mark credentials unusable on usage poll failure and trigger poll on upstream error --- service/ccm/credential_external.go | 41 ++++++++++++++++++++---------- service/ccm/credential_state.go | 37 ++++++++++++++++++--------- service/ccm/service.go | 1 + service/ocm/credential_external.go | 41 ++++++++++++++++++++---------- service/ocm/credential_state.go | 37 ++++++++++++++++++--------- service/ocm/service.go | 1 + 6 files changed, 108 insertions(+), 50 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 74ce6617ec..d6eb4c1026 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -239,6 +239,10 @@ func (c *externalCredential) isUsable() bool { return false } c.stateMutex.RLock() + if c.state.consecutivePollFailures > 0 { + c.stateMutex.RUnlock() + return false + } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { c.stateMutex.RUnlock() @@ -402,6 +406,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } } if hadData { + c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { @@ -419,7 +424,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } func (c *externalCredential) checkTransitionLocked() bool { - unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 + unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 if unusable && !c.interrupted { c.interrupted = true return true @@ -479,19 +484,24 @@ func (c *externalCredential) pollUsage(ctx context.Context) { }) if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - c.stateMutex.Unlock() + c.incrementPollFailures() return } defer response.Body.Close() if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - c.stateMutex.Unlock() c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + // 404 means the remote does not have a status endpoint yet; + // usage will be updated passively from response headers. + if response.StatusCode == http.StatusNotFound { + c.stateMutex.Lock() + c.state.consecutivePollFailures = 0 + c.checkTransitionLocked() + c.stateMutex.Unlock() + } else { + c.incrementPollFailures() + } return } @@ -501,10 +511,8 @@ func (c *externalCredential) pollUsage(ctx context.Context) { } err = json.NewDecoder(response.Body).Decode(&statusResponse) if err != nil { - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - c.stateMutex.Unlock() c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.incrementPollFailures() return } @@ -551,10 +559,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati if failures <= 0 { return baseInterval } - if failures > 4 { - failures = 4 + return failedPollRetryInterval +} + +func (c *externalCredential) incrementPollFailures() { + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() } - return baseInterval * time.Duration(1< 0 { + c.stateMutex.RUnlock() + return false + } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { c.stateMutex.RUnlock() @@ -459,7 +467,7 @@ func (c *defaultCredential) checkReservesLocked() bool { // checkTransitionLocked detects usable→unusable transition. // Must be called with stateMutex write lock held. func (c *defaultCredential) checkTransitionLocked() bool { - unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() + unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 if unusable && !c.interrupted { c.interrupted = true return true @@ -534,6 +542,16 @@ func (c *defaultCredential) markUsagePollAttempted() { c.state.lastUpdated = time.Now() } +func (c *defaultCredential) incrementPollFailures() { + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { c.stateMutex.RLock() failures := c.state.consecutivePollFailures @@ -541,10 +559,7 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio if failures <= 0 { return baseInterval } - if failures > 4 { - failures = 4 - } - return baseInterval * time.Duration(1< 0 { + c.stateMutex.RUnlock() + return false + } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { c.stateMutex.RUnlock() @@ -439,6 +443,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } } if hadData { + c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { @@ -456,7 +461,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } func (c *externalCredential) checkTransitionLocked() bool { - unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 + unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 if unusable && !c.interrupted { c.interrupted = true return true @@ -516,19 +521,24 @@ func (c *externalCredential) pollUsage(ctx context.Context) { }) if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - c.stateMutex.Unlock() + c.incrementPollFailures() return } defer response.Body.Close() if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - c.stateMutex.Unlock() c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + // 404 means the remote does not have a status endpoint yet; + // usage will be updated passively from response headers. + if response.StatusCode == http.StatusNotFound { + c.stateMutex.Lock() + c.state.consecutivePollFailures = 0 + c.checkTransitionLocked() + c.stateMutex.Unlock() + } else { + c.incrementPollFailures() + } return } @@ -538,10 +548,8 @@ func (c *externalCredential) pollUsage(ctx context.Context) { } err = json.NewDecoder(response.Body).Decode(&statusResponse) if err != nil { - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - c.stateMutex.Unlock() c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.incrementPollFailures() return } @@ -588,10 +596,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati if failures <= 0 { return baseInterval } - if failures > 4 { - failures = 4 + return failedPollRetryInterval +} + +func (c *externalCredential) incrementPollFailures() { + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() } - return baseInterval * time.Duration(1< 0 { + c.stateMutex.RUnlock() + return false + } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { c.stateMutex.RUnlock() @@ -476,7 +484,7 @@ func (c *defaultCredential) checkReservesLocked() bool { // checkTransitionLocked detects usable→unusable transition. // Must be called with stateMutex write lock held. func (c *defaultCredential) checkTransitionLocked() bool { - unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() + unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 if unusable && !c.interrupted { c.interrupted = true return true @@ -551,6 +559,16 @@ func (c *defaultCredential) markUsagePollAttempted() { c.state.lastUpdated = time.Now() } +func (c *defaultCredential) incrementPollFailures() { + c.stateMutex.Lock() + c.state.consecutivePollFailures++ + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { c.stateMutex.RLock() failures := c.state.consecutivePollFailures @@ -558,10 +576,7 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio if failures <= 0 { return baseInterval } - if failures > 4 { - failures = 4 - } - return baseInterval * time.Duration(1< Date: Sat, 14 Mar 2026 14:14:34 +0800 Subject: [PATCH 29/96] service/ccm: only log new credential assignments and show context window in model --- service/ccm/service.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/service/ccm/service.go b/service/ccm/service.go index 7680083c42..4fd880f8ac 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -433,23 +433,23 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) return } - var logParts []any if isNew { - logParts = append(logParts, "assigned credential ") - } else { - logParts = append(logParts, "credential ") - } - logParts = append(logParts, selectedCredential.tagName()) - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if isNew && username != "" { - logParts = append(logParts, " by user ", username) - } - if requestModel != "" { - logParts = append(logParts, ", model=", requestModel) + logParts := []any{"assigned credential ", selectedCredential.tagName()} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + if requestModel != "" { + modelDisplay := requestModel + if isExtendedContextRequest(anthropicBetaHeader) { + modelDisplay += "[1m]" + } + logParts = append(logParts, ", model=", modelDisplay) + } + s.logger.Debug(logParts...) } - s.logger.Debug(logParts...) if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", From 25a9e4ce5916f01ed2f6c5a2e78b6e284badd237 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 14:44:24 +0800 Subject: [PATCH 30/96] service/ocm: only log new credential assignments and add websocket logging --- service/ocm/service.go | 34 ++++++++++++++------------------ service/ocm/service_websocket.go | 23 +++++++++++++++++---- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/service/ocm/service.go b/service/ocm/service.go index db2467a284..9dded47406 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -437,7 +437,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter) + s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew) return } @@ -481,26 +481,22 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - var logParts []any if isNew { - logParts = append(logParts, "assigned credential ") - } else { - logParts = append(logParts, "credential ") - } - logParts = append(logParts, selectedCredential.tagName()) - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if isNew && username != "" { - logParts = append(logParts, " by user ", username) - } - if requestModel != "" { - logParts = append(logParts, ", model=", requestModel) - } - if requestServiceTier == "priority" { - logParts = append(logParts, ", fast") + logParts := []any{"assigned credential ", selectedCredential.tagName()} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + if requestModel != "" { + logParts = append(logParts, ", model=", requestModel) + } + if requestServiceTier == "priority" { + logParts = append(logParts, ", fast") + } + s.logger.Debug(logParts...) } - s.logger.Debug(logParts...) requestContext := selectedCredential.wrapRequestContext(r.Context()) defer func() { diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 8cbe3056c9..f348f7fa44 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -91,7 +91,19 @@ func (s *Service) handleWebSocket( provider credentialProvider, selectedCredential credential, credentialFilter func(credential) bool, + isNew bool, ) { + if isNew { + logParts := []any{"assigned credential ", selectedCredential.tagName()} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + s.logger.Debug(logParts...) + } + var ( err error upstreamConn net.Conn @@ -264,15 +276,18 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo return } - if opCode == ws.OpText && selectedCredential.usageTrackerOrNil() != nil { + if opCode == ws.OpText { var request struct { Type string `json:"type"` Model string `json:"model"` } if json.Unmarshal(data, &request) == nil && request.Type == "response.create" && request.Model != "" { - select { - case modelChannel <- request.Model: - default: + s.logger.Debug("model=", request.Model) + if selectedCredential.usageTrackerOrNil() != nil { + select { + case modelChannel <- request.Model: + default: + } } } } From 8984b45dedf9d0d7813489938bbd4f371ca2ee5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 15:10:13 +0800 Subject: [PATCH 31/96] ccm,ocm: improve balancer least_used with plan-weighted scoring and reset urgency Scale remaining capacity by plan weight (Pro=1, Max 5x=5, Max 20x=10 for CCM; Plus=1, Pro=10 for OCM) so higher-tier accounts contribute proportionally more. Factor in weekly reset proximity so credentials about to reset are preferred ("use it or lose it"). Auto-detect plan weight from subscriptionType + rateLimitTier (CCM) or plan_type (OCM). Fetch /api/oauth/profile when rateLimitTier is missing from the credential file. External credentials accept a manual plan_weight option. --- option/ccm.go | 1 + option/ocm.go | 1 + service/ccm/credential.go | 2 + service/ccm/credential_external.go | 55 +++++++++----- service/ccm/credential_file.go | 2 + service/ccm/credential_state.go | 117 ++++++++++++++++++++++++++++- service/ocm/credential_external.go | 57 +++++++++----- service/ocm/credential_state.go | 43 ++++++++++- 8 files changed, 233 insertions(+), 45 deletions(-) diff --git a/option/ccm.go b/option/ccm.go index dd55a4ba4e..b4be72ea76 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -102,6 +102,7 @@ type CCMExternalCredentialOptions struct { Token string `json:"token"` Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` + PlanWeight float64 `json:"plan_weight,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/option/ocm.go b/option/ocm.go index e508abae7e..0f364821f9 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -102,6 +102,7 @@ type OCMExternalCredentialOptions struct { Token string `json:"token"` Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` + PlanWeight float64 `json:"plan_weight,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 75ae62f972..8bfd27c23d 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -133,6 +133,7 @@ type oauthCredentials struct { ExpiresAt int64 `json:"expiresAt"` Scopes []string `json:"scopes,omitempty"` SubscriptionType string `json:"subscriptionType,omitempty"` + RateLimitTier string `json:"rateLimitTier,omitempty"` IsMax bool `json:"isMax,omitempty"` } @@ -219,5 +220,6 @@ func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { left.ExpiresAt == right.ExpiresAt && slices.Equal(left.Scopes, right.Scopes) && left.SubscriptionType == right.SubscriptionType && + left.RateLimitTier == right.RateLimitTier && left.IsMax == right.IsMax } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index d6eb4c1026..807a06fe8d 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -29,16 +29,17 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + configuredPlanWeight float64 + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -112,16 +113,22 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) + configuredPlanWeight := options.PlanWeight + if configuredPlanWeight <= 0 { + configuredPlanWeight = 1 + } + cred := &externalCredential{ - tag: tag, - token: options.Token, - pollInterval: pollInterval, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - reverse: options.Reverse, - reverseContext: reverseContext, - reverseCancel: reverseCancel, + tag: tag, + token: options.Token, + pollInterval: pollInterval, + configuredPlanWeight: configuredPlanWeight, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, } if options.URL == "" { @@ -283,6 +290,16 @@ func (c *externalCredential) weeklyCap() float64 { return 100 } +func (c *externalCredential) planWeight() float64 { + return c.configuredPlanWeight +} + +func (c *externalCredential) weeklyResetTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyReset +} + func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) c.stateMutex.Lock() diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index da13fae10e..eba9207268 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -114,6 +114,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.state.unavailable = false c.state.lastCredentialLoadError = "" c.state.accountType = credentials.SubscriptionType + c.state.rateLimitTier = credentials.RateLimitTier c.checkTransitionLocked() c.stateMutex.Unlock() @@ -130,6 +131,7 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error { c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() c.state.accountType = "" + c.state.rateLimitTier = "" shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 5eacd19ee3..87c9afde2e 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -71,6 +71,7 @@ type credentialState struct { hardRateLimited bool rateLimitResetAt time.Time accountType string + rateLimitTier string lastUpdated time.Time consecutivePollFailures int unavailable bool @@ -134,6 +135,8 @@ type credential interface { weeklyUtilization() float64 fiveHourCap() float64 weeklyCap() float64 + planWeight() float64 + weeklyResetTime() time.Time markRateLimited(resetAt time.Time) earliestReset() time.Time unavailableError() error @@ -294,6 +297,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.state.accountType = latestCredentials.SubscriptionType + c.state.rateLimitTier = latestCredentials.RateLimitTier c.checkTransitionLocked() c.stateMutex.Unlock() if !latestCredentials.needsRefresh() { @@ -308,6 +312,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.state.accountType = newCredentials.SubscriptionType + c.state.rateLimitTier = newCredentials.RateLimitTier c.checkTransitionLocked() c.stateMutex.Unlock() @@ -510,6 +515,18 @@ func (c *defaultCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *defaultCredential) planWeight() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) +} + +func (c *defaultCredential) weeklyResetTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyReset +} + func (c *defaultCredential) isAvailable() bool { c.retryCredentialReloadIfNeeded() @@ -670,11 +687,72 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { } c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } + needsProfileFetch := c.state.rateLimitTier == "" shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() if shouldInterrupt { c.interruptConnections() } + + if needsProfileFetch { + c.fetchProfile(ctx, httpClient, accessToken) + } +} + +func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) { + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) + if err != nil { + c.logger.Debug("fetch profile for ", c.tag, ": ", err) + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return + } + + var profileResponse struct { + Organization *struct { + OrganizationType string `json:"organization_type"` + RateLimitTier string `json:"rate_limit_tier"` + } `json:"organization"` + } + err = json.NewDecoder(response.Body).Decode(&profileResponse) + if err != nil || profileResponse.Organization == nil { + return + } + + accountType := "" + switch profileResponse.Organization.OrganizationType { + case "claude_pro": + accountType = "pro" + case "claude_max": + accountType = "max" + case "claude_team": + accountType = "team" + case "claude_enterprise": + accountType = "enterprise" + } + rateLimitTier := profileResponse.Organization.RateLimitTier + + c.stateMutex.Lock() + if accountType != "" && c.state.accountType == "" { + c.state.accountType = accountType + } + if rateLimitTier != "" { + c.state.rateLimitTier = rateLimitTier + } + c.stateMutex.Unlock() + c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier)) } func (c *defaultCredential) close() { @@ -928,7 +1006,8 @@ func (p *balancerProvider) pickCredential(filter func(credential) bool) credenti func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { var best credential - bestRemaining := float64(-1) + bestScore := float64(-1) + now := time.Now() for _, cred := range p.credentials { if filter != nil && !filter(cred) { continue @@ -937,14 +1016,46 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia continue } remaining := cred.weeklyCap() - cred.weeklyUtilization() - if remaining > bestRemaining { - bestRemaining = remaining + score := remaining * cred.planWeight() + resetTime := cred.weeklyResetTime() + if !resetTime.IsZero() { + timeUntilReset := resetTime.Sub(now) + if timeUntilReset < time.Hour { + timeUntilReset = time.Hour + } + score *= weeklyWindowDuration / timeUntilReset.Hours() + } + if score > bestScore { + bestScore = score best = cred } } return best } +const weeklyWindowDuration = 7 * 24 // hours + +func ccmPlanWeight(accountType string, rateLimitTier string) float64 { + switch accountType { + case "max": + switch rateLimitTier { + case "default_claude_max_20x": + return 10 + case "default_claude_max_5x": + return 5 + default: + return 5 + } + case "team": + if rateLimitTier == "default_claude_max_5x" { + return 5 + } + return 1 + default: + return 1 + } +} + func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { start := int(p.roundRobinIndex.Add(1) - 1) count := len(p.credentials) diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 52c1e7210c..2c9dce46bd 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -30,17 +30,18 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - credDialer N.Dialer - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + credDialer N.Dialer + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + configuredPlanWeight float64 + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -129,16 +130,22 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) + configuredPlanWeight := options.PlanWeight + if configuredPlanWeight <= 0 { + configuredPlanWeight = 1 + } + cred := &externalCredential{ - tag: tag, - token: options.Token, - pollInterval: pollInterval, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - reverse: options.Reverse, - reverseContext: reverseContext, - reverseCancel: reverseCancel, + tag: tag, + token: options.Token, + pollInterval: pollInterval, + configuredPlanWeight: configuredPlanWeight, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, } if options.URL == "" { @@ -305,6 +312,16 @@ func (c *externalCredential) weeklyCap() float64 { return 100 } +func (c *externalCredential) planWeight() float64 { + return c.configuredPlanWeight +} + +func (c *externalCredential) weeklyResetTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyReset +} + func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) c.stateMutex.Lock() diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index db06d05a92..3cb1f48b94 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -135,6 +135,8 @@ type credential interface { weeklyUtilization() float64 fiveHourCap() float64 weeklyCap() float64 + planWeight() float64 + weeklyResetTime() time.Time markRateLimited(resetAt time.Time) earliestReset() time.Time unavailableError() error @@ -527,6 +529,18 @@ func (c *defaultCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *defaultCredential) planWeight() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return ocmPlanWeight(c.state.accountType) +} + +func (c *defaultCredential) weeklyResetTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyReset +} + func (c *defaultCredential) isAvailable() bool { c.retryCredentialReloadIfNeeded() @@ -991,7 +1005,8 @@ func (p *balancerProvider) pickCredential(filter func(credential) bool) credenti func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { var best credential - bestRemaining := float64(-1) + bestScore := float64(-1) + now := time.Now() for _, cred := range p.credentials { if filter != nil && !filter(cred) { continue @@ -1003,14 +1018,36 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia continue } remaining := cred.weeklyCap() - cred.weeklyUtilization() - if remaining > bestRemaining { - bestRemaining = remaining + score := remaining * cred.planWeight() + resetTime := cred.weeklyResetTime() + if !resetTime.IsZero() { + timeUntilReset := resetTime.Sub(now) + if timeUntilReset < time.Hour { + timeUntilReset = time.Hour + } + score *= weeklyWindowDuration / timeUntilReset.Hours() + } + if score > bestScore { + bestScore = score best = cred } } return best } +const weeklyWindowDuration = 7 * 24 // hours + +func ocmPlanWeight(accountType string) float64 { + switch accountType { + case "pro": + return 10 + case "plus": + return 1 + default: + return 1 + } +} + func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { start := int(p.roundRobinIndex.Add(1) - 1) count := len(p.credentials) From 80d5432654d1401174de3b08c6f342389fb224f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 15:19:59 +0800 Subject: [PATCH 32/96] service/ccm: update oauth token URL and remove unnecessary Accept header --- service/ccm/credential.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 8bfd27c23d..da559c173d 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -20,7 +20,7 @@ import ( const ( oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" - oauth2TokenURL = "https://console.anthropic.com/v1/oauth/token" + oauth2TokenURL = "https://platform.claude.com/v1/oauth/token" claudeAPIBaseURL = "https://api.anthropic.com" tokenRefreshBufferMs = 60000 anthropicBetaOAuthValue = "oauth-2025-04-20" @@ -164,7 +164,6 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau return nil, err } request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "application/json") request.Header.Set("User-Agent", ccmUserAgentValue) return request, nil }) From 6f433937bae10b7b59c256ad6c088f9432088434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 15:32:30 +0800 Subject: [PATCH 33/96] ccm,ocm: auto-detect plan weight for external credentials via status endpoint --- option/ccm.go | 1 - option/ocm.go | 1 - service/ccm/credential_external.go | 62 ++++++++++++++++------------- service/ccm/credential_state.go | 1 + service/ccm/service.go | 42 +++++++++++--------- service/ocm/credential_external.go | 64 +++++++++++++++++------------- service/ocm/credential_state.go | 1 + service/ocm/service.go | 40 +++++++++++-------- service/ocm/service_websocket.go | 14 +++++-- 9 files changed, 130 insertions(+), 96 deletions(-) diff --git a/option/ccm.go b/option/ccm.go index b4be72ea76..dd55a4ba4e 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -102,7 +102,6 @@ type CCMExternalCredentialOptions struct { Token string `json:"token"` Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` - PlanWeight float64 `json:"plan_weight,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/option/ocm.go b/option/ocm.go index 0f364821f9..e508abae7e 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -102,7 +102,6 @@ type OCMExternalCredentialOptions struct { Token string `json:"token"` Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` - PlanWeight float64 `json:"plan_weight,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 807a06fe8d..a5781d6f6d 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -29,17 +29,16 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - configuredPlanWeight float64 - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -113,22 +112,16 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) - configuredPlanWeight := options.PlanWeight - if configuredPlanWeight <= 0 { - configuredPlanWeight = 1 - } - cred := &externalCredential{ - tag: tag, - token: options.Token, - pollInterval: pollInterval, - configuredPlanWeight: configuredPlanWeight, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - reverse: options.Reverse, - reverseContext: reverseContext, - reverseCancel: reverseCancel, + tag: tag, + token: options.Token, + pollInterval: pollInterval, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, } if options.URL == "" { @@ -291,7 +284,12 @@ func (c *externalCredential) weeklyCap() float64 { } func (c *externalCredential) planWeight() float64 { - return c.configuredPlanWeight + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + if c.state.remotePlanWeight > 0 { + return c.state.remotePlanWeight + } + return 10 } func (c *externalCredential) weeklyResetTime() time.Time { @@ -422,6 +420,12 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.weeklyUtilization = value * 100 } } + if planWeight := headers.Get("X-CCM-Plan-Weight"); planWeight != "" { + value, err := strconv.ParseFloat(planWeight, 64) + if err == nil && value > 0 { + c.state.remotePlanWeight = value + } + } if hadData { c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() @@ -525,6 +529,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` WeeklyUtilization float64 `json:"weekly_utilization"` + PlanWeight float64 `json:"plan_weight"` } err = json.NewDecoder(response.Body).Decode(&statusResponse) if err != nil { @@ -540,6 +545,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 87c9afde2e..83db5a1977 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -72,6 +72,7 @@ type credentialState struct { rateLimitResetAt time.Time accountType string rateLimitTier string + remotePlanWeight float64 lastUpdated time.Time consecutivePollFailures int unavailable bool diff --git a/service/ccm/service.go b/service/ccm/service.go index 4fd880f8ac..f940d7309f 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -766,47 +766,48 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]float64{ "five_hour_utilization": avgFiveHour, "weekly_utilization": avgWeekly, + "plan_weight": totalWeight, }) } -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64) { - var totalFiveHour, totalWeekly float64 - var count int +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 for _, cred := range provider.allCredentials() { if !cred.isAvailable() { continue } - // Exclude the user's own external_credential (their contribution to us) if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { continue } - // If user doesn't allow external usage, exclude all external credentials if !userConfig.AllowExternalUsage && cred.isExternal() { continue } - scaledFiveHour := cred.fiveHourUtilization() / cred.fiveHourCap() * 100 - if scaledFiveHour > 100 { - scaledFiveHour = 100 + weight := cred.planWeight() + remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 } - scaledWeekly := cred.weeklyUtilization() / cred.weeklyCap() * 100 - if scaledWeekly > 100 { - scaledWeekly = 100 + remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 } - totalFiveHour += scaledFiveHour - totalWeekly += scaledWeekly - count++ + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight } - if count == 0 { - return 100, 100 + if totalWeight == 0 { + return 100, 100, 0 } - return totalFiveHour / float64(count), totalWeekly / float64(count) + return 100 - totalWeightedRemaining5h/totalWeight, + 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight } func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { @@ -815,11 +816,14 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use return } - avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) // Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range) headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) + if totalWeight > 0 { + headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + } } func (s *Service) InterfaceUpdated() { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 2c9dce46bd..2a09c84ad6 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -30,18 +30,17 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - credDialer N.Dialer - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - configuredPlanWeight float64 - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + credDialer N.Dialer + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -130,22 +129,16 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) - configuredPlanWeight := options.PlanWeight - if configuredPlanWeight <= 0 { - configuredPlanWeight = 1 - } - cred := &externalCredential{ - tag: tag, - token: options.Token, - pollInterval: pollInterval, - configuredPlanWeight: configuredPlanWeight, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - reverse: options.Reverse, - reverseContext: reverseContext, - reverseCancel: reverseCancel, + tag: tag, + token: options.Token, + pollInterval: pollInterval, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, } if options.URL == "" { @@ -313,7 +306,12 @@ func (c *externalCredential) weeklyCap() float64 { } func (c *externalCredential) planWeight() float64 { - return c.configuredPlanWeight + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + if c.state.remotePlanWeight > 0 { + return c.state.remotePlanWeight + } + return 10 } func (c *externalCredential) weeklyResetTime() time.Time { @@ -459,6 +457,12 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.weeklyUtilization = value } } + if planWeight := headers.Get("X-OCM-Plan-Weight"); planWeight != "" { + value, err := strconv.ParseFloat(planWeight, 64) + if err == nil && value > 0 { + c.state.remotePlanWeight = value + } + } if hadData { c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() @@ -562,6 +566,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` WeeklyUtilization float64 `json:"weekly_utilization"` + PlanWeight float64 `json:"plan_weight"` } err = json.NewDecoder(response.Body).Decode(&statusResponse) if err != nil { @@ -577,6 +582,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 3cb1f48b94..ea7c621bf1 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -71,6 +71,7 @@ type credentialState struct { hardRateLimited bool rateLimitResetAt time.Time accountType string + remotePlanWeight float64 lastUpdated time.Time consecutivePollFailures int unavailable bool diff --git a/service/ocm/service.go b/service/ocm/service.go index 9dded47406..1f93b4d5c0 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -832,19 +832,19 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]float64{ "five_hour_utilization": avgFiveHour, "weekly_utilization": avgWeekly, + "plan_weight": totalWeight, }) } -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64) { - var totalFiveHour, totalWeekly float64 - var count int +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 for _, cred := range provider.allCredentials() { if !cred.isAvailable() { continue @@ -855,22 +855,25 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user if !userConfig.AllowExternalUsage && cred.isExternal() { continue } - scaledFiveHour := cred.fiveHourUtilization() / cred.fiveHourCap() * 100 - if scaledFiveHour > 100 { - scaledFiveHour = 100 + weight := cred.planWeight() + remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 } - scaledWeekly := cred.weeklyUtilization() / cred.weeklyCap() * 100 - if scaledWeekly > 100 { - scaledWeekly = 100 + remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 } - totalFiveHour += scaledFiveHour - totalWeekly += scaledWeekly - count++ + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight } - if count == 0 { - return 100, 100 + if totalWeight == 0 { + return 100, 100, 0 } - return totalFiveHour / float64(count), totalWeekly / float64(count) + return 100 - totalWeightedRemaining5h/totalWeight, + 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight } func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { @@ -879,7 +882,7 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use return } - avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) if activeLimitIdentifier == "" { @@ -888,6 +891,9 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) + if totalWeight > 0 { + headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + } } func (s *Service) InterfaceUpdated() { diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index f348f7fa44..2f4911959c 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -368,8 +368,9 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential ResetAt int64 `json:"reset_at"` } `json:"secondary"` } `json:"rate_limits"` - LimitName string `json:"limit_name"` - MeteredLimitName string `json:"metered_limit_name"` + LimitName string `json:"limit_name"` + MeteredLimitName string `json:"metered_limit_name"` + PlanWeight float64 `json:"plan_weight"` } err := json.Unmarshal(data, &rateLimitsEvent) if err != nil { @@ -398,6 +399,9 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential headers.Set("x-"+identifier+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10)) } } + if rateLimitsEvent.PlanWeight > 0 { + headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(rateLimitsEvent.PlanWeight, 'f', -1, 64)) + } selectedCredential.updateStateFromHeaders(headers) } @@ -436,7 +440,11 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide return nil, err } - averageFiveHour, averageWeekly := s.computeAggregatedUtilization(provider, userConfig) + averageFiveHour, averageWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + if totalWeight > 0 { + event["plan_weight"], _ = json.Marshal(totalWeight) + } primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour) if err != nil { From 8fe8e238b37a76d782f3976aca2eacb1ce6fba4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 15:52:27 +0800 Subject: [PATCH 34/96] service/ocm: unify websocket logging with HTTP request logging --- service/ocm/service_websocket.go | 37 ++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 2f4911959c..365472f595 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -93,17 +93,6 @@ func (s *Service) handleWebSocket( credentialFilter func(credential) bool, isNew bool, ) { - if isNew { - logParts := []any{"assigned credential ", selectedCredential.tagName()} - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if username != "" { - logParts = append(logParts, " by user ", username) - } - s.logger.Debug(logParts...) - } - var ( err error upstreamConn net.Conn @@ -256,7 +245,7 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel) + s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID) }() go func() { defer waitGroup.Done() @@ -266,7 +255,8 @@ func (s *Service) handleWebSocket( waitGroup.Wait() } -func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string) { +func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { + logged := false for { data, opCode, err := wsutil.ReadClientData(clientConn) if err != nil { @@ -278,11 +268,26 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo if opCode == ws.OpText { var request struct { - Type string `json:"type"` - Model string `json:"model"` + Type string `json:"type"` + Model string `json:"model"` + ServiceTier string `json:"service_tier"` } if json.Unmarshal(data, &request) == nil && request.Type == "response.create" && request.Model != "" { - s.logger.Debug("model=", request.Model) + if isNew && !logged { + logged = true + logParts := []any{"assigned credential ", selectedCredential.tagName()} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + logParts = append(logParts, ", model=", request.Model) + if request.ServiceTier == "priority" { + logParts = append(logParts, ", fast") + } + s.logger.Debug(logParts...) + } if selectedCredential.usageTrackerOrNil() != nil { select { case modelChannel <- request.Model: From f4aaf33bf295d6b9d5a2a682be5542baa84c828e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 15:58:33 +0800 Subject: [PATCH 35/96] ccm,ocm: strip reverse proxy headers from upstream responses --- service/ccm/service.go | 2 +- service/ocm/service.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/service/ccm/service.go b/service/ccm/service.go index f940d7309f..e8aaebdc4c 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -538,7 +538,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } for key, values := range response.Header { - if !isHopByHopHeader(key) { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { w.Header()[key] = values } } diff --git a/service/ocm/service.go b/service/ocm/service.go index 1f93b4d5c0..858595b1fb 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -580,7 +580,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } for key, values := range response.Header { - if !isHopByHopHeader(key) { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { w.Header()[key] = values } } From badeeb91fe2934c3f3451e5b39632381b966a17b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 16:07:58 +0800 Subject: [PATCH 36/96] service/ocm: add default OpenAI-Beta header and log websocket error body The upstream OpenAI WebSocket endpoint requires the OpenAI-Beta: responses_websockets=2026-02-06 header. Set it automatically when the client doesn't provide it. Also capture and log the response body on non-429 WebSocket handshake failures to surface the actual error from upstream. --- service/ocm/service_websocket.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 365472f595..21b25bafcb 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -99,6 +99,7 @@ func (s *Service) handleWebSocket( upstreamBufferedReader *bufio.Reader upstreamResponseHeaders http.Header statusCode int + statusResponseBody string ) for { @@ -135,9 +136,13 @@ func (s *Service) handleWebSocket( if accountID := selectedCredential.ocmGetAccountID(); accountID != "" { upstreamHeaders.Set("ChatGPT-Account-Id", accountID) } + if upstreamHeaders.Get("OpenAI-Beta") == "" { + upstreamHeaders.Set("OpenAI-Beta", "responses_websockets=2026-02-06") + } upstreamResponseHeaders = make(http.Header) statusCode = 0 + statusResponseBody = "" upstreamDialer := ws.Dialer{ NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { return selectedCredential.ocmDialer().DialContext(ctx, network, M.ParseSocksaddr(addr)) @@ -162,6 +167,10 @@ func (s *Service) handleWebSocket( if readErr == nil { upstreamResponseHeaders = http.Header(mimeHeader) } + body, readErr := io.ReadAll(io.LimitReader(bufferedResponse, 4096)) + if readErr == nil && len(body) > 0 { + statusResponseBody = string(body) + } }, OnHeader: func(key, value []byte) error { upstreamResponseHeaders.Add(string(key), string(value)) @@ -185,7 +194,11 @@ func (s *Service) handleWebSocket( selectedCredential = nextCredential continue } - s.logger.Error("dial upstream websocket: ", err) + if statusCode > 0 && statusResponseBody != "" { + s.logger.Error("dial upstream websocket: status ", statusCode, " body: ", statusResponseBody) + } else { + s.logger.Error("dial upstream websocket: ", err) + } writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed") return } From b97b9d9cfd67debe9979ac216c1867a41ffca242 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 16:14:24 +0800 Subject: [PATCH 37/96] ccm,ocm: add request ID context to HTTP request logging --- service/ccm/reverse.go | 18 +++++++------- service/ccm/service.go | 39 +++++++++++++++--------------- service/ocm/reverse.go | 18 +++++++------- service/ocm/service.go | 41 ++++++++++++++++---------------- service/ocm/service_websocket.go | 29 +++++++++++----------- 5 files changed, 74 insertions(+), 71 deletions(-) diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 625e55a9dc..62a1011171 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr { return l.session.Addr() } -func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { +func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") != "reverse-proxy" { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") return @@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { receiverCredential := s.findReceiverCredential(clientToken) if receiverCredential == nil { - s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") + s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") return } hijacker, ok := w.(http.Hijacker) if !ok { - s.logger.Error("reverse connect: hijack not supported") + s.logger.ErrorContext(ctx, "reverse connect: hijack not supported") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") return } conn, bufferedReadWriter, err := hijacker.Hijack() if err != nil { - s.logger.Error("reverse connect: hijack: ", err) + s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err) return } @@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { _, err = bufferedReadWriter.WriteString(response) if err != nil { conn.Close() - s.logger.Error("reverse connect: write upgrade response: ", err) + s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err) return } err = bufferedReadWriter.Flush() if err != nil { conn.Close() - s.logger.Error("reverse connect: flush upgrade response: ", err) + s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err) return } session, err := yamux.Client(conn, reverseYamuxConfig()) if err != nil { conn.Close() - s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) + s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) return } @@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { session.Close() return } - s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) + s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) go func() { <-session.CloseChan() receiverCredential.clearReverseSession(session) - s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) + s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName()) }() } diff --git a/service/ccm/service.go b/service/ccm/service.go index e8aaebdc4c..81e3b38a5c 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -325,13 +325,14 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int { } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := log.ContextWithNewID(r.Context()) if r.URL.Path == "/ccm/v1/status" { s.handleStatusEndpoint(w, r) return } if r.URL.Path == "/ccm/v1/reverse" { - s.handleReverseConnect(w, r) + s.handleReverseConnect(ctx, w, r) return } @@ -344,20 +345,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") return } clientToken := strings.TrimPrefix(authHeader, "Bearer ") if clientToken == authHeader { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") return } var ok bool username, ok = s.userManager.Authenticate(clientToken) if !ok { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") return } @@ -373,7 +374,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { var err error bodyBytes, err = io.ReadAll(r.Body) if err != nil { - s.logger.Error("read request body: ", err) + s.logger.ErrorContext(ctx, "read request body: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") return } @@ -400,7 +401,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { var err error provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) if err != nil { - s.logger.Error("resolve credential: ", err) + s.logger.ErrorContext(ctx, "resolve credential: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) return } @@ -448,7 +449,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } logParts = append(logParts, ", model=", modelDisplay) } - s.logger.Debug(logParts...) + s.logger.DebugContext(ctx, logParts...) } if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { @@ -463,7 +464,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if err != nil { - s.logger.Error("create proxy request: ", err) + s.logger.ErrorContext(ctx, "create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") return } @@ -493,12 +494,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } response.Body.Close() - s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(r.Context()) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { - s.logger.Error("retry request: ", buildErr) + s.logger.ErrorContext(ctx, "retry request: ", buildErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) return } @@ -511,7 +512,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") return } - s.logger.Error("retry request: ", retryErr) + s.logger.ErrorContext(ctx, "retry request: ", retryErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) return } @@ -525,7 +526,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) - s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) go selectedCredential.pollUsage(s.ctx) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) @@ -546,7 +547,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { usageTracker := selectedCredential.usageTrackerOrNil() if usageTracker != nil && response.StatusCode == http.StatusOK { - s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) + s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -555,7 +556,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } flusher, ok := w.(http.Flusher) if !ok { - s.logger.Error("streaming not supported") + s.logger.ErrorContext(ctx, "streaming not supported") return } buffer := make([]byte, buf.BufferSize) @@ -564,7 +565,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if n > 0 { _, writeError := w.Write(buffer[:n]) if writeError != nil { - s.logger.Error("write streaming response: ", writeError) + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() @@ -576,7 +577,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { +func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) isStreaming := err == nil && mediaType == "text/event-stream" @@ -584,7 +585,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if !isStreaming { bodyBytes, err := io.ReadAll(response.Body) if err != nil { - s.logger.Error("read response body: ", err) + s.logger.ErrorContext(ctx, "read response body: ", err) return } @@ -627,7 +628,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons flusher, ok := writer.(http.Flusher) if !ok { - s.logger.Error("streaming not supported") + s.logger.ErrorContext(ctx, "streaming not supported") return } @@ -690,7 +691,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons _, writeError := writer.Write(buffer[:n]) if writeError != nil { - s.logger.Error("write streaming response: ", writeError) + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index 906778df58..1ed274f6d2 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr { return l.session.Addr() } -func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { +func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") != "reverse-proxy" { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") return @@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { receiverCredential := s.findReceiverCredential(clientToken) if receiverCredential == nil { - s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") + s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") return } hijacker, ok := w.(http.Hijacker) if !ok { - s.logger.Error("reverse connect: hijack not supported") + s.logger.ErrorContext(ctx, "reverse connect: hijack not supported") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") return } conn, bufferedReadWriter, err := hijacker.Hijack() if err != nil { - s.logger.Error("reverse connect: hijack: ", err) + s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err) return } @@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { _, err = bufferedReadWriter.WriteString(response) if err != nil { conn.Close() - s.logger.Error("reverse connect: write upgrade response: ", err) + s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err) return } err = bufferedReadWriter.Flush() if err != nil { conn.Close() - s.logger.Error("reverse connect: flush upgrade response: ", err) + s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err) return } session, err := yamux.Client(conn, reverseYamuxConfig()) if err != nil { conn.Close() - s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) + s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) return } @@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { session.Close() return } - s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) + s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) go func() { <-session.CloseChan() receiverCredential.clearReverseSession(session) - s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) + s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName()) }() } diff --git a/service/ocm/service.go b/service/ocm/service.go index 858595b1fb..94a98c665e 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -362,13 +362,14 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := log.ContextWithNewID(r.Context()) if r.URL.Path == "/ocm/v1/status" { s.handleStatusEndpoint(w, r) return } if r.URL.Path == "/ocm/v1/reverse" { - s.handleReverseConnect(w, r) + s.handleReverseConnect(ctx, w, r) return } @@ -382,20 +383,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") return } clientToken := strings.TrimPrefix(authHeader, "Bearer ") if clientToken == authHeader { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") return } var ok bool username, ok = s.userManager.Authenticate(clientToken) if !ok { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") return } @@ -411,7 +412,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { var err error provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) if err != nil { - s.logger.Error("resolve credential: ", err) + s.logger.ErrorContext(ctx, "resolve credential: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) return } @@ -437,7 +438,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew) + s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew) return } @@ -465,7 +466,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if isJSONRequest { bodyBytes, err = io.ReadAll(r.Body) if err != nil { - s.logger.Error("read request body: ", err) + s.logger.ErrorContext(ctx, "read request body: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") return } @@ -495,7 +496,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if requestServiceTier == "priority" { logParts = append(logParts, ", fast") } - s.logger.Debug(logParts...) + s.logger.DebugContext(ctx, logParts...) } requestContext := selectedCredential.wrapRequestContext(r.Context()) @@ -504,7 +505,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if err != nil { - s.logger.Error("create proxy request: ", err) + s.logger.ErrorContext(ctx, "create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") return } @@ -535,12 +536,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } response.Body.Close() - s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(r.Context()) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { - s.logger.Error("retry request: ", buildErr) + s.logger.ErrorContext(ctx, "retry request: ", buildErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) return } @@ -553,7 +554,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") return } - s.logger.Error("retry request: ", retryErr) + s.logger.ErrorContext(ctx, "retry request: ", retryErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) return } @@ -567,7 +568,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) - s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) go selectedCredential.pollUsage(s.ctx) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) @@ -589,7 +590,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { usageTracker := selectedCredential.usageTrackerOrNil() if usageTracker != nil && response.StatusCode == http.StatusOK && (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { - s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username) + s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -598,7 +599,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } flusher, ok := w.(http.Flusher) if !ok { - s.logger.Error("streaming not supported") + s.logger.ErrorContext(ctx, "streaming not supported") return } buffer := make([]byte, buf.BufferSize) @@ -607,7 +608,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if n > 0 { _, writeError := w.Write(buffer[:n]) if writeError != nil { - s.logger.Error("write streaming response: ", writeError) + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() @@ -619,7 +620,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { +func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { isChatCompletions := path == "/v1/chat/completions" weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) @@ -630,7 +631,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if !isStreaming { bodyBytes, err := io.ReadAll(response.Body) if err != nil { - s.logger.Error("read response body: ", err) + s.logger.ErrorContext(ctx, "read response body: ", err) return } @@ -683,7 +684,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons flusher, ok := writer.(http.Flusher) if !ok { - s.logger.Error("streaming not supported") + s.logger.ErrorContext(ctx, "streaming not supported") return } @@ -760,7 +761,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons _, writeError := writer.Write(buffer[:n]) if writeError != nil { - s.logger.Error("write streaming response: ", writeError) + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 21b25bafcb..17178e8c2d 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -82,6 +82,7 @@ func isForwardableWebSocketRequestHeader(key string) bool { } func (s *Service) handleWebSocket( + ctx context.Context, w http.ResponseWriter, r *http.Request, path string, @@ -105,7 +106,7 @@ func (s *Service) handleWebSocket( for { accessToken, accessErr := selectedCredential.getAccessToken() if accessErr != nil { - s.logger.Error("get access token for websocket: ", accessErr) + s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") return } @@ -190,14 +191,14 @@ func (s *Service) handleWebSocket( writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") return } - s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) selectedCredential = nextCredential continue } if statusCode > 0 && statusResponseBody != "" { - s.logger.Error("dial upstream websocket: status ", statusCode, " body: ", statusResponseBody) + s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody) } else { - s.logger.Error("dial upstream websocket: ", err) + s.logger.ErrorContext(ctx, "dial upstream websocket: ", err) } writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed") return @@ -226,7 +227,7 @@ func (s *Service) handleWebSocket( } clientConn, _, _, err := clientUpgrader.Upgrade(r, w) if err != nil { - s.logger.Error("upgrade client websocket: ", err) + s.logger.ErrorContext(ctx, "upgrade client websocket: ", err) upstreamConn.Close() return } @@ -258,23 +259,23 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID) + s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID) }() go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) }() waitGroup.Wait() } -func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { +func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { logged := false for { data, opCode, err := wsutil.ReadClientData(clientConn) if err != nil { if !E.IsClosedOrCanceled(err) { - s.logger.Debug("read client websocket: ", err) + s.logger.DebugContext(ctx, "read client websocket: ", err) } return } @@ -299,7 +300,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo if request.ServiceTier == "priority" { logParts = append(logParts, ", fast") } - s.logger.Debug(logParts...) + s.logger.DebugContext(ctx, logParts...) } if selectedCredential.usageTrackerOrNil() != nil { select { @@ -313,21 +314,21 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo err = wsutil.WriteClientMessage(upstreamConn, opCode, data) if err != nil { if !E.IsClosedOrCanceled(err) { - s.logger.Debug("write upstream websocket: ", err) + s.logger.DebugContext(ctx, "write upstream websocket: ", err) } return } } } -func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { data, opCode, err := wsutil.ReadServerData(upstreamReadWriter) if err != nil { if !E.IsClosedOrCanceled(err) { - s.logger.Debug("read upstream websocket: ", err) + s.logger.DebugContext(ctx, "read upstream websocket: ", err) } return } @@ -367,7 +368,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite err = wsutil.WriteServerMessage(clientConn, opCode, data) if err != nil { if !E.IsClosedOrCanceled(err) { - s.logger.Debug("write client websocket: ", err) + s.logger.DebugContext(ctx, "write client websocket: ", err) } return } From f871113832fed2d9b3a352354de33be148fca8ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 17:15:32 +0800 Subject: [PATCH 38/96] ccm,ocm: add balancer session rebalancing with per-credential interrupt When a sticky session's credential utilization exceeds the least-used credential by a weight-adjusted threshold, force reassign all sessions on that credential and cancel in-flight requests scoped to the balancer. Threshold formula: effective = rebalance_threshold / planWeight, so a config value of 20 triggers at 2% delta for Max 20x (w=10), 4% for Max 5x (w=5), and 20% for Pro (w=1). --- option/ccm.go | 7 +- option/ocm.go | 7 +- service/ccm/credential_external.go | 6 +- service/ccm/credential_state.go | 120 +++++++++++++++++++++++------ service/ccm/service.go | 2 + service/ocm/credential_external.go | 6 +- service/ocm/credential_state.go | 116 ++++++++++++++++++++++------ service/ocm/service.go | 2 + 8 files changed, 208 insertions(+), 58 deletions(-) diff --git a/option/ccm.go b/option/ccm.go index dd55a4ba4e..6d9d14423c 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -91,9 +91,10 @@ type CCMDefaultCredentialOptions struct { } type CCMBalancerCredentialOptions struct { - Strategy string `json:"strategy,omitempty"` - Credentials badoption.Listable[string] `json:"credentials"` - PollInterval badoption.Duration `json:"poll_interval,omitempty"` + Strategy string `json:"strategy,omitempty"` + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` + RebalanceThreshold float64 `json:"rebalance_threshold,omitempty"` } type CCMExternalCredentialOptions struct { diff --git a/option/ocm.go b/option/ocm.go index e508abae7e..6015b506cd 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -91,9 +91,10 @@ type OCMDefaultCredentialOptions struct { } type OCMBalancerCredentialOptions struct { - Strategy string `json:"strategy,omitempty"` - Credentials badoption.Listable[string] `json:"credentials"` - PollInterval badoption.Duration `json:"poll_interval,omitempty"` + Strategy string `json:"strategy,omitempty"` + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` + RebalanceThreshold float64 `json:"rebalance_threshold,omitempty"` } type OCMExternalCredentialOptions struct { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index a5781d6f6d..ade2d8361c 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -465,9 +465,9 @@ func (c *externalCredential) wrapRequestContext(parent context.Context) *credent cancel() }) return &credentialRequestContext{ - Context: derived, - releaseFunc: stop, - cancelFunc: cancel, + Context: derived, + releaseFuncs: []func() bool{stop}, + cancelFunc: cancel, } } diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 83db5a1977..490c6148f7 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -110,15 +110,21 @@ type defaultCredential struct { type credentialRequestContext struct { context.Context - releaseOnce sync.Once - cancelOnce sync.Once - releaseFunc func() bool - cancelFunc context.CancelFunc + releaseOnce sync.Once + cancelOnce sync.Once + releaseFuncs []func() bool + cancelFunc context.CancelFunc +} + +func (c *credentialRequestContext) addInterruptLink(stop func() bool) { + c.releaseFuncs = append(c.releaseFuncs, stop) } func (c *credentialRequestContext) releaseCredentialInterrupt() { c.releaseOnce.Do(func() { - c.releaseFunc() + for _, f := range c.releaseFuncs { + f() + } }) } @@ -504,9 +510,9 @@ func (c *defaultCredential) wrapRequestContext(parent context.Context) *credenti cancel() }) return &credentialRequestContext{ - Context: derived, - releaseFunc: stop, - cancelFunc: cancel, + Context: derived, + releaseFuncs: []func() bool{stop}, + cancelFunc: cancel, } } @@ -851,6 +857,7 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt type credentialProvider interface { selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential + wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) pollIfStale(ctx context.Context) allCredentials() []credential close() @@ -913,6 +920,8 @@ func (p *singleCredentialProvider) allCredentials() []credential { return []credential{p.cred} } +func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} + func (p *singleCredentialProvider) close() {} const sessionExpiry = 24 * time.Hour @@ -922,27 +931,37 @@ type sessionEntry struct { createdAt time.Time } -// balancerProvider assigns sessions to credentials based on a configurable strategy. -type balancerProvider struct { - credentials []credential - strategy string - roundRobinIndex atomic.Uint64 - pollInterval time.Duration - sessionMutex sync.RWMutex - sessions map[string]sessionEntry - logger log.ContextLogger +type credentialInterruptEntry struct { + context context.Context + cancel context.CancelFunc } -func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { +// balancerProvider assigns sessions to credentials based on a configurable strategy. +type balancerProvider struct { + credentials []credential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + rebalanceThreshold float64 + sessionMutex sync.RWMutex + sessions map[string]sessionEntry + interruptAccess sync.Mutex + credentialInterrupts map[string]credentialInterruptEntry + logger log.ContextLogger +} + +func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval } return &balancerProvider{ - credentials: credentials, - strategy: strategy, - pollInterval: pollInterval, - sessions: make(map[string]sessionEntry), - logger: logger, + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + rebalanceThreshold: rebalanceThreshold, + sessions: make(map[string]sessionEntry), + credentialInterrupts: make(map[string]credentialInterruptEntry), + logger: logger, } } @@ -954,6 +973,20 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden if exists { for _, cred := range p.credentials { if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { + better := p.pickLeastUsed(filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName()) + break + } + } + } return cred, false, nil } } @@ -977,6 +1010,40 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden return best, isNew, nil } +func (p *balancerProvider) rebalanceCredential(tag string) { + p.interruptAccess.Lock() + if entry, loaded := p.credentialInterrupts[tag]; loaded { + entry.cancel() + } + ctx, cancel := context.WithCancel(context.Background()) + p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.interruptAccess.Unlock() + + p.sessionMutex.Lock() + for id, entry := range p.sessions { + if entry.tag == tag { + delete(p.sessions, id) + } + } + p.sessionMutex.Unlock() +} + +func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) { + tag := cred.tagName() + p.interruptAccess.Lock() + entry, loaded := p.credentialInterrupts[tag] + if !loaded { + ctx, cancel := context.WithCancel(context.Background()) + entry = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[tag] = entry + } + p.interruptAccess.Unlock() + stop := context.AfterFunc(entry.context, func() { + requestContext.cancelOnce.Do(requestContext.cancelFunc) + }) + requestContext.addInterruptLink(stop) +} + func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential { cred.markRateLimited(resetAt) if sessionID != "" { @@ -1166,6 +1233,8 @@ func (p *fallbackProvider) allCredentials() []credential { return p.credentials } +func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} + func (p *fallbackProvider) close() {} func allCredentialsUnavailableError(credentials []credential) error { @@ -1247,7 +1316,7 @@ func buildCredentialProviders( if err != nil { return nil, nil, err } - providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger) + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) case "fallback": subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag) if err != nil { @@ -1346,6 +1415,9 @@ func validateCCMOptions(options option.CCMServiceOptions) error { default: return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) } + if cred.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") + } } } diff --git a/service/ccm/service.go b/service/ccm/service.go index 81e3b38a5c..2760c348aa 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -459,6 +459,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } requestContext := selectedCredential.wrapRequestContext(r.Context()) + provider.wrapProviderInterrupt(selectedCredential, requestContext) defer func() { requestContext.cancelRequest() }() @@ -497,6 +498,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(r.Context()) + provider.wrapProviderInterrupt(nextCredential, requestContext) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { s.logger.ErrorContext(ctx, "retry request: ", buildErr) diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 2a09c84ad6..3d46f8fa89 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -502,9 +502,9 @@ func (c *externalCredential) wrapRequestContext(parent context.Context) *credent cancel() }) return &credentialRequestContext{ - Context: derived, - releaseFunc: stop, - cancelFunc: cancel, + Context: derived, + releaseFuncs: []func() bool{stop}, + cancelFunc: cancel, } } diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index ea7c621bf1..7a6c8ef5ed 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -110,15 +110,21 @@ type defaultCredential struct { type credentialRequestContext struct { context.Context - releaseOnce sync.Once - cancelOnce sync.Once - releaseFunc func() bool - cancelFunc context.CancelFunc + releaseOnce sync.Once + cancelOnce sync.Once + releaseFuncs []func() bool + cancelFunc context.CancelFunc +} + +func (c *credentialRequestContext) addInterruptLink(stop func() bool) { + c.releaseFuncs = append(c.releaseFuncs, stop) } func (c *credentialRequestContext) releaseCredentialInterrupt() { c.releaseOnce.Do(func() { - c.releaseFunc() + for _, f := range c.releaseFuncs { + f() + } }) } @@ -518,9 +524,9 @@ func (c *defaultCredential) wrapRequestContext(parent context.Context) *credenti cancel() }) return &credentialRequestContext{ - Context: derived, - releaseFunc: stop, - cancelFunc: cancel, + Context: derived, + releaseFuncs: []func() bool{stop}, + cancelFunc: cancel, } } @@ -848,6 +854,7 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt type credentialProvider interface { selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential + wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) pollIfStale(ctx context.Context) allCredentials() []credential close() @@ -909,6 +916,8 @@ func (p *singleCredentialProvider) allCredentials() []credential { return []credential{p.cred} } +func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} + func (p *singleCredentialProvider) close() {} const sessionExpiry = 24 * time.Hour @@ -918,30 +927,40 @@ type sessionEntry struct { createdAt time.Time } +type credentialInterruptEntry struct { + context context.Context + cancel context.CancelFunc +} + type balancerProvider struct { - credentials []credential - strategy string - roundRobinIndex atomic.Uint64 - pollInterval time.Duration - sessionMutex sync.RWMutex - sessions map[string]sessionEntry - logger log.ContextLogger + credentials []credential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + rebalanceThreshold float64 + sessionMutex sync.RWMutex + sessions map[string]sessionEntry + interruptAccess sync.Mutex + credentialInterrupts map[string]credentialInterruptEntry + logger log.ContextLogger } func compositeCredentialSelectable(cred credential) bool { return !cred.ocmIsAPIKeyMode() } -func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { +func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval } return &balancerProvider{ - credentials: credentials, - strategy: strategy, - pollInterval: pollInterval, - sessions: make(map[string]sessionEntry), - logger: logger, + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + rebalanceThreshold: rebalanceThreshold, + sessions: make(map[string]sessionEntry), + credentialInterrupts: make(map[string]credentialInterruptEntry), + logger: logger, } } @@ -953,6 +972,20 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden if exists { for _, cred := range p.credentials { if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && (filter == nil || filter(cred)) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { + better := p.pickLeastUsed(filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName()) + break + } + } + } return cred, false, nil } } @@ -976,6 +1009,40 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden return best, isNew, nil } +func (p *balancerProvider) rebalanceCredential(tag string) { + p.interruptAccess.Lock() + if entry, loaded := p.credentialInterrupts[tag]; loaded { + entry.cancel() + } + ctx, cancel := context.WithCancel(context.Background()) + p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.interruptAccess.Unlock() + + p.sessionMutex.Lock() + for id, entry := range p.sessions { + if entry.tag == tag { + delete(p.sessions, id) + } + } + p.sessionMutex.Unlock() +} + +func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) { + tag := cred.tagName() + p.interruptAccess.Lock() + entry, loaded := p.credentialInterrupts[tag] + if !loaded { + ctx, cancel := context.WithCancel(context.Background()) + entry = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[tag] = entry + } + p.interruptAccess.Unlock() + stop := context.AfterFunc(entry.context, func() { + requestContext.cancelOnce.Do(requestContext.cancelFunc) + }) + requestContext.addInterruptLink(stop) +} + func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential { cred.markRateLimited(resetAt) if sessionID != "" { @@ -1169,6 +1236,8 @@ func (p *fallbackProvider) allCredentials() []credential { return p.credentials } +func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} + func (p *fallbackProvider) close() {} func allRateLimitedError(credentials []credential) error { @@ -1232,7 +1301,7 @@ func buildOCMCredentialProviders( if err != nil { return nil, nil, err } - providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger) + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) case "fallback": subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag) if err != nil { @@ -1339,6 +1408,9 @@ func validateOCMOptions(options option.OCMServiceOptions) error { default: return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) } + if cred.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") + } } } diff --git a/service/ocm/service.go b/service/ocm/service.go index 94a98c665e..376659c4b2 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -500,6 +500,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } requestContext := selectedCredential.wrapRequestContext(r.Context()) + provider.wrapProviderInterrupt(selectedCredential, requestContext) defer func() { requestContext.cancelRequest() }() @@ -539,6 +540,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(r.Context()) + provider.wrapProviderInterrupt(nextCredential, requestContext) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { s.logger.ErrorContext(ctx, "retry request: ", buildErr) From d2300353fdf02e091991e84d38aa50905705bbea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 17:26:18 +0800 Subject: [PATCH 39/96] Propagate request context to upstream requests --- service/ccm/service.go | 4 ++-- service/ocm/service.go | 4 ++-- service/ocm/service_websocket.go | 13 ++++++++++++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/service/ccm/service.go b/service/ccm/service.go index 2760c348aa..58bdd37874 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -458,7 +458,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - requestContext := selectedCredential.wrapRequestContext(r.Context()) + requestContext := selectedCredential.wrapRequestContext(ctx) provider.wrapProviderInterrupt(selectedCredential, requestContext) defer func() { requestContext.cancelRequest() @@ -497,7 +497,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { response.Body.Close() s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() - requestContext = nextCredential.wrapRequestContext(r.Context()) + requestContext = nextCredential.wrapRequestContext(ctx) provider.wrapProviderInterrupt(nextCredential, requestContext) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { diff --git a/service/ocm/service.go b/service/ocm/service.go index 376659c4b2..cd7909dd49 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -499,7 +499,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.logger.DebugContext(ctx, logParts...) } - requestContext := selectedCredential.wrapRequestContext(r.Context()) + requestContext := selectedCredential.wrapRequestContext(ctx) provider.wrapProviderInterrupt(selectedCredential, requestContext) defer func() { requestContext.cancelRequest() @@ -539,7 +539,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { response.Body.Close() s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() - requestContext = nextCredential.wrapRequestContext(r.Context()) + requestContext = nextCredential.wrapRequestContext(ctx) provider.wrapProviderInterrupt(nextCredential, requestContext) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 17178e8c2d..fcffaae96a 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -96,12 +96,18 @@ func (s *Service) handleWebSocket( ) { var ( err error + requestContext *credentialRequestContext upstreamConn net.Conn upstreamBufferedReader *bufio.Reader upstreamResponseHeaders http.Header statusCode int statusResponseBody string ) + defer func() { + if requestContext != nil { + requestContext.cancelRequest() + } + }() for { accessToken, accessErr := selectedCredential.getAccessToken() @@ -179,10 +185,15 @@ func (s *Service) handleWebSocket( }, } - upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(s.ctx, upstreamURL) + requestContext = selectedCredential.wrapRequestContext(ctx) + provider.wrapProviderInterrupt(selectedCredential, requestContext) + upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(requestContext, upstreamURL) if err == nil { + requestContext.releaseCredentialInterrupt() break } + requestContext.cancelRequest() + requestContext = nil if statusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders) nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) From 2c907bef2ca46a0dc0059f6e67204d4a5a5718c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 17:38:40 +0800 Subject: [PATCH 40/96] Fix scoped rebalance interrupts --- service/ccm/credential_state.go | 145 +++++++++++++++++++++---------- service/ccm/service.go | 48 ++++++---- service/ocm/credential_state.go | 145 +++++++++++++++++++++---------- service/ocm/service.go | 50 +++++++---- service/ocm/service_websocket.go | 59 +++++++++---- 5 files changed, 304 insertions(+), 143 deletions(-) diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 490c6148f7..6b1a766f2c 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -855,14 +855,37 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt // credentialProvider is the interface for all credential types. type credentialProvider interface { - selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential - wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) + selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential + linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool pollIfStale(ctx context.Context) allCredentials() []credential close() } +type credentialSelectionScope string + +const ( + credentialSelectionScopeAll credentialSelectionScope = "all" + credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" +) + +type credentialSelection struct { + scope credentialSelectionScope + filter func(credential) bool +} + +func (s credentialSelection) allows(cred credential) bool { + return s.filter == nil || s.filter(cred) +} + +func (s credentialSelection) scopeOrDefault() credentialSelectionScope { + if s.scope == "" { + return credentialSelectionScopeAll + } + return s.scope +} + // singleCredentialProvider wraps a single credential (legacy or single default). type singleCredentialProvider struct { cred credential @@ -870,8 +893,8 @@ type singleCredentialProvider struct { sessions map[string]time.Time } -func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { - if filter != nil && !filter(p.cred) { +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if !selection.allows(p.cred) { return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } if !p.cred.isAvailable() { @@ -896,7 +919,7 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, filter fun return p.cred, isNew, nil } -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential { +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { cred.markRateLimited(resetAt) return nil } @@ -920,15 +943,25 @@ func (p *singleCredentialProvider) allCredentials() []credential { return []credential{p.cred} } -func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} +func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} func (p *singleCredentialProvider) close() {} const sessionExpiry = 24 * time.Hour type sessionEntry struct { - tag string - createdAt time.Time + tag string + selectionScope credentialSelectionScope + createdAt time.Time +} + +type credentialInterruptKey struct { + tag string + selectionScope credentialSelectionScope } type credentialInterruptEntry struct { @@ -946,7 +979,7 @@ type balancerProvider struct { sessionMutex sync.RWMutex sessions map[string]sessionEntry interruptAccess sync.Mutex - credentialInterrupts map[string]credentialInterruptEntry + credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry logger log.ContextLogger } @@ -960,34 +993,37 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval pollInterval: pollInterval, rebalanceThreshold: rebalanceThreshold, sessions: make(map[string]sessionEntry), - credentialInterrupts: make(map[string]credentialInterruptEntry), + credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), logger: logger, } } -func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + selectionScope := selection.scopeOrDefault() if sessionID != "" { p.sessionMutex.RLock() entry, exists := p.sessions[sessionID] p.sessionMutex.RUnlock() if exists { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { - better := p.pickLeastUsed(filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName()) - break + if entry.selectionScope == selectionScope { + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName(), selectionScope) + break + } } } + return cred, false, nil } - return cred, false, nil } } p.sessionMutex.Lock() @@ -996,7 +1032,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden } } - best := p.pickCredential(filter) + best := p.pickCredential(selection.filter) if best == nil { return nil, false, allCredentialsUnavailableError(p.credentials) } @@ -1004,47 +1040,52 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden isNew := sessionID != "" if isNew { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selectionScope, + createdAt: time.Now(), + } p.sessionMutex.Unlock() } return best, isNew, nil } -func (p *balancerProvider) rebalanceCredential(tag string) { +func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { + key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} p.interruptAccess.Lock() - if entry, loaded := p.credentialInterrupts[tag]; loaded { + if entry, loaded := p.credentialInterrupts[key]; loaded { entry.cancel() } ctx, cancel := context.WithCancel(context.Background()) - p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} p.interruptAccess.Unlock() p.sessionMutex.Lock() for id, entry := range p.sessions { - if entry.tag == tag { + if entry.tag == tag && entry.selectionScope == selectionScope { delete(p.sessions, id) } } p.sessionMutex.Unlock() } -func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) { - tag := cred.tagName() +func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + key := credentialInterruptKey{ + tag: cred.tagName(), + selectionScope: selection.scopeOrDefault(), + } p.interruptAccess.Lock() - entry, loaded := p.credentialInterrupts[tag] + entry, loaded := p.credentialInterrupts[key] if !loaded { ctx, cancel := context.WithCancel(context.Background()) entry = credentialInterruptEntry{context: ctx, cancel: cancel} - p.credentialInterrupts[tag] = entry + p.credentialInterrupts[key] = entry } p.interruptAccess.Unlock() - stop := context.AfterFunc(entry.context, func() { - requestContext.cancelOnce.Do(requestContext.cancelFunc) - }) - requestContext.addInterruptLink(stop) + return context.AfterFunc(entry.context, onInterrupt) } -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential { +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) if sessionID != "" { p.sessionMutex.Lock() @@ -1052,10 +1093,14 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese p.sessionMutex.Unlock() } - best := p.pickCredential(filter) + best := p.pickCredential(selection.filter) if best != nil && sessionID != "" { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selection.scopeOrDefault(), + createdAt: time.Now(), + } p.sessionMutex.Unlock() } return best @@ -1196,9 +1241,9 @@ func newFallbackProvider(credentials []credential, pollInterval time.Duration, l } } -func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { +func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) { for _, cred := range p.credentials { - if filter != nil && !filter(cred) { + if !selection.allows(cred) { continue } if cred.isUsable() { @@ -1208,10 +1253,10 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo return nil, false, allCredentialsUnavailableError(p.credentials) } -func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential { +func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) for _, candidate := range p.credentials { - if filter != nil && !filter(candidate) { + if !selection.allows(candidate) { continue } if candidate.isUsable() { @@ -1233,7 +1278,11 @@ func (p *fallbackProvider) allCredentials() []credential { return p.credentials } -func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} +func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} func (p *fallbackProvider) close() {} diff --git a/service/ccm/service.go b/service/ccm/service.go index 58bdd37874..6a2aa2b740 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -67,7 +67,7 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro }) } -func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool { +func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool { if provider == nil || currentCredential == nil { return false } @@ -75,7 +75,7 @@ func hasAlternativeCredential(provider credentialProvider, currentCredential cre if cred == currentCredential { continue } - if filter != nil && !filter(cred) { + if !selection.allows(cred) { continue } if cred.isUsable() { @@ -109,16 +109,27 @@ func writeCredentialUnavailableError( r *http.Request, provider credentialProvider, currentCredential credential, - filter func(credential) bool, + selection credentialSelection, fallback string, ) { - if hasAlternativeCredential(provider, currentCredential, filter) { + if hasAlternativeCredential(provider, currentCredential, selection) { writeRetryableUsageError(w, r) return } writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback)) } +func credentialSelectionForUser(userConfig *option.CCMUser) credentialSelection { + selection := credentialSelection{scope: credentialSelectionScopeAll} + if userConfig != nil && !userConfig.AllowExternalUsage { + selection.scope = credentialSelectionScopeNonExternal + selection.filter = func(cred credential) bool { + return !cred.isExternal() + } + } + return selection +} + func isHopByHopHeader(header string) bool { switch strings.ToLower(header) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": @@ -424,12 +435,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - var credentialFilter func(credential) bool - if userConfig != nil && !userConfig.AllowExternalUsage { - credentialFilter = func(c credential) bool { return !c.isExternal() } - } + selection := credentialSelectionForUser(userConfig) - selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter) + selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) if err != nil { writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) return @@ -459,7 +467,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } requestContext := selectedCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(selectedCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } defer func() { requestContext.cancelRequest() }() @@ -476,7 +489,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") return } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) @@ -487,18 +500,23 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Transparent 429 retry for response.StatusCode == http.StatusTooManyRequests { resetAt := parseRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) selectedCredential.updateStateFromHeaders(response.Header) if bodyBytes == nil || nextCredential == nil { response.Body.Close() - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") return } response.Body.Close() s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(nextCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { s.logger.ErrorContext(ctx, "retry request: ", buildErr) @@ -511,7 +529,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") + writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") return } s.logger.ErrorContext(ctx, "retry request: ", retryErr) diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 7a6c8ef5ed..d8f2e826ac 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -852,22 +852,45 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt } type credentialProvider interface { - selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential - wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) + selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential + linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool pollIfStale(ctx context.Context) allCredentials() []credential close() } +type credentialSelectionScope string + +const ( + credentialSelectionScopeAll credentialSelectionScope = "all" + credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" +) + +type credentialSelection struct { + scope credentialSelectionScope + filter func(credential) bool +} + +func (s credentialSelection) allows(cred credential) bool { + return s.filter == nil || s.filter(cred) +} + +func (s credentialSelection) scopeOrDefault() credentialSelectionScope { + if s.scope == "" { + return credentialSelectionScopeAll + } + return s.scope +} + type singleCredentialProvider struct { cred credential sessionAccess sync.RWMutex sessions map[string]time.Time } -func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { - if filter != nil && !filter(p.cred) { +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if !selection.allows(p.cred) { return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } if !p.cred.isAvailable() { @@ -892,7 +915,7 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, filter fun return p.cred, isNew, nil } -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential { +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { cred.markRateLimited(resetAt) return nil } @@ -916,15 +939,25 @@ func (p *singleCredentialProvider) allCredentials() []credential { return []credential{p.cred} } -func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} +func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} func (p *singleCredentialProvider) close() {} const sessionExpiry = 24 * time.Hour type sessionEntry struct { - tag string - createdAt time.Time + tag string + selectionScope credentialSelectionScope + createdAt time.Time +} + +type credentialInterruptKey struct { + tag string + selectionScope credentialSelectionScope } type credentialInterruptEntry struct { @@ -941,7 +974,7 @@ type balancerProvider struct { sessionMutex sync.RWMutex sessions map[string]sessionEntry interruptAccess sync.Mutex - credentialInterrupts map[string]credentialInterruptEntry + credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry logger log.ContextLogger } @@ -959,34 +992,37 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval pollInterval: pollInterval, rebalanceThreshold: rebalanceThreshold, sessions: make(map[string]sessionEntry), - credentialInterrupts: make(map[string]credentialInterruptEntry), + credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), logger: logger, } } -func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + selectionScope := selection.scopeOrDefault() if sessionID != "" { p.sessionMutex.RLock() entry, exists := p.sessions[sessionID] p.sessionMutex.RUnlock() if exists { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && (filter == nil || filter(cred)) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { - better := p.pickLeastUsed(filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName()) - break + if entry.selectionScope == selectionScope { + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName(), selectionScope) + break + } } } + return cred, false, nil } - return cred, false, nil } } p.sessionMutex.Lock() @@ -995,7 +1031,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden } } - best := p.pickCredential(filter) + best := p.pickCredential(selection.filter) if best == nil { return nil, false, allRateLimitedError(p.credentials) } @@ -1003,47 +1039,52 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden isNew := sessionID != "" if isNew { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selectionScope, + createdAt: time.Now(), + } p.sessionMutex.Unlock() } return best, isNew, nil } -func (p *balancerProvider) rebalanceCredential(tag string) { +func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { + key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} p.interruptAccess.Lock() - if entry, loaded := p.credentialInterrupts[tag]; loaded { + if entry, loaded := p.credentialInterrupts[key]; loaded { entry.cancel() } ctx, cancel := context.WithCancel(context.Background()) - p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} p.interruptAccess.Unlock() p.sessionMutex.Lock() for id, entry := range p.sessions { - if entry.tag == tag { + if entry.tag == tag && entry.selectionScope == selectionScope { delete(p.sessions, id) } } p.sessionMutex.Unlock() } -func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) { - tag := cred.tagName() +func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + key := credentialInterruptKey{ + tag: cred.tagName(), + selectionScope: selection.scopeOrDefault(), + } p.interruptAccess.Lock() - entry, loaded := p.credentialInterrupts[tag] + entry, loaded := p.credentialInterrupts[key] if !loaded { ctx, cancel := context.WithCancel(context.Background()) entry = credentialInterruptEntry{context: ctx, cancel: cancel} - p.credentialInterrupts[tag] = entry + p.credentialInterrupts[key] = entry } p.interruptAccess.Unlock() - stop := context.AfterFunc(entry.context, func() { - requestContext.cancelOnce.Do(requestContext.cancelFunc) - }) - requestContext.addInterruptLink(stop) + return context.AfterFunc(entry.context, onInterrupt) } -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential { +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) if sessionID != "" { p.sessionMutex.Lock() @@ -1051,10 +1092,14 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese p.sessionMutex.Unlock() } - best := p.pickCredential(filter) + best := p.pickCredential(selection.filter) if best != nil && sessionID != "" { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selection.scopeOrDefault(), + createdAt: time.Now(), + } p.sessionMutex.Unlock() } return best @@ -1193,9 +1238,9 @@ func newFallbackProvider(credentials []credential, pollInterval time.Duration, l } } -func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { +func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) { for _, cred := range p.credentials { - if filter != nil && !filter(cred) { + if !selection.allows(cred) { continue } if !compositeCredentialSelectable(cred) { @@ -1208,10 +1253,10 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo return nil, false, allRateLimitedError(p.credentials) } -func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential { +func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) for _, candidate := range p.credentials { - if filter != nil && !filter(candidate) { + if !selection.allows(candidate) { continue } if !compositeCredentialSelectable(candidate) { @@ -1236,7 +1281,11 @@ func (p *fallbackProvider) allCredentials() []credential { return p.credentials } -func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} +func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} func (p *fallbackProvider) close() {} diff --git a/service/ocm/service.go b/service/ocm/service.go index cd7909dd49..071cec8ccb 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -75,7 +75,7 @@ const ( retryableUsageCode = "credential_usage_exhausted" ) -func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool { +func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool { if provider == nil || currentCredential == nil { return false } @@ -83,7 +83,7 @@ func hasAlternativeCredential(provider credentialProvider, currentCredential cre if cred == currentCredential { continue } - if filter != nil && !filter(cred) { + if !selection.allows(cred) { continue } if cred.isUsable() { @@ -117,16 +117,27 @@ func writeCredentialUnavailableError( r *http.Request, provider credentialProvider, currentCredential credential, - filter func(credential) bool, + selection credentialSelection, fallback string, ) { - if hasAlternativeCredential(provider, currentCredential, filter) { + if hasAlternativeCredential(provider, currentCredential, selection) { writeRetryableUsageError(w, r) return } writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback)) } +func credentialSelectionForUser(userConfig *option.OCMUser) credentialSelection { + selection := credentialSelection{scope: credentialSelectionScopeAll} + if userConfig != nil && !userConfig.AllowExternalUsage { + selection.scope = credentialSelectionScopeNonExternal + selection.filter = func(cred credential) bool { + return !cred.isExternal() + } + } + return selection +} + func isHopByHopHeader(header string) bool { switch strings.ToLower(header) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": @@ -426,19 +437,16 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { provider.pollIfStale(s.ctx) - var credentialFilter func(credential) bool - if userConfig != nil && !userConfig.AllowExternalUsage { - credentialFilter = func(c credential) bool { return !c.isExternal() } - } + selection := credentialSelectionForUser(userConfig) - selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter) + selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) if err != nil { writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) return } if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew) + s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew) return } @@ -500,7 +508,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } requestContext := selectedCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(selectedCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } defer func() { requestContext.cancelRequest() }() @@ -517,7 +530,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") return } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) @@ -528,19 +541,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Transparent 429 retry for response.StatusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete selectedCredential.updateStateFromHeaders(response.Header) if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { response.Body.Close() - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") return } response.Body.Close() s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(nextCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { s.logger.ErrorContext(ctx, "retry request: ", buildErr) @@ -553,7 +571,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") + writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") return } s.logger.ErrorContext(ctx, "retry request: ", retryErr) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index fcffaae96a..4b640d9c5c 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -26,16 +26,24 @@ import ( ) type webSocketSession struct { - clientConn net.Conn - upstreamConn net.Conn - credentialTag string - closeOnce sync.Once + clientConn net.Conn + upstreamConn net.Conn + credentialTag string + releaseProviderInterrupt func() + closeOnce sync.Once } func (s *webSocketSession) Close() { s.closeOnce.Do(func() { - s.clientConn.Close() - s.upstreamConn.Close() + if s.releaseProviderInterrupt != nil { + s.releaseProviderInterrupt() + } + if s.clientConn != nil { + s.clientConn.Close() + } + if s.upstreamConn != nil { + s.upstreamConn.Close() + } }) } @@ -91,12 +99,14 @@ func (s *Service) handleWebSocket( userConfig *option.OCMUser, provider credentialProvider, selectedCredential credential, - credentialFilter func(credential) bool, + selection credentialSelection, isNew bool, ) { var ( err error requestContext *credentialRequestContext + clientConn net.Conn + session *webSocketSession upstreamConn net.Conn upstreamBufferedReader *bufio.Reader upstreamResponseHeaders http.Header @@ -186,20 +196,36 @@ func (s *Service) handleWebSocket( } requestContext = selectedCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(selectedCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + if session != nil { + session.Close() + return + } + if clientConn != nil { + clientConn.Close() + } + if upstreamConn != nil { + upstreamConn.Close() + } + })) + } upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(requestContext, upstreamURL) if err == nil { - requestContext.releaseCredentialInterrupt() break } requestContext.cancelRequest() requestContext = nil + upstreamConn = nil + clientConn = nil if statusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) selectedCredential.updateStateFromHeaders(upstreamResponseHeaders) if nextCredential == nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") return } s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) @@ -236,16 +262,17 @@ func (s *Service) handleWebSocket( writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down") return } - clientConn, _, _, err := clientUpgrader.Upgrade(r, w) + clientConn, _, _, err = clientUpgrader.Upgrade(r, w) if err != nil { s.logger.ErrorContext(ctx, "upgrade client websocket: ", err) upstreamConn.Close() return } - session := &webSocketSession{ - clientConn: clientConn, - upstreamConn: upstreamConn, - credentialTag: selectedCredential.tagName(), + session = &webSocketSession{ + clientConn: clientConn, + upstreamConn: upstreamConn, + credentialTag: selectedCredential.tagName(), + releaseProviderInterrupt: requestContext.releaseCredentialInterrupt, } if !s.registerWebSocketSession(session) { session.Close() From 4d907bc49da94251f46b365c08e743c5063f1312 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 18:06:40 +0800 Subject: [PATCH 41/96] ccm,ocm: allow URL-based credentials to accept reverse connections Previously, findReceiverCredential required baseURL == reverseProxyBaseURL, so only credentials with no URL could accept incoming reverse connections. Now credentials with a normal URL also accept reverse connections, preferring the reverse session when active and falling back to the direct URL when not. --- service/ccm/credential_external.go | 41 +++++++++++++++++++--- service/ccm/reverse.go | 4 +-- service/ocm/credential_external.go | 55 +++++++++++++++++++++++++++--- service/ocm/reverse.go | 4 +-- 4 files changed, 90 insertions(+), 14 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index ade2d8361c..8d1c1a08d1 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -47,8 +47,9 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseSession *yamux.Session + reverse bool + reverseHttpClient *http.Client + reverseSession *yamux.Session reverseAccess sync.RWMutex closed bool reverseContext context.Context @@ -194,6 +195,14 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx } else { // Normal mode: standard HTTP client for proxying cred.httpClient = &http.Client{Transport: transport} + cred.reverseHttpClient = &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: false, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return cred.openReverseConnection(ctx) + }, + }, + } } } @@ -341,7 +350,14 @@ func (c *externalCredential) getAccessToken() (string, error) { } func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) { - proxyURL := c.baseURL + original.URL.RequestURI() + baseURL := c.baseURL + if c.reverseHttpClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + baseURL = reverseProxyBaseURL + } + } + proxyURL := baseURL + original.URL.RequestURI() var body io.Reader if bodyBytes != nil { body = bytes.NewReader(bodyBytes) @@ -489,9 +505,18 @@ func (c *externalCredential) pollUsage(ctx context.Context) { defer c.pollAccess.Unlock() defer c.markUsagePollAttempted() - statusURL := c.baseURL + "/ccm/v1/status" + activeBaseURL := c.baseURL + activeTransport := c.httpClient.Transport + if c.reverseHttpClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + activeBaseURL = reverseProxyBaseURL + activeTransport = c.reverseHttpClient.Transport + } + } + statusURL := activeBaseURL + "/ccm/v1/status" httpClient := &http.Client{ - Transport: c.httpClient.Transport, + Transport: activeTransport, Timeout: 5 * time.Second, } @@ -602,6 +627,12 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { } func (c *externalCredential) httpTransport() *http.Client { + if c.reverseHttpClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + return c.reverseHttpClient + } + } return c.httpClient } diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 62a1011171..6ecc224f9a 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -126,10 +126,10 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite func (s *Service) findReceiverCredential(token string) *externalCredential { for _, cred := range s.allCredentials { extCred, ok := cred.(*externalCredential) - if !ok { + if !ok || extCred.connectorURL != nil { continue } - if extCred.baseURL == reverseProxyBaseURL && extCred.token == token { + if extCred.token == token { return extCred } } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 3d46f8fa89..f09716db3a 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -49,8 +49,10 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseSession *yamux.Session + reverse bool + reverseHttpClient *http.Client + reverseCredDialer N.Dialer + reverseSession *yamux.Session reverseAccess sync.RWMutex closed bool reverseContext context.Context @@ -213,6 +215,15 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx // Normal mode: standard HTTP client for proxying cred.credDialer = credentialDialer cred.httpClient = &http.Client{Transport: transport} + cred.reverseCredDialer = reverseSessionDialer{credential: cred} + cred.reverseHttpClient = &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: false, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return cred.openReverseConnection(ctx) + }, + }, + } } } @@ -363,7 +374,14 @@ func (c *externalCredential) getAccessToken() (string, error) { } func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) { - proxyURL := c.baseURL + original.URL.RequestURI() + baseURL := c.baseURL + if c.reverseHttpClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + baseURL = reverseProxyBaseURL + } + } + proxyURL := baseURL + original.URL.RequestURI() var body io.Reader if bodyBytes != nil { body = bytes.NewReader(bodyBytes) @@ -526,9 +544,18 @@ func (c *externalCredential) pollUsage(ctx context.Context) { defer c.pollAccess.Unlock() defer c.markUsagePollAttempted() - statusURL := c.baseURL + "/ocm/v1/status" + activeBaseURL := c.baseURL + activeTransport := c.httpClient.Transport + if c.reverseHttpClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + activeBaseURL = reverseProxyBaseURL + activeTransport = c.reverseHttpClient.Transport + } + } + statusURL := activeBaseURL + "/ocm/v1/status" httpClient := &http.Client{ - Transport: c.httpClient.Transport, + Transport: activeTransport, Timeout: 5 * time.Second, } @@ -639,10 +666,22 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { } func (c *externalCredential) httpTransport() *http.Client { + if c.reverseHttpClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + return c.reverseHttpClient + } + } return c.httpClient } func (c *externalCredential) ocmDialer() N.Dialer { + if c.reverseCredDialer != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + return c.reverseCredDialer + } + } return c.credDialer } @@ -655,6 +694,12 @@ func (c *externalCredential) ocmGetAccountID() string { } func (c *externalCredential) ocmGetBaseURL() string { + if c.reverseHttpClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + return reverseProxyBaseURL + } + } return c.baseURL } diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index 1ed274f6d2..ab99c77a6e 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -126,10 +126,10 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite func (s *Service) findReceiverCredential(token string) *externalCredential { for _, cred := range s.allCredentials { extCred, ok := cred.(*externalCredential) - if !ok { + if !ok || extCred.connectorURL != nil { continue } - if extCred.baseURL == reverseProxyBaseURL && extCred.token == token { + if extCred.token == token { return extCred } } From d1e5426bc8a9b08abf2937ab50698ad36b7afcfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 18:26:51 +0800 Subject: [PATCH 42/96] ccm,ocm: add exponential backoff with cap for poll retry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace flat 1-minute poll retry interval with exponential backoff (1m → 2m → 4m → 5m cap). Suppress error logs after reaching the cap. --- service/ccm/credential_external.go | 17 +++++++++++++++-- service/ccm/credential_state.go | 22 +++++++++++++++++++--- service/ocm/credential_external.go | 17 +++++++++++++++-- service/ocm/credential_state.go | 22 +++++++++++++++++++--- 4 files changed, 68 insertions(+), 10 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 8d1c1a08d1..e8fe31799d 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -529,7 +529,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { return request, nil }) if err != nil { - c.logger.Error("poll usage for ", c.tag, ": ", err) + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } c.incrementPollFailures() return } @@ -609,7 +611,18 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati if failures <= 0 { return baseInterval } - return failedPollRetryInterval + backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) + if backoff > httpRetryMaxBackoff { + return httpRetryMaxBackoff + } + return backoff +} + +func (c *externalCredential) isPollBackoffAtCap() bool { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + failures := c.state.consecutivePollFailures + return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff } func (c *externalCredential) incrementPollFailures() { diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 6b1a766f2c..81d559f3c5 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -29,6 +29,7 @@ import ( const ( defaultPollInterval = 60 * time.Minute failedPollRetryInterval = time.Minute + httpRetryMaxBackoff = 5 * time.Minute ) const ( @@ -583,7 +584,18 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio if failures <= 0 { return baseInterval } - return failedPollRetryInterval + backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) + if backoff > httpRetryMaxBackoff { + return httpRetryMaxBackoff + } + return backoff +} + +func (c *defaultCredential) isPollBackoffAtCap() bool { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + failures := c.state.consecutivePollFailures + return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff } func (c *defaultCredential) earliestReset() time.Time { @@ -616,7 +628,9 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { accessToken, err := c.getAccessToken() if err != nil { - c.logger.Error("poll usage for ", c.tag, ": get token: ", err) + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": get token: ", err) + } c.incrementPollFailures() return } @@ -638,7 +652,9 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { return request, nil }) if err != nil { - c.logger.Error("poll usage for ", c.tag, ": ", err) + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } c.incrementPollFailures() return } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index f09716db3a..0b60cff1eb 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -568,7 +568,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { return request, nil }) if err != nil { - c.logger.Error("poll usage for ", c.tag, ": ", err) + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } c.incrementPollFailures() return } @@ -648,7 +650,18 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati if failures <= 0 { return baseInterval } - return failedPollRetryInterval + backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) + if backoff > httpRetryMaxBackoff { + return httpRetryMaxBackoff + } + return backoff +} + +func (c *externalCredential) isPollBackoffAtCap() bool { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + failures := c.state.consecutivePollFailures + return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff } func (c *externalCredential) incrementPollFailures() { diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index d8f2e826ac..58b7344654 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -29,6 +29,7 @@ import ( const ( defaultPollInterval = 60 * time.Minute failedPollRetryInterval = time.Minute + httpRetryMaxBackoff = 5 * time.Minute ) const ( @@ -597,7 +598,18 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio if failures <= 0 { return baseInterval } - return failedPollRetryInterval + backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) + if backoff > httpRetryMaxBackoff { + return httpRetryMaxBackoff + } + return backoff +} + +func (c *defaultCredential) isPollBackoffAtCap() bool { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + failures := c.state.consecutivePollFailures + return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff } func (c *defaultCredential) earliestReset() time.Time { @@ -633,7 +645,9 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { accessToken, err := c.getAccessToken() if err != nil { - c.logger.Error("poll usage for ", c.tag, ": get token: ", err) + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": get token: ", err) + } c.incrementPollFailures() return } @@ -663,7 +677,9 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { return request, nil }) if err != nil { - c.logger.Error("poll usage for ", c.tag, ": ", err) + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } c.incrementPollFailures() return } From 4d8baf71750e715cc5544a7e2c2066bb534ae635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 18:36:00 +0800 Subject: [PATCH 43/96] ccm: fix nil pointer in pollUsage for connector-mode credentials Connector-mode credentials (URL + reverse: true) never assigned httpClient, causing a nil dereference when pollUsage accessed httpClient.Transport. Also extract poll request logic into doPollUsageRequest to try reverse transport first (single attempt), then fall back to forward transport with retries if the reverse session disconnects. --- service/ccm/credential_external.go | 64 ++++++++++++++++++++---------- service/ocm/credential_external.go | 64 ++++++++++++++++++++---------- 2 files changed, 84 insertions(+), 44 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index e8fe31799d..97da6bd9ec 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -192,6 +192,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx Time: ntp.TimeFuncFromContext(ctx), } } + cred.httpClient = &http.Client{Transport: transport} } else { // Normal mode: standard HTTP client for proxying cred.httpClient = &http.Client{Transport: transport} @@ -498,36 +499,55 @@ func (c *externalCredential) interruptConnections() { } } -func (c *externalCredential) pollUsage(ctx context.Context) { - if !c.pollAccess.TryLock() { - return +func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Response, error) { + buildRequest := func(baseURL string) func() (*http.Request, error) { + return func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ccm/v1/status", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+c.token) + return request, nil + } } - defer c.pollAccess.Unlock() - defer c.markUsagePollAttempted() - - activeBaseURL := c.baseURL - activeTransport := c.httpClient.Transport + // Try reverse transport first (single attempt, no retry) if c.reverseHttpClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { - activeBaseURL = reverseProxyBaseURL - activeTransport = c.reverseHttpClient.Transport + request, err := buildRequest(reverseProxyBaseURL)() + if err != nil { + return nil, err + } + reverseClient := &http.Client{ + Transport: c.reverseHttpClient.Transport, + Timeout: 5 * time.Second, + } + response, err := reverseClient.Do(request) + if err == nil { + return response, nil + } + // Reverse failed, fall through to forward if available } } - statusURL := activeBaseURL + "/ccm/v1/status" - httpClient := &http.Client{ - Transport: activeTransport, - Timeout: 5 * time.Second, + // Forward transport with retries + if c.httpClient != nil { + forwardClient := &http.Client{ + Transport: c.httpClient.Transport, + Timeout: 5 * time.Second, + } + return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL)) } + return nil, E.New("no transport available") +} - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+c.token) - return request, nil - }) +func (c *externalCredential) pollUsage(ctx context.Context) { + if !c.pollAccess.TryLock() { + return + } + defer c.pollAccess.Unlock() + defer c.markUsagePollAttempted() + + response, err := c.doPollUsageRequest(ctx) if err != nil { if !c.isPollBackoffAtCap() { c.logger.Error("poll usage for ", c.tag, ": ", err) diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 0b60cff1eb..fbe8f11c60 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -211,6 +211,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx Time: ntp.TimeFuncFromContext(ctx), } } + cred.httpClient = &http.Client{Transport: transport} } else { // Normal mode: standard HTTP client for proxying cred.credDialer = credentialDialer @@ -537,36 +538,55 @@ func (c *externalCredential) interruptConnections() { } } -func (c *externalCredential) pollUsage(ctx context.Context) { - if !c.pollAccess.TryLock() { - return +func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Response, error) { + buildRequest := func(baseURL string) func() (*http.Request, error) { + return func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ocm/v1/status", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+c.token) + return request, nil + } } - defer c.pollAccess.Unlock() - defer c.markUsagePollAttempted() - - activeBaseURL := c.baseURL - activeTransport := c.httpClient.Transport + // Try reverse transport first (single attempt, no retry) if c.reverseHttpClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { - activeBaseURL = reverseProxyBaseURL - activeTransport = c.reverseHttpClient.Transport + request, err := buildRequest(reverseProxyBaseURL)() + if err != nil { + return nil, err + } + reverseClient := &http.Client{ + Transport: c.reverseHttpClient.Transport, + Timeout: 5 * time.Second, + } + response, err := reverseClient.Do(request) + if err == nil { + return response, nil + } + // Reverse failed, fall through to forward if available } } - statusURL := activeBaseURL + "/ocm/v1/status" - httpClient := &http.Client{ - Transport: activeTransport, - Timeout: 5 * time.Second, + // Forward transport with retries + if c.httpClient != nil { + forwardClient := &http.Client{ + Transport: c.httpClient.Transport, + Timeout: 5 * time.Second, + } + return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL)) } + return nil, E.New("no transport available") +} - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+c.token) - return request, nil - }) +func (c *externalCredential) pollUsage(ctx context.Context) { + if !c.pollAccess.TryLock() { + return + } + defer c.pollAccess.Unlock() + defer c.markUsagePollAttempted() + + response, err := c.doPollUsageRequest(ctx) if err != nil { if !c.isPollBackoffAtCap() { c.logger.Error("poll usage for ", c.tag, ": ", err) From 51d564c9ff92aa6c86f3da5d8b29e8b892dcf4d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 19:08:32 +0800 Subject: [PATCH 44/96] ccm,ocm: merge fallback into balancer strategy, use hyphenated constant names Merge the fallback credential type into balancer as a strategy (C.BalancerStrategyFallback). Replace raw string literals with C.BalancerStrategyXxx constants and switch to hyphens (least-used, round-robin) per project convention. --- constant/proxy.go | 7 ++ option/ccm.go | 10 --- option/ocm.go | 10 --- service/ccm/credential_external.go | 6 +- service/ccm/credential_state.go | 111 +++++++++----------------- service/ocm/credential_external.go | 8 +- service/ocm/credential_state.go | 121 ++++++++++------------------- 7 files changed, 88 insertions(+), 185 deletions(-) diff --git a/constant/proxy.go b/constant/proxy.go index 278a46c2f6..d46fc0f925 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -38,6 +38,13 @@ const ( TypeURLTest = "urltest" ) +const ( + BalancerStrategyLeastUsed = "least-used" + BalancerStrategyRoundRobin = "round-robin" + BalancerStrategyRandom = "random" + BalancerStrategyFallback = "fallback" +) + func ProxyDisplayName(proxyType string) string { switch proxyType { case TypeTun: diff --git a/option/ccm.go b/option/ccm.go index 6d9d14423c..7a4f0709f3 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -32,7 +32,6 @@ type _CCMCredential struct { DefaultOptions CCMDefaultCredentialOptions `json:"-"` ExternalOptions CCMExternalCredentialOptions `json:"-"` BalancerOptions CCMBalancerCredentialOptions `json:"-"` - FallbackOptions CCMFallbackCredentialOptions `json:"-"` } type CCMCredential _CCMCredential @@ -47,8 +46,6 @@ func (c CCMCredential) MarshalJSON() ([]byte, error) { v = c.ExternalOptions case "balancer": v = c.BalancerOptions - case "fallback": - v = c.FallbackOptions default: return nil, E.New("unknown credential type: ", c.Type) } @@ -72,8 +69,6 @@ func (c *CCMCredential) UnmarshalJSON(bytes []byte) error { v = &c.ExternalOptions case "balancer": v = &c.BalancerOptions - case "fallback": - v = &c.FallbackOptions default: return E.New("unknown credential type: ", c.Type) } @@ -106,8 +101,3 @@ type CCMExternalCredentialOptions struct { UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } - -type CCMFallbackCredentialOptions struct { - Credentials badoption.Listable[string] `json:"credentials"` - PollInterval badoption.Duration `json:"poll_interval,omitempty"` -} diff --git a/option/ocm.go b/option/ocm.go index 6015b506cd..af937560b8 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -32,7 +32,6 @@ type _OCMCredential struct { DefaultOptions OCMDefaultCredentialOptions `json:"-"` ExternalOptions OCMExternalCredentialOptions `json:"-"` BalancerOptions OCMBalancerCredentialOptions `json:"-"` - FallbackOptions OCMFallbackCredentialOptions `json:"-"` } type OCMCredential _OCMCredential @@ -47,8 +46,6 @@ func (c OCMCredential) MarshalJSON() ([]byte, error) { v = c.ExternalOptions case "balancer": v = c.BalancerOptions - case "fallback": - v = c.FallbackOptions default: return nil, E.New("unknown credential type: ", c.Type) } @@ -72,8 +69,6 @@ func (c *OCMCredential) UnmarshalJSON(bytes []byte) error { v = &c.ExternalOptions case "balancer": v = &c.BalancerOptions - case "fallback": - v = &c.FallbackOptions default: return E.New("unknown credential type: ", c.Type) } @@ -106,8 +101,3 @@ type OCMExternalCredentialOptions struct { UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } - -type OCMFallbackCredentialOptions struct { - Credentials badoption.Listable[string] `json:"credentials"` - PollInterval badoption.Duration `json:"poll_interval,omitempty"` -} diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 97da6bd9ec..b7e04bad34 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -47,9 +47,9 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseHttpClient *http.Client - reverseSession *yamux.Session + reverse bool + reverseHttpClient *http.Client + reverseSession *yamux.Session reverseAccess sync.RWMutex closed bool reverseContext context.Context diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 81d559f3c5..b07529eb0c 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -19,6 +19,7 @@ import ( "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" @@ -1015,6 +1016,14 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval } func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if p.strategy == C.BalancerStrategyFallback { + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + return best, false, nil + } + selectionScope := selection.scopeOrDefault() if sessionID != "" { p.sessionMutex.RLock() @@ -1024,7 +1033,7 @@ func (p *balancerProvider) selectCredential(sessionID string, selection credenti if entry.selectionScope == selectionScope { for _, cred := range p.credentials { if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { better := p.pickLeastUsed(selection.filter) if better != nil && better.tagName() != cred.tagName() { effectiveThreshold := p.rebalanceThreshold / cred.planWeight() @@ -1086,6 +1095,9 @@ func (p *balancerProvider) rebalanceCredential(tag string, selectionScope creden } func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + if p.strategy == C.BalancerStrategyFallback { + return func() bool { return false } + } key := credentialInterruptKey{ tag: cred.tagName(), selectionScope: selection.scopeOrDefault(), @@ -1103,6 +1115,9 @@ func (p *balancerProvider) linkProviderInterrupt(cred credential, selection cred func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) + if p.strategy == C.BalancerStrategyFallback { + return p.pickCredential(selection.filter) + } if sessionID != "" { p.sessionMutex.Lock() delete(p.sessions, sessionID) @@ -1124,15 +1139,29 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { switch p.strategy { - case "round_robin": + case C.BalancerStrategyRoundRobin: return p.pickRoundRobin(filter) - case "random": + case C.BalancerStrategyRandom: return p.pickRandom(filter) + case C.BalancerStrategyFallback: + return p.pickFallback(filter) default: return p.pickLeastUsed(filter) } } +func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if cred.isUsable() { + return cred + } + } + return nil +} + func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { var best credential bestScore := float64(-1) @@ -1239,69 +1268,6 @@ func (p *balancerProvider) allCredentials() []credential { func (p *balancerProvider) close() {} -// fallbackProvider tries credentials in order. -type fallbackProvider struct { - credentials []credential - pollInterval time.Duration - logger log.ContextLogger -} - -func newFallbackProvider(credentials []credential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { - if pollInterval <= 0 { - pollInterval = defaultPollInterval - } - return &fallbackProvider{ - credentials: credentials, - pollInterval: pollInterval, - logger: logger, - } -} - -func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) { - for _, cred := range p.credentials { - if !selection.allows(cred) { - continue - } - if cred.isUsable() { - return cred, false, nil - } - } - return nil, false, allCredentialsUnavailableError(p.credentials) -} - -func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential { - cred.markRateLimited(resetAt) - for _, candidate := range p.credentials { - if !selection.allows(candidate) { - continue - } - if candidate.isUsable() { - return candidate - } - } - return nil -} - -func (p *fallbackProvider) pollIfStale(ctx context.Context) { - for _, cred := range p.credentials { - if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { - cred.pollUsage(ctx) - } - } -} - -func (p *fallbackProvider) allCredentials() []credential { - return p.credentials -} - -func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { - return func() bool { - return false - } -} - -func (p *fallbackProvider) close() {} - func allCredentialsUnavailableError(credentials []credential) error { var hasUnavailable bool var earliest time.Time @@ -1373,21 +1339,14 @@ func buildCredentialProviders( } } - // Pass 2: create balancer and fallback providers + // Pass 2: create balancer providers for _, credOpt := range options.Credentials { - switch credOpt.Type { - case "balancer": + if credOpt.Type == "balancer" { subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) if err != nil { return nil, nil, err } providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) - case "fallback": - subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag) - if err != nil { - return nil, nil, err - } - providers[credOpt.Tag] = newFallbackProvider(subCredentials, time.Duration(credOpt.FallbackOptions.PollInterval), logger) } } @@ -1476,7 +1435,7 @@ func validateCCMOptions(options option.CCMServiceOptions) error { } if cred.Type == "balancer" { switch cred.BalancerOptions.Strategy { - case "", "least_used", "round_robin", "random": + case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: default: return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index fbe8f11c60..0e0556be71 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -49,10 +49,10 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseHttpClient *http.Client - reverseCredDialer N.Dialer - reverseSession *yamux.Session + reverse bool + reverseHttpClient *http.Client + reverseCredDialer N.Dialer + reverseSession *yamux.Session reverseAccess sync.RWMutex closed bool reverseContext context.Context diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 58b7344654..181132f09d 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -18,6 +18,7 @@ import ( "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" @@ -1014,6 +1015,14 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval } func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if p.strategy == C.BalancerStrategyFallback { + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + return best, false, nil + } + selectionScope := selection.scopeOrDefault() if sessionID != "" { p.sessionMutex.RLock() @@ -1023,7 +1032,7 @@ func (p *balancerProvider) selectCredential(sessionID string, selection credenti if entry.selectionScope == selectionScope { for _, cred := range p.credentials { if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { better := p.pickLeastUsed(selection.filter) if better != nil && better.tagName() != cred.tagName() { effectiveThreshold := p.rebalanceThreshold / cred.planWeight() @@ -1085,6 +1094,9 @@ func (p *balancerProvider) rebalanceCredential(tag string, selectionScope creden } func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + if p.strategy == C.BalancerStrategyFallback { + return func() bool { return false } + } key := credentialInterruptKey{ tag: cred.tagName(), selectionScope: selection.scopeOrDefault(), @@ -1102,6 +1114,9 @@ func (p *balancerProvider) linkProviderInterrupt(cred credential, selection cred func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) + if p.strategy == C.BalancerStrategyFallback { + return p.pickCredential(selection.filter) + } if sessionID != "" { p.sessionMutex.Lock() delete(p.sessions, sessionID) @@ -1123,15 +1138,32 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { switch p.strategy { - case "round_robin": + case C.BalancerStrategyRoundRobin: return p.pickRoundRobin(filter) - case "random": + case C.BalancerStrategyRandom: return p.pickRandom(filter) + case C.BalancerStrategyFallback: + return p.pickFallback(filter) default: return p.pickLeastUsed(filter) } } +func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !compositeCredentialSelectable(cred) { + continue + } + if cred.isUsable() { + return cred + } + } + return nil +} + func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { var best credential bestScore := float64(-1) @@ -1237,74 +1269,6 @@ func (p *balancerProvider) allCredentials() []credential { func (p *balancerProvider) close() {} -type fallbackProvider struct { - credentials []credential - pollInterval time.Duration - logger log.ContextLogger -} - -func newFallbackProvider(credentials []credential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { - if pollInterval <= 0 { - pollInterval = defaultPollInterval - } - return &fallbackProvider{ - credentials: credentials, - pollInterval: pollInterval, - logger: logger, - } -} - -func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) { - for _, cred := range p.credentials { - if !selection.allows(cred) { - continue - } - if !compositeCredentialSelectable(cred) { - continue - } - if cred.isUsable() { - return cred, false, nil - } - } - return nil, false, allRateLimitedError(p.credentials) -} - -func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential { - cred.markRateLimited(resetAt) - for _, candidate := range p.credentials { - if !selection.allows(candidate) { - continue - } - if !compositeCredentialSelectable(candidate) { - continue - } - if candidate.isUsable() { - return candidate - } - } - return nil -} - -func (p *fallbackProvider) pollIfStale(ctx context.Context) { - for _, cred := range p.credentials { - if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { - cred.pollUsage(ctx) - } - } -} - -func (p *fallbackProvider) allCredentials() []credential { - return p.credentials -} - -func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { - return func() bool { - return false - } -} - -func (p *fallbackProvider) close() {} - func allRateLimitedError(credentials []credential) error { var hasUnavailable bool var earliest time.Time @@ -1358,21 +1322,14 @@ func buildOCMCredentialProviders( } } - // Pass 2: create balancer and fallback providers + // Pass 2: create balancer providers for _, credOpt := range options.Credentials { - switch credOpt.Type { - case "balancer": + if credOpt.Type == "balancer" { subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) if err != nil { return nil, nil, err } providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) - case "fallback": - subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag) - if err != nil { - return nil, nil, err - } - providers[credOpt.Tag] = newFallbackProvider(subCredentials, time.Duration(credOpt.FallbackOptions.PollInterval), logger) } } @@ -1469,7 +1426,7 @@ func validateOCMOptions(options option.OCMServiceOptions) error { } if cred.Type == "balancer" { switch cred.BalancerOptions.Strategy { - case "", "least_used", "round_robin", "random": + case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: default: return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) } @@ -1505,7 +1462,7 @@ func validateOCMCompositeCredentialModes( providers map[string]credentialProvider, ) error { for _, credOpt := range options.Credentials { - if credOpt.Type != "balancer" && credOpt.Type != "fallback" { + if credOpt.Type != "balancer" { continue } From 04bd63b45573eb4f81153a3158dc9f6de8f50155 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 20:17:23 +0800 Subject: [PATCH 45/96] ccm,ocm: reorganize files and improve naming conventions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split credential_state.go (1500+ lines) into credential.go, credential_default.go, credential_provider.go, credential_builder.go. Split service.go (900+ lines) into service.go, service_handler.go, service_status.go. Rename credential.go to credential_oauth.go to avoid name conflict with the credential interface. Apply naming fixes: accessMutex→access, stateMutex→stateAccess, sessionMutex→sessionAccess, webSocketMutex→webSocketAccess, httpTransport()→httpClient(), httpClient field→forwardHTTPClient, weeklyWindowDuration→weeklyWindowHours. --- service/ccm/credential.go | 313 +++--- service/ccm/credential_builder.go | 192 ++++ service/ccm/credential_default.go | 726 +++++++++++++ service/ccm/credential_external.go | 104 +- service/ccm/credential_file.go | 28 +- service/ccm/credential_oauth.go | 224 ++++ service/ccm/credential_provider.go | 405 ++++++++ service/ccm/credential_state.go | 1506 --------------------------- service/ccm/service.go | 570 +---------- service/ccm/service_handler.go | 499 +++++++++ service/ccm/service_status.go | 109 ++ service/ccm/service_user.go | 10 +- service/ocm/credential.go | 325 +++--- service/ocm/credential_builder.go | 223 ++++ service/ocm/credential_default.go | 749 ++++++++++++++ service/ocm/credential_external.go | 106 +- service/ocm/credential_file.go | 28 +- service/ocm/credential_oauth.go | 225 ++++ service/ocm/credential_provider.go | 411 ++++++++ service/ocm/credential_state.go | 1524 ---------------------------- service/ocm/service.go | 658 +----------- service/ocm/service_handler.go | 504 +++++++++ service/ocm/service_status.go | 114 +++ service/ocm/service_user.go | 10 +- 24 files changed, 4832 insertions(+), 4731 deletions(-) create mode 100644 service/ccm/credential_builder.go create mode 100644 service/ccm/credential_default.go create mode 100644 service/ccm/credential_oauth.go create mode 100644 service/ccm/credential_provider.go delete mode 100644 service/ccm/credential_state.go create mode 100644 service/ccm/service_handler.go create mode 100644 service/ccm/service_status.go create mode 100644 service/ocm/credential_builder.go create mode 100644 service/ocm/credential_default.go create mode 100644 service/ocm/credential_oauth.go create mode 100644 service/ocm/credential_provider.go delete mode 100644 service/ocm/credential_state.go create mode 100644 service/ocm/service_handler.go create mode 100644 service/ocm/service_status.go diff --git a/service/ccm/credential.go b/service/ccm/credential.go index da559c173d..8589676a8c 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -1,224 +1,187 @@ package ccm import ( - "bytes" "context" - "encoding/json" - "io" "net/http" - "os" - "os/user" - "path/filepath" - "runtime" - "slices" + "strconv" "sync" "time" - - "github.com/sagernet/sing-box/log" - E "github.com/sagernet/sing/common/exceptions" ) const ( - oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" - oauth2TokenURL = "https://platform.claude.com/v1/oauth/token" - claudeAPIBaseURL = "https://api.anthropic.com" - tokenRefreshBufferMs = 60000 - anthropicBetaOAuthValue = "oauth-2025-04-20" + defaultPollInterval = 60 * time.Minute + failedPollRetryInterval = time.Minute + httpRetryMaxBackoff = 5 * time.Minute ) -const ccmUserAgentFallback = "claude-code/2.1.72" - -var ( - ccmUserAgentOnce sync.Once - ccmUserAgentValue string +const ( + httpRetryMaxAttempts = 3 + httpRetryInitialDelay = 200 * time.Millisecond ) -func initCCMUserAgent(logger log.ContextLogger) { - ccmUserAgentOnce.Do(func() { - version, err := detectClaudeCodeVersion() +const sessionExpiry = 24 * time.Hour + +func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { + var lastError error + for attempt := range httpRetryMaxAttempts { + if attempt > 0 { + delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + select { + case <-ctx.Done(): + return nil, lastError + case <-time.After(delay): + } + } + request, err := buildRequest() if err != nil { - logger.Error("detect Claude Code version: ", err) - ccmUserAgentValue = ccmUserAgentFallback - return + return nil, err } - logger.Debug("detected Claude Code version: ", version) - ccmUserAgentValue = "claude-code/" + version - }) -} - -func detectClaudeCodeVersion() (string, error) { - userInfo, err := getRealUser() - if err != nil { - return "", E.Cause(err, "get user") - } - binaryName := "claude" - if runtime.GOOS == "windows" { - binaryName = "claude.exe" - } - linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName) - target, err := os.Readlink(linkPath) - if err != nil { - return "", E.Cause(err, "readlink ", linkPath) - } - if !filepath.IsAbs(target) { - target = filepath.Join(filepath.Dir(linkPath), target) - } - parent := filepath.Base(filepath.Dir(target)) - if parent != "versions" { - return "", E.New("unexpected symlink target: ", target) - } - return filepath.Base(target), nil -} - -func getRealUser() (*user.User, error) { - if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { - sudoUserInfo, err := user.Lookup(sudoUser) + response, err := client.Do(request) if err == nil { - return sudoUserInfo, nil + return response, nil + } + lastError = err + if ctx.Err() != nil { + return nil, lastError } } - return user.Current() + return nil, lastError } -func getDefaultCredentialsPath() (string, error) { - if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" { - return filepath.Join(configDir, ".credentials.json"), nil - } - userInfo, err := getRealUser() - if err != nil { - return "", err - } - return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil +type credentialState struct { + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + rateLimitTier string + remotePlanWeight float64 + lastUpdated time.Time + consecutivePollFailures int + unavailable bool + lastCredentialLoadAttempt time.Time + lastCredentialLoadError string } -func readCredentialsFromFile(path string) (*oauthCredentials, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var credentialsContainer struct { - ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"` - } - err = json.Unmarshal(data, &credentialsContainer) - if err != nil { - return nil, err - } - if credentialsContainer.ClaudeAIAuth == nil { - return nil, E.New("claudeAiOauth field not found in credentials") - } - return credentialsContainer.ClaudeAIAuth, nil +type credentialRequestContext struct { + context.Context + releaseOnce sync.Once + cancelOnce sync.Once + releaseFuncs []func() bool + cancelFunc context.CancelFunc } -func checkCredentialFileWritable(path string) error { - file, err := os.OpenFile(path, os.O_WRONLY, 0) - if err != nil { - return err - } - return file.Close() +func (c *credentialRequestContext) addInterruptLink(stop func() bool) { + c.releaseFuncs = append(c.releaseFuncs, stop) } -func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error { - data, err := json.MarshalIndent(map[string]any{ - "claudeAiOauth": oauthCredentials, - }, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, data, 0o600) +func (c *credentialRequestContext) releaseCredentialInterrupt() { + c.releaseOnce.Do(func() { + for _, f := range c.releaseFuncs { + f() + } + }) } -type oauthCredentials struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` - ExpiresAt int64 `json:"expiresAt"` - Scopes []string `json:"scopes,omitempty"` - SubscriptionType string `json:"subscriptionType,omitempty"` - RateLimitTier string `json:"rateLimitTier,omitempty"` - IsMax bool `json:"isMax,omitempty"` +func (c *credentialRequestContext) cancelRequest() { + c.releaseCredentialInterrupt() + c.cancelOnce.Do(c.cancelFunc) } -func (c *oauthCredentials) needsRefresh() bool { - if c.ExpiresAt == 0 { - return false - } - return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs +type credential interface { + tagName() string + isAvailable() bool + isUsable() bool + isExternal() bool + fiveHourUtilization() float64 + weeklyUtilization() float64 + fiveHourCap() float64 + weeklyCap() float64 + planWeight() float64 + weeklyResetTime() time.Time + markRateLimited(resetAt time.Time) + earliestReset() time.Time + unavailableError() error + + getAccessToken() (string, error) + buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) + updateStateFromHeaders(header http.Header) + + wrapRequestContext(ctx context.Context) *credentialRequestContext + interruptConnections() + + start() error + pollUsage(ctx context.Context) + lastUpdatedTime() time.Time + pollBackoff(base time.Duration) time.Duration + usageTrackerOrNil() *AggregatedUsage + httpClient() *http.Client + close() } -func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { - if credentials.RefreshToken == "" { - return nil, E.New("refresh token is empty") - } +type credentialSelectionScope string - requestBody, err := json.Marshal(map[string]string{ - "grant_type": "refresh_token", - "refresh_token": credentials.RefreshToken, - "client_id": oauth2ClientID, - }) - if err != nil { - return nil, E.Cause(err, "marshal request") - } +const ( + credentialSelectionScopeAll credentialSelectionScope = "all" + credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" +) - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, err - } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - return request, nil - }) - if err != nil { - return nil, err - } - defer response.Body.Close() +type credentialSelection struct { + scope credentialSelectionScope + filter func(credential) bool +} - if response.StatusCode == http.StatusTooManyRequests { - body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) - } - if response.StatusCode != http.StatusOK { - body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh failed: ", response.Status, " ", string(body)) - } +func (s credentialSelection) allows(cred credential) bool { + return s.filter == nil || s.filter(cred) +} - var tokenResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` +func (s credentialSelection) scopeOrDefault() credentialSelectionScope { + if s.scope == "" { + return credentialSelectionScopeAll } - err = json.NewDecoder(response.Body).Decode(&tokenResponse) + return s.scope +} + +// Claude Code's unified rate-limit handling parses these reset headers with +// Number(...), compares them against Date.now()/1000, and renders them via +// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds. +func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time { + unixEpoch, err := strconv.ParseInt(headerValue, 10, 64) if err != nil { - return nil, E.Cause(err, "decode response") + panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue)) } - - newCredentials := *credentials - newCredentials.AccessToken = tokenResponse.AccessToken - if tokenResponse.RefreshToken != "" { - newCredentials.RefreshToken = tokenResponse.RefreshToken + if unixEpoch <= 0 { + panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue)) } - newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000 + return time.Unix(unixEpoch, 0) +} - return &newCredentials, nil +func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) { + headerValue := headers.Get(headerName) + if headerValue == "" { + return time.Time{}, false + } + return parseAnthropicResetHeaderValue(headerName, headerValue), true } -func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { - if credentials == nil { - return nil +func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time { + headerValue := headers.Get(headerName) + if headerValue == "" { + panic("missing required " + headerName + " header") } - cloned := *credentials - cloned.Scopes = append([]string(nil), credentials.Scopes...) - return &cloned + return parseAnthropicResetHeaderValue(headerName, headerValue) } -func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { - if left == nil || right == nil { - return left == right +func parseRateLimitResetFromHeaders(headers http.Header) time.Time { + claim := headers.Get("anthropic-ratelimit-unified-representative-claim") + switch claim { + case "5h": + return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset") + case "7d": + return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") + default: + panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim)) } - return left.AccessToken == right.AccessToken && - left.RefreshToken == right.RefreshToken && - left.ExpiresAt == right.ExpiresAt && - slices.Equal(left.Scopes, right.Scopes) && - left.SubscriptionType == right.SubscriptionType && - left.RateLimitTier == right.RateLimitTier && - left.IsMax == right.IsMax } diff --git a/service/ccm/credential_builder.go b/service/ccm/credential_builder.go new file mode 100644 index 0000000000..c49a201950 --- /dev/null +++ b/service/ccm/credential_builder.go @@ -0,0 +1,192 @@ +package ccm + +import ( + "context" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func buildCredentialProviders( + ctx context.Context, + options option.CCMServiceOptions, + logger log.ContextLogger, +) (map[string]credentialProvider, []credential, error) { + allCredentialMap := make(map[string]credential) + var allCreds []credential + providers := make(map[string]credentialProvider) + + // Pass 1: create default and external credentials + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "default": + cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + case "external": + cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + } + } + + // Pass 2: create balancer providers + for _, credOpt := range options.Credentials { + if credOpt.Type == "balancer" { + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) + } + } + + return providers, allCreds, nil +} + +func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { + credentials := make([]credential, 0, len(tags)) + for _, tag := range tags { + cred, exists := allCredentials[tag] + if !exists { + return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) + } + credentials = append(credentials, cred) + } + if len(credentials) == 0 { + return nil, E.New("credential ", parentTag, " has no sub-credentials") + } + return credentials, nil +} + +func validateCCMOptions(options option.CCMServiceOptions) error { + hasCredentials := len(options.Credentials) > 0 + hasLegacyPath := options.CredentialPath != "" + hasLegacyUsages := options.UsagesPath != "" + hasLegacyDetour := options.Detour != "" + + if hasCredentials && hasLegacyPath { + return E.New("credential_path and credentials are mutually exclusive") + } + if hasCredentials && hasLegacyUsages { + return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") + } + if hasCredentials && hasLegacyDetour { + return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") + } + + if hasCredentials { + tags := make(map[string]bool) + credentialTypes := make(map[string]string) + for _, cred := range options.Credentials { + if tags[cred.Tag] { + return E.New("duplicate credential tag: ", cred.Tag) + } + tags[cred.Tag] = true + credentialTypes[cred.Tag] = cred.Type + if cred.Type == "default" || cred.Type == "" { + if cred.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") + } + if cred.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") + } + if cred.DefaultOptions.Limit5h > 100 { + return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") + } + if cred.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") + } + if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { + return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") + } + if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") + } + } + if cred.Type == "external" { + if cred.ExternalOptions.Token == "" { + return E.New("credential ", cred.Tag, ": external credential requires token") + } + if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": reverse external credential requires url") + } + } + if cred.Type == "balancer" { + switch cred.BalancerOptions.Strategy { + case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: + default: + return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) + } + if cred.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") + } + } + } + + for _, user := range options.Users { + if user.Credential == "" { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + if user.ExternalCredential != "" { + if !tags[user.ExternalCredential] { + return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) + } + if credentialTypes[user.ExternalCredential] != "external" { + return E.New("user ", user.Name, ": external_credential must reference an external type credential") + } + } + } + } + + return nil +} + +func credentialForUser( + userConfigMap map[string]*option.CCMUser, + providers map[string]credentialProvider, + legacyProvider credentialProvider, + username string, +) (credentialProvider, error) { + if legacyProvider != nil { + return legacyProvider, nil + } + userConfig, exists := userConfigMap[username] + if !exists { + return nil, E.New("no credential mapping for user: ", username) + } + provider, exists := providers[userConfig.Credential] + if !exists { + return nil, E.New("unknown credential: ", userConfig.Credential) + } + return provider, nil +} + +func noUserCredentialProvider( + providers map[string]credentialProvider, + legacyProvider credentialProvider, + options option.CCMServiceOptions, +) credentialProvider { + if legacyProvider != nil { + return legacyProvider + } + if len(options.Credentials) > 0 { + tag := options.Credentials[0].Tag + return providers[tag] + } + return nil +} diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go new file mode 100644 index 0000000000..c44ec41036 --- /dev/null +++ b/service/ccm/credential_default.go @@ -0,0 +1,726 @@ +package ccm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "io" + "math" + "net" + "net/http" + "strconv" + "sync" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/ntp" +) + +type defaultCredential struct { + tag string + serviceContext context.Context + credentialPath string + credentialFilePath string + credentials *oauthCredentials + access sync.RWMutex + state credentialState + stateAccess sync.RWMutex + pollAccess sync.Mutex + reloadAccess sync.Mutex + watcherAccess sync.Mutex + cap5h float64 + capWeekly float64 + usageTracker *AggregatedUsage + forwardHTTPClient *http.Client + logger log.ContextLogger + watcher *fswatch.Watcher + watcherRetryAt time.Time + + // Connection interruption + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + reserve5h := options.Reserve5h + if reserve5h == 0 { + reserve5h = 1 + } + reserveWeekly := options.ReserveWeekly + if reserveWeekly == 0 { + reserveWeekly = 1 + } + var cap5h float64 + if options.Limit5h > 0 { + cap5h = float64(options.Limit5h) + } else { + cap5h = float64(100 - reserve5h) + } + var capWeekly float64 + if options.LimitWeekly > 0 { + capWeekly = float64(options.LimitWeekly) + } else { + capWeekly = float64(100 - reserveWeekly) + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: tag, + serviceContext: ctx, + credentialPath: options.CredentialPath, + cap5h: cap5h, + capWeekly: capWeekly, + forwardHTTPClient: httpClient, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if options.UsagesPath != "" { + credential.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + return credential, nil +} + +func (c *defaultCredential) start() error { + credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) + if err != nil { + return E.Cause(err, "resolve credential path for ", c.tag) + } + c.credentialFilePath = credentialFilePath + err = c.ensureCredentialWatcher() + if err != nil { + c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + } + err = c.reloadCredentials(true) + if err != nil { + c.logger.Warn("initial credential load for ", c.tag, ": ", err) + } + if c.usageTracker != nil { + err = c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *defaultCredential) getAccessToken() (string, error) { + c.retryCredentialReloadIfNeeded() + + c.access.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.AccessToken + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + + err := c.reloadCredentials(true) + if err == nil { + c.access.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.AccessToken + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + } + + c.access.Lock() + defer c.access.Unlock() + + if c.credentials == nil { + return "", c.unavailableError() + } + if !c.credentials.needsRefresh() { + return c.credentials.AccessToken, nil + } + + err = platformCanWriteCredentials(c.credentialPath) + if err != nil { + return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") + } + + baseCredentials := cloneCredentials(c.credentials) + newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) + if err != nil { + return "", err + } + + latestCredentials, latestErr := platformReadCredentials(c.credentialPath) + if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { + c.credentials = latestCredentials + c.stateAccess.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.state.accountType = latestCredentials.SubscriptionType + c.state.rateLimitTier = latestCredentials.RateLimitTier + c.checkTransitionLocked() + c.stateAccess.Unlock() + if !latestCredentials.needsRefresh() { + return latestCredentials.AccessToken, nil + } + return "", E.New("credential ", c.tag, " changed while refreshing") + } + + c.credentials = newCredentials + c.stateAccess.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.state.accountType = newCredentials.SubscriptionType + c.state.rateLimitTier = newCredentials.RateLimitTier + c.checkTransitionLocked() + c.stateAccess.Unlock() + + err = platformWriteCredentials(newCredentials, c.credentialPath) + if err != nil { + c.logger.Error("persist refreshed token for ", c.tag, ": ", err) + } + + return newCredentials.AccessToken, nil +} + +func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { + c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + hadData := false + + fiveHourResetChanged := false + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { + hadData = true + if value.After(c.state.fiveHourReset) { + fiveHourResetChanged = true + c.state.fiveHourReset = value + } + } + if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + hadData = true + newValue := math.Ceil(value * 100) + if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged { + c.state.fiveHourUtilization = newValue + } + } + } + + weeklyResetChanged := false + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { + hadData = true + if value.After(c.state.weeklyReset) { + weeklyResetChanged = true + c.state.weeklyReset = value + } + } + if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + hadData = true + newValue := math.Ceil(value * 100) + if newValue >= c.state.weeklyUtilization || weeklyResetChanged { + c.state.weeklyUtilization = newValue + } + } + } + if hadData { + c.state.consecutivePollFailures = 0 + c.state.lastUpdated = time.Now() + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateAccess.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) isUsable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateAccess.RLock() + if c.state.unavailable { + c.stateAccess.RUnlock() + return false + } + if c.state.consecutivePollFailures > 0 { + c.stateAccess.RUnlock() + return false + } + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateAccess.RUnlock() + return false + } + c.stateAccess.RUnlock() + c.stateAccess.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.checkReservesLocked() + c.stateAccess.Unlock() + return usable + } + usable := c.checkReservesLocked() + c.stateAccess.RUnlock() + return usable +} + +func (c *defaultCredential) checkReservesLocked() bool { + if c.state.fiveHourUtilization >= c.cap5h { + return false + } + if c.state.weeklyUtilization >= c.capWeekly { + return false + } + return true +} + +// checkTransitionLocked detects usable→unusable transition. +// Must be called with stateAccess write lock held. +func (c *defaultCredential) checkTransitionLocked() bool { + unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *defaultCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFuncs: []func() bool{stop}, + cancelFunc: cancel, + } +} + +func (c *defaultCredential) weeklyUtilization() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.weeklyUtilization +} + +func (c *defaultCredential) planWeight() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) +} + +func (c *defaultCredential) weeklyResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.weeklyReset +} + +func (c *defaultCredential) isAvailable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return !c.state.unavailable +} + +func (c *defaultCredential) unavailableError() error { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if !c.state.unavailable { + return nil + } + if c.state.lastCredentialLoadError == "" { + return E.New("credential ", c.tag, " is unavailable") + } + return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) +} + +func (c *defaultCredential) lastUpdatedTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.lastUpdated +} + +func (c *defaultCredential) markUsagePollAttempted() { + c.stateAccess.Lock() + defer c.stateAccess.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *defaultCredential) incrementPollFailures() { + c.stateAccess.Lock() + c.state.consecutivePollFailures++ + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateAccess.RLock() + failures := c.state.consecutivePollFailures + c.stateAccess.RUnlock() + if failures <= 0 { + return baseInterval + } + backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) + if backoff > httpRetryMaxBackoff { + return httpRetryMaxBackoff + } + return backoff +} + +func (c *defaultCredential) isPollBackoffAtCap() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + failures := c.state.consecutivePollFailures + return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff +} + +func (c *defaultCredential) earliestReset() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if c.state.unavailable { + return time.Time{} + } + if c.state.hardRateLimited { + return c.state.rateLimitResetAt + } + earliest := c.state.fiveHourReset + if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { + earliest = c.state.weeklyReset + } + return earliest +} + +func (c *defaultCredential) pollUsage(ctx context.Context) { + if !c.pollAccess.TryLock() { + return + } + defer c.pollAccess.Unlock() + defer c.markUsagePollAttempted() + + c.retryCredentialReloadIfNeeded() + if !c.isAvailable() { + return + } + + accessToken, err := c.getAccessToken() + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": get token: ", err) + } + c.incrementPollFailures() + return + } + + httpClient := &http.Client{ + Transport: c.forwardHTTPClient.Transport, + Timeout: 5 * time.Second, + } + + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + return request, nil + }) + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } + c.incrementPollFailures() + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + if response.StatusCode == http.StatusTooManyRequests { + c.logger.Warn("poll usage for ", c.tag, ": rate limited") + } + body, _ := io.ReadAll(response.Body) + c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + c.incrementPollFailures() + return + } + + var usageResponse struct { + FiveHour struct { + Utilization float64 `json:"utilization"` + ResetsAt time.Time `json:"resets_at"` + } `json:"five_hour"` + SevenDay struct { + Utilization float64 `json:"utilization"` + ResetsAt time.Time `json:"resets_at"` + } `json:"seven_day"` + } + err = json.NewDecoder(response.Body).Decode(&usageResponse) + if err != nil { + c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.incrementPollFailures() + return + } + + c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + c.state.consecutivePollFailures = 0 + c.state.fiveHourUtilization = usageResponse.FiveHour.Utilization + if !usageResponse.FiveHour.ResetsAt.IsZero() { + c.state.fiveHourReset = usageResponse.FiveHour.ResetsAt + } + c.state.weeklyUtilization = usageResponse.SevenDay.Utilization + if !usageResponse.SevenDay.ResetsAt.IsZero() { + c.state.weeklyReset = usageResponse.SevenDay.ResetsAt + } + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } + needsProfileFetch := c.state.rateLimitTier == "" + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + + if needsProfileFetch { + c.fetchProfile(ctx, httpClient, accessToken) + } +} + +func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) { + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) + if err != nil { + c.logger.Debug("fetch profile for ", c.tag, ": ", err) + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return + } + + var profileResponse struct { + Organization *struct { + OrganizationType string `json:"organization_type"` + RateLimitTier string `json:"rate_limit_tier"` + } `json:"organization"` + } + err = json.NewDecoder(response.Body).Decode(&profileResponse) + if err != nil || profileResponse.Organization == nil { + return + } + + accountType := "" + switch profileResponse.Organization.OrganizationType { + case "claude_pro": + accountType = "pro" + case "claude_max": + accountType = "max" + case "claude_team": + accountType = "team" + case "claude_enterprise": + accountType = "enterprise" + } + rateLimitTier := profileResponse.Organization.RateLimitTier + + c.stateAccess.Lock() + if accountType != "" && c.state.accountType == "" { + c.state.accountType = accountType + } + if rateLimitTier != "" { + c.state.rateLimitTier = rateLimitTier + } + c.stateAccess.Unlock() + c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier)) +} + +func (c *defaultCredential) close() { + if c.watcher != nil { + err := c.watcher.Close() + if err != nil { + c.logger.Error("close credential watcher for ", c.tag, ": ", err) + } + } + if c.usageTracker != nil { + c.usageTracker.cancelPendingSave() + err := c.usageTracker.Save() + if err != nil { + c.logger.Error("save usage statistics for ", c.tag, ": ", err) + } + } +} + +func (c *defaultCredential) tagName() string { + return c.tag +} + +func (c *defaultCredential) isExternal() bool { + return false +} + +func (c *defaultCredential) fiveHourUtilization() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourUtilization +} + +func (c *defaultCredential) fiveHourCap() float64 { + return c.cap5h +} + +func (c *defaultCredential) weeklyCap() float64 { + return c.capWeekly +} + +func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { + return c.usageTracker +} + +func (c *defaultCredential) httpClient() *http.Client { + return c.forwardHTTPClient +} + +func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { + accessToken, err := c.getAccessToken() + if err != nil { + return nil, E.Cause(err, "get access token for ", c.tag) + } + + proxyURL := claudeAPIBaseURL + original.URL.RequestURI() + var body io.Reader + if bodyBytes != nil { + body = bytes.NewReader(bodyBytes) + } else { + body = original.Body + } + proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) + if err != nil { + return nil, err + } + + for key, values := range original.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { + proxyRequest.Header[key] = values + } + } + + serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0 + if c.usageTracker != nil && !serviceOverridesAcceptEncoding { + proxyRequest.Header.Del("Accept-Encoding") + } + + anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta") + if anthropicBetaHeader != "" { + proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) + } else { + proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + } + + for key, values := range serviceHeaders { + proxyRequest.Header.Del(key) + proxyRequest.Header[key] = values + } + proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) + + return proxyRequest, nil +} diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index b7e04bad34..24ddf6c4a2 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -29,16 +29,16 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + forwardHTTPClient *http.Client + state credentialState + stateAccess sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -128,7 +128,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx if options.URL == "" { // Receiver mode: no URL, wait for reverse connection cred.baseURL = reverseProxyBaseURL - cred.httpClient = &http.Client{ + cred.forwardHTTPClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -192,10 +192,10 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx Time: ntp.TimeFuncFromContext(ctx), } } - cred.httpClient = &http.Client{Transport: transport} + cred.forwardHTTPClient = &http.Client{Transport: transport} } else { // Normal mode: standard HTTP client for proxying - cred.httpClient = &http.Client{Transport: transport} + cred.forwardHTTPClient = &http.Client{Transport: transport} cred.reverseHttpClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, @@ -248,40 +248,40 @@ func (c *externalCredential) isUsable() bool { if !c.isAvailable() { return false } - c.stateMutex.RLock() + c.stateAccess.RLock() if c.state.consecutivePollFailures > 0 { - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return false } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return false } - c.stateMutex.RUnlock() - c.stateMutex.Lock() + c.stateAccess.RUnlock() + c.stateAccess.Lock() if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } // No reserve for external: only 100% is unusable usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 - c.stateMutex.Unlock() + c.stateAccess.Unlock() return usable } usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return usable } func (c *externalCredential) fiveHourUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.fiveHourUtilization } func (c *externalCredential) weeklyUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.weeklyUtilization } @@ -294,8 +294,8 @@ func (c *externalCredential) weeklyCap() float64 { } func (c *externalCredential) planWeight() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() if c.state.remotePlanWeight > 0 { return c.state.remotePlanWeight } @@ -303,26 +303,26 @@ func (c *externalCredential) planWeight() float64 { } func (c *externalCredential) weeklyResetTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.weeklyReset } func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } } func (c *externalCredential) earliestReset() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() if c.state.hardRateLimited { return c.state.rateLimitResetAt } @@ -408,7 +408,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con } func (c *externalCredential) updateStateFromHeaders(headers http.Header) { - c.stateMutex.Lock() + c.stateAccess.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization @@ -455,7 +455,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } @@ -530,9 +530,9 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp } } // Forward transport with retries - if c.httpClient != nil { + if c.forwardHTTPClient != nil { forwardClient := &http.Client{ - Transport: c.httpClient.Transport, + Transport: c.forwardHTTPClient.Transport, Timeout: 5 * time.Second, } return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL)) @@ -563,10 +563,10 @@ func (c *externalCredential) pollUsage(ctx context.Context) { // 404 means the remote does not have a status endpoint yet; // usage will be updated passively from response headers. if response.StatusCode == http.StatusNotFound { - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.consecutivePollFailures = 0 c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() } else { c.incrementPollFailures() } @@ -585,7 +585,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { return } - c.stateMutex.Lock() + c.stateAccess.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization @@ -606,28 +606,28 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } } func (c *externalCredential) lastUpdatedTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.lastUpdated } func (c *externalCredential) markUsagePollAttempted() { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() + c.stateAccess.Lock() + defer c.stateAccess.Unlock() c.state.lastUpdated = time.Now() } func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateMutex.RLock() + c.stateAccess.RLock() failures := c.state.consecutivePollFailures - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if failures <= 0 { return baseInterval } @@ -639,17 +639,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati } func (c *externalCredential) isPollBackoffAtCap() bool { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() failures := c.state.consecutivePollFailures return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff } func (c *externalCredential) incrementPollFailures() { - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.consecutivePollFailures++ shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } @@ -659,14 +659,14 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { return c.usageTracker } -func (c *externalCredential) httpTransport() *http.Client { +func (c *externalCredential) httpClient() *http.Client { if c.reverseHttpClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { return c.reverseHttpClient } } - return c.httpClient + return c.forwardHTTPClient } func (c *externalCredential) close() { diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index eba9207268..72d9da0100 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error { } func (c *defaultCredential) retryCredentialReloadIfNeeded() { - c.stateMutex.RLock() + c.stateAccess.RLock() unavailable := c.state.unavailable lastAttempt := c.state.lastCredentialLoadAttempt - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if !unavailable { return } @@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.reloadAccess.Lock() defer c.reloadAccess.Unlock() - c.stateMutex.RLock() + c.stateAccess.RLock() unavailable := c.state.unavailable lastAttempt := c.state.lastCredentialLoadAttempt - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if !force { if !unavailable { return nil @@ -97,43 +97,43 @@ func (c *defaultCredential) reloadCredentials(force bool) error { } } - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.lastCredentialLoadAttempt = time.Now() - c.stateMutex.Unlock() + c.stateAccess.Unlock() credentials, err := platformReadCredentials(c.credentialPath) if err != nil { return c.markCredentialsUnavailable(E.Cause(err, "read credentials")) } - c.accessMutex.Lock() + c.access.Lock() c.credentials = credentials - c.accessMutex.Unlock() + c.access.Unlock() - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.unavailable = false c.state.lastCredentialLoadError = "" c.state.accountType = credentials.SubscriptionType c.state.rateLimitTier = credentials.RateLimitTier c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() return nil } func (c *defaultCredential) markCredentialsUnavailable(err error) error { - c.accessMutex.Lock() + c.access.Lock() hadCredentials := c.credentials != nil c.credentials = nil - c.accessMutex.Unlock() + c.access.Unlock() - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() c.state.accountType = "" c.state.rateLimitTier = "" shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt && hadCredentials { c.interruptConnections() diff --git a/service/ccm/credential_oauth.go b/service/ccm/credential_oauth.go new file mode 100644 index 0000000000..da559c173d --- /dev/null +++ b/service/ccm/credential_oauth.go @@ -0,0 +1,224 @@ +package ccm + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "os/user" + "path/filepath" + "runtime" + "slices" + "sync" + "time" + + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + oauth2TokenURL = "https://platform.claude.com/v1/oauth/token" + claudeAPIBaseURL = "https://api.anthropic.com" + tokenRefreshBufferMs = 60000 + anthropicBetaOAuthValue = "oauth-2025-04-20" +) + +const ccmUserAgentFallback = "claude-code/2.1.72" + +var ( + ccmUserAgentOnce sync.Once + ccmUserAgentValue string +) + +func initCCMUserAgent(logger log.ContextLogger) { + ccmUserAgentOnce.Do(func() { + version, err := detectClaudeCodeVersion() + if err != nil { + logger.Error("detect Claude Code version: ", err) + ccmUserAgentValue = ccmUserAgentFallback + return + } + logger.Debug("detected Claude Code version: ", version) + ccmUserAgentValue = "claude-code/" + version + }) +} + +func detectClaudeCodeVersion() (string, error) { + userInfo, err := getRealUser() + if err != nil { + return "", E.Cause(err, "get user") + } + binaryName := "claude" + if runtime.GOOS == "windows" { + binaryName = "claude.exe" + } + linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName) + target, err := os.Readlink(linkPath) + if err != nil { + return "", E.Cause(err, "readlink ", linkPath) + } + if !filepath.IsAbs(target) { + target = filepath.Join(filepath.Dir(linkPath), target) + } + parent := filepath.Base(filepath.Dir(target)) + if parent != "versions" { + return "", E.New("unexpected symlink target: ", target) + } + return filepath.Base(target), nil +} + +func getRealUser() (*user.User, error) { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + sudoUserInfo, err := user.Lookup(sudoUser) + if err == nil { + return sudoUserInfo, nil + } + } + return user.Current() +} + +func getDefaultCredentialsPath() (string, error) { + if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" { + return filepath.Join(configDir, ".credentials.json"), nil + } + userInfo, err := getRealUser() + if err != nil { + return "", err + } + return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil +} + +func readCredentialsFromFile(path string) (*oauthCredentials, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var credentialsContainer struct { + ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"` + } + err = json.Unmarshal(data, &credentialsContainer) + if err != nil { + return nil, err + } + if credentialsContainer.ClaudeAIAuth == nil { + return nil, E.New("claudeAiOauth field not found in credentials") + } + return credentialsContainer.ClaudeAIAuth, nil +} + +func checkCredentialFileWritable(path string) error { + file, err := os.OpenFile(path, os.O_WRONLY, 0) + if err != nil { + return err + } + return file.Close() +} + +func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error { + data, err := json.MarshalIndent(map[string]any{ + "claudeAiOauth": oauthCredentials, + }, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o600) +} + +type oauthCredentials struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresAt int64 `json:"expiresAt"` + Scopes []string `json:"scopes,omitempty"` + SubscriptionType string `json:"subscriptionType,omitempty"` + RateLimitTier string `json:"rateLimitTier,omitempty"` + IsMax bool `json:"isMax,omitempty"` +} + +func (c *oauthCredentials) needsRefresh() bool { + if c.ExpiresAt == 0 { + return false + } + return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs +} + +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { + if credentials.RefreshToken == "" { + return nil, E.New("refresh token is empty") + } + + requestBody, err := json.Marshal(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": credentials.RefreshToken, + "client_id": oauth2ClientID, + }) + if err != nil { + return nil, E.Cause(err, "marshal request") + } + + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode == http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + } + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh failed: ", response.Status, " ", string(body)) + } + + var tokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + } + err = json.NewDecoder(response.Body).Decode(&tokenResponse) + if err != nil { + return nil, E.Cause(err, "decode response") + } + + newCredentials := *credentials + newCredentials.AccessToken = tokenResponse.AccessToken + if tokenResponse.RefreshToken != "" { + newCredentials.RefreshToken = tokenResponse.RefreshToken + } + newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000 + + return &newCredentials, nil +} + +func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { + if credentials == nil { + return nil + } + cloned := *credentials + cloned.Scopes = append([]string(nil), credentials.Scopes...) + return &cloned +} + +func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { + if left == nil || right == nil { + return left == right + } + return left.AccessToken == right.AccessToken && + left.RefreshToken == right.RefreshToken && + left.ExpiresAt == right.ExpiresAt && + slices.Equal(left.Scopes, right.Scopes) && + left.SubscriptionType == right.SubscriptionType && + left.RateLimitTier == right.RateLimitTier && + left.IsMax == right.IsMax +} diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go new file mode 100644 index 0000000000..cd77bfcdc1 --- /dev/null +++ b/service/ccm/credential_provider.go @@ -0,0 +1,405 @@ +package ccm + +import ( + "context" + "math/rand/v2" + "sync" + "sync/atomic" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" +) + +type credentialProvider interface { + selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential + linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool + pollIfStale(ctx context.Context) + allCredentials() []credential + close() +} + +type singleCredentialProvider struct { + cred credential + sessionAccess sync.RWMutex + sessions map[string]time.Time +} + +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if !selection.allows(p.cred) { + return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") + } + if !p.cred.isAvailable() { + return nil, false, p.cred.unavailableError() + } + if !p.cred.isUsable() { + return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") + } + var isNew bool + if sessionID != "" { + p.sessionAccess.Lock() + if p.sessions == nil { + p.sessions = make(map[string]time.Time) + } + _, exists := p.sessions[sessionID] + if !exists { + p.sessions[sessionID] = time.Now() + isNew = true + } + p.sessionAccess.Unlock() + } + return p.cred, isNew, nil +} + +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { + cred.markRateLimited(resetAt) + return nil +} + +func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, createdAt := range p.sessions { + if now.Sub(createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + + if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { + p.cred.pollUsage(ctx) + } +} + +func (p *singleCredentialProvider) allCredentials() []credential { + return []credential{p.cred} +} + +func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} + +func (p *singleCredentialProvider) close() {} + +type sessionEntry struct { + tag string + selectionScope credentialSelectionScope + createdAt time.Time +} + +type credentialInterruptKey struct { + tag string + selectionScope credentialSelectionScope +} + +type credentialInterruptEntry struct { + context context.Context + cancel context.CancelFunc +} + +type balancerProvider struct { + credentials []credential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + rebalanceThreshold float64 + sessionAccess sync.RWMutex + sessions map[string]sessionEntry + interruptAccess sync.Mutex + credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry + logger log.ContextLogger +} + +func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &balancerProvider{ + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + rebalanceThreshold: rebalanceThreshold, + sessions: make(map[string]sessionEntry), + credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), + logger: logger, + } +} + +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if p.strategy == C.BalancerStrategyFallback { + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + return best, false, nil + } + + selectionScope := selection.scopeOrDefault() + if sessionID != "" { + p.sessionAccess.RLock() + entry, exists := p.sessions[sessionID] + p.sessionAccess.RUnlock() + if exists { + if entry.selectionScope == selectionScope { + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName(), selectionScope) + break + } + } + } + return cred, false, nil + } + } + } + p.sessionAccess.Lock() + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } + } + + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + + isNew := sessionID != "" + if isNew { + p.sessionAccess.Lock() + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selectionScope, + createdAt: time.Now(), + } + p.sessionAccess.Unlock() + } + return best, isNew, nil +} + +func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { + key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} + p.interruptAccess.Lock() + if entry, loaded := p.credentialInterrupts[key]; loaded { + entry.cancel() + } + ctx, cancel := context.WithCancel(context.Background()) + p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.interruptAccess.Unlock() + + p.sessionAccess.Lock() + for id, entry := range p.sessions { + if entry.tag == tag && entry.selectionScope == selectionScope { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() +} + +func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + if p.strategy == C.BalancerStrategyFallback { + return func() bool { return false } + } + key := credentialInterruptKey{ + tag: cred.tagName(), + selectionScope: selection.scopeOrDefault(), + } + p.interruptAccess.Lock() + entry, loaded := p.credentialInterrupts[key] + if !loaded { + ctx, cancel := context.WithCancel(context.Background()) + entry = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[key] = entry + } + p.interruptAccess.Unlock() + return context.AfterFunc(entry.context, onInterrupt) +} + +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { + cred.markRateLimited(resetAt) + if p.strategy == C.BalancerStrategyFallback { + return p.pickCredential(selection.filter) + } + if sessionID != "" { + p.sessionAccess.Lock() + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } + + best := p.pickCredential(selection.filter) + if best != nil && sessionID != "" { + p.sessionAccess.Lock() + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selection.scopeOrDefault(), + createdAt: time.Now(), + } + p.sessionAccess.Unlock() + } + return best +} + +func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { + switch p.strategy { + case C.BalancerStrategyRoundRobin: + return p.pickRoundRobin(filter) + case C.BalancerStrategyRandom: + return p.pickRandom(filter) + case C.BalancerStrategyFallback: + return p.pickFallback(filter) + default: + return p.pickLeastUsed(filter) + } +} + +func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if cred.isUsable() { + return cred + } + } + return nil +} + +const weeklyWindowHours = 7 * 24 + +func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { + var best credential + bestScore := float64(-1) + now := time.Now() + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !cred.isUsable() { + continue + } + remaining := cred.weeklyCap() - cred.weeklyUtilization() + score := remaining * cred.planWeight() + resetTime := cred.weeklyResetTime() + if !resetTime.IsZero() { + timeUntilReset := resetTime.Sub(now) + if timeUntilReset < time.Hour { + timeUntilReset = time.Hour + } + score *= weeklyWindowHours / timeUntilReset.Hours() + } + if score > bestScore { + bestScore = score + best = cred + } + } + return best +} + +func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { + start := int(p.roundRobinIndex.Add(1) - 1) + count := len(p.credentials) + for offset := range count { + candidate := p.credentials[(start+offset)%count] + if filter != nil && !filter(candidate) { + continue + } + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { + var usable []credential + for _, candidate := range p.credentials { + if filter != nil && !filter(candidate) { + continue + } + if candidate.isUsable() { + usable = append(usable, candidate) + } + } + if len(usable) == 0 { + return nil + } + return usable[rand.IntN(len(usable))] +} + +func (p *balancerProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, entry := range p.sessions { + if now.Sub(entry.createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + + for _, cred := range p.credentials { + if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { + cred.pollUsage(ctx) + } + } +} + +func (p *balancerProvider) allCredentials() []credential { + return p.credentials +} + +func (p *balancerProvider) close() {} + +func ccmPlanWeight(accountType string, rateLimitTier string) float64 { + switch accountType { + case "max": + switch rateLimitTier { + case "default_claude_max_20x": + return 10 + case "default_claude_max_5x": + return 5 + default: + return 5 + } + case "team": + if rateLimitTier == "default_claude_max_5x" { + return 5 + } + return 1 + default: + return 1 + } +} + +func allCredentialsUnavailableError(credentials []credential) error { + var hasUnavailable bool + var earliest time.Time + for _, cred := range credentials { + if cred.unavailableError() != nil { + hasUnavailable = true + continue + } + resetAt := cred.earliestReset() + if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { + earliest = resetAt + } + } + if hasUnavailable { + return E.New("all credentials unavailable") + } + if earliest.IsZero() { + return E.New("all credentials rate-limited") + } + return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) +} diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go deleted file mode 100644 index b07529eb0c..0000000000 --- a/service/ccm/credential_state.go +++ /dev/null @@ -1,1506 +0,0 @@ -package ccm - -import ( - "bytes" - "context" - stdTLS "crypto/tls" - "encoding/json" - "io" - "math" - "math/rand/v2" - "net" - "net/http" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/sagernet/fswatch" - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/ntp" -) - -const ( - defaultPollInterval = 60 * time.Minute - failedPollRetryInterval = time.Minute - httpRetryMaxBackoff = 5 * time.Minute -) - -const ( - httpRetryMaxAttempts = 3 - httpRetryInitialDelay = 200 * time.Millisecond -) - -func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { - var lastError error - for attempt := range httpRetryMaxAttempts { - if attempt > 0 { - delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) - select { - case <-ctx.Done(): - return nil, lastError - case <-time.After(delay): - } - } - request, err := buildRequest() - if err != nil { - return nil, err - } - response, err := client.Do(request) - if err == nil { - return response, nil - } - lastError = err - if ctx.Err() != nil { - return nil, lastError - } - } - return nil, lastError -} - -type credentialState struct { - fiveHourUtilization float64 - fiveHourReset time.Time - weeklyUtilization float64 - weeklyReset time.Time - hardRateLimited bool - rateLimitResetAt time.Time - accountType string - rateLimitTier string - remotePlanWeight float64 - lastUpdated time.Time - consecutivePollFailures int - unavailable bool - lastCredentialLoadAttempt time.Time - lastCredentialLoadError string -} - -type defaultCredential struct { - tag string - serviceContext context.Context - credentialPath string - credentialFilePath string - credentials *oauthCredentials - accessMutex sync.RWMutex - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - reloadAccess sync.Mutex - watcherAccess sync.Mutex - cap5h float64 - capWeekly float64 - usageTracker *AggregatedUsage - httpClient *http.Client - logger log.ContextLogger - watcher *fswatch.Watcher - watcherRetryAt time.Time - - // Connection interruption - onBecameUnusable func() - interrupted bool - requestContext context.Context - cancelRequests context.CancelFunc - requestAccess sync.Mutex -} - -type credentialRequestContext struct { - context.Context - releaseOnce sync.Once - cancelOnce sync.Once - releaseFuncs []func() bool - cancelFunc context.CancelFunc -} - -func (c *credentialRequestContext) addInterruptLink(stop func() bool) { - c.releaseFuncs = append(c.releaseFuncs, stop) -} - -func (c *credentialRequestContext) releaseCredentialInterrupt() { - c.releaseOnce.Do(func() { - for _, f := range c.releaseFuncs { - f() - } - }) -} - -func (c *credentialRequestContext) cancelRequest() { - c.releaseCredentialInterrupt() - c.cancelOnce.Do(c.cancelFunc) -} - -type credential interface { - tagName() string - isAvailable() bool - isUsable() bool - isExternal() bool - fiveHourUtilization() float64 - weeklyUtilization() float64 - fiveHourCap() float64 - weeklyCap() float64 - planWeight() float64 - weeklyResetTime() time.Time - markRateLimited(resetAt time.Time) - earliestReset() time.Time - unavailableError() error - - getAccessToken() (string, error) - buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) - updateStateFromHeaders(header http.Header) - - wrapRequestContext(ctx context.Context) *credentialRequestContext - interruptConnections() - - start() error - pollUsage(ctx context.Context) - lastUpdatedTime() time.Time - pollBackoff(base time.Duration) time.Duration - usageTrackerOrNil() *AggregatedUsage - httpTransport() *http.Client - close() -} - -func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { - credentialDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer for credential ", tag) - } - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSClientConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - reserve5h := options.Reserve5h - if reserve5h == 0 { - reserve5h = 1 - } - reserveWeekly := options.ReserveWeekly - if reserveWeekly == 0 { - reserveWeekly = 1 - } - var cap5h float64 - if options.Limit5h > 0 { - cap5h = float64(options.Limit5h) - } else { - cap5h = float64(100 - reserve5h) - } - var capWeekly float64 - if options.LimitWeekly > 0 { - capWeekly = float64(options.LimitWeekly) - } else { - capWeekly = float64(100 - reserveWeekly) - } - requestContext, cancelRequests := context.WithCancel(context.Background()) - credential := &defaultCredential{ - tag: tag, - serviceContext: ctx, - credentialPath: options.CredentialPath, - cap5h: cap5h, - capWeekly: capWeekly, - httpClient: httpClient, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - } - if options.UsagesPath != "" { - credential.usageTracker = &AggregatedUsage{ - LastUpdated: time.Now(), - Combinations: make([]CostCombination, 0), - filePath: options.UsagesPath, - logger: logger, - } - } - return credential, nil -} - -func (c *defaultCredential) start() error { - credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) - if err != nil { - return E.Cause(err, "resolve credential path for ", c.tag) - } - c.credentialFilePath = credentialFilePath - err = c.ensureCredentialWatcher() - if err != nil { - c.logger.Debug("start credential watcher for ", c.tag, ": ", err) - } - err = c.reloadCredentials(true) - if err != nil { - c.logger.Warn("initial credential load for ", c.tag, ": ", err) - } - if c.usageTracker != nil { - err = c.usageTracker.Load() - if err != nil { - c.logger.Warn("load usage statistics for ", c.tag, ": ", err) - } - } - return nil -} - -func (c *defaultCredential) getAccessToken() (string, error) { - c.retryCredentialReloadIfNeeded() - - c.accessMutex.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.AccessToken - c.accessMutex.RUnlock() - return token, nil - } - c.accessMutex.RUnlock() - - err := c.reloadCredentials(true) - if err == nil { - c.accessMutex.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.AccessToken - c.accessMutex.RUnlock() - return token, nil - } - c.accessMutex.RUnlock() - } - - c.accessMutex.Lock() - defer c.accessMutex.Unlock() - - if c.credentials == nil { - return "", c.unavailableError() - } - if !c.credentials.needsRefresh() { - return c.credentials.AccessToken, nil - } - - err = platformCanWriteCredentials(c.credentialPath) - if err != nil { - return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") - } - - baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials) - if err != nil { - return "", err - } - - latestCredentials, latestErr := platformReadCredentials(c.credentialPath) - if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { - c.credentials = latestCredentials - c.stateMutex.Lock() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.state.accountType = latestCredentials.SubscriptionType - c.state.rateLimitTier = latestCredentials.RateLimitTier - c.checkTransitionLocked() - c.stateMutex.Unlock() - if !latestCredentials.needsRefresh() { - return latestCredentials.AccessToken, nil - } - return "", E.New("credential ", c.tag, " changed while refreshing") - } - - c.credentials = newCredentials - c.stateMutex.Lock() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.state.accountType = newCredentials.SubscriptionType - c.state.rateLimitTier = newCredentials.RateLimitTier - c.checkTransitionLocked() - c.stateMutex.Unlock() - - err = platformWriteCredentials(newCredentials, c.credentialPath) - if err != nil { - c.logger.Error("persist refreshed token for ", c.tag, ": ", err) - } - - return newCredentials.AccessToken, nil -} - -// Claude Code's unified rate-limit handling parses these reset headers with -// Number(...), compares them against Date.now()/1000, and renders them via -// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds. -func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time { - unixEpoch, err := strconv.ParseInt(headerValue, 10, 64) - if err != nil { - panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue)) - } - if unixEpoch <= 0 { - panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue)) - } - return time.Unix(unixEpoch, 0) -} - -func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) { - headerValue := headers.Get(headerName) - if headerValue == "" { - return time.Time{}, false - } - return parseAnthropicResetHeaderValue(headerName, headerValue), true -} - -func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time { - headerValue := headers.Get(headerName) - if headerValue == "" { - panic("missing required " + headerName + " header") - } - return parseAnthropicResetHeaderValue(headerName, headerValue) -} - -func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { - c.stateMutex.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() - oldFiveHour := c.state.fiveHourUtilization - oldWeekly := c.state.weeklyUtilization - hadData := false - - fiveHourResetChanged := false - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { - hadData = true - if value.After(c.state.fiveHourReset) { - fiveHourResetChanged = true - c.state.fiveHourReset = value - } - } - if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { - value, err := strconv.ParseFloat(utilization, 64) - if err == nil { - hadData = true - newValue := math.Ceil(value * 100) - if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged { - c.state.fiveHourUtilization = newValue - } - } - } - - weeklyResetChanged := false - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { - hadData = true - if value.After(c.state.weeklyReset) { - weeklyResetChanged = true - c.state.weeklyReset = value - } - } - if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { - value, err := strconv.ParseFloat(utilization, 64) - if err == nil { - hadData = true - newValue := math.Ceil(value * 100) - if newValue >= c.state.weeklyUtilization || weeklyResetChanged { - c.state.weeklyUtilization = newValue - } - } - } - if hadData { - c.state.consecutivePollFailures = 0 - c.state.lastUpdated = time.Now() - } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - resetSuffix := "" - if !c.state.weeklyReset.IsZero() { - resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) - } - c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) - } - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) markRateLimited(resetAt time.Time) { - c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) - c.stateMutex.Lock() - c.state.hardRateLimited = true - c.state.rateLimitResetAt = resetAt - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) isUsable() bool { - c.retryCredentialReloadIfNeeded() - - c.stateMutex.RLock() - if c.state.unavailable { - c.stateMutex.RUnlock() - return false - } - if c.state.consecutivePollFailures > 0 { - c.stateMutex.RUnlock() - return false - } - if c.state.hardRateLimited { - if time.Now().Before(c.state.rateLimitResetAt) { - c.stateMutex.RUnlock() - return false - } - c.stateMutex.RUnlock() - c.stateMutex.Lock() - if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { - c.state.hardRateLimited = false - } - usable := c.checkReservesLocked() - c.stateMutex.Unlock() - return usable - } - usable := c.checkReservesLocked() - c.stateMutex.RUnlock() - return usable -} - -func (c *defaultCredential) checkReservesLocked() bool { - if c.state.fiveHourUtilization >= c.cap5h { - return false - } - if c.state.weeklyUtilization >= c.capWeekly { - return false - } - return true -} - -// checkTransitionLocked detects usable→unusable transition. -// Must be called with stateMutex write lock held. -func (c *defaultCredential) checkTransitionLocked() bool { - unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 - if unusable && !c.interrupted { - c.interrupted = true - return true - } - if !unusable && c.interrupted { - c.interrupted = false - } - return false -} - -func (c *defaultCredential) interruptConnections() { - c.logger.Warn("interrupting connections for ", c.tag) - c.requestAccess.Lock() - c.cancelRequests() - c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) - c.requestAccess.Unlock() - if c.onBecameUnusable != nil { - c.onBecameUnusable() - } -} - -func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { - c.requestAccess.Lock() - credentialContext := c.requestContext - c.requestAccess.Unlock() - derived, cancel := context.WithCancel(parent) - stop := context.AfterFunc(credentialContext, func() { - cancel() - }) - return &credentialRequestContext{ - Context: derived, - releaseFuncs: []func() bool{stop}, - cancelFunc: cancel, - } -} - -func (c *defaultCredential) weeklyUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.weeklyUtilization -} - -func (c *defaultCredential) planWeight() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) -} - -func (c *defaultCredential) weeklyResetTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.weeklyReset -} - -func (c *defaultCredential) isAvailable() bool { - c.retryCredentialReloadIfNeeded() - - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return !c.state.unavailable -} - -func (c *defaultCredential) unavailableError() error { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - if !c.state.unavailable { - return nil - } - if c.state.lastCredentialLoadError == "" { - return E.New("credential ", c.tag, " is unavailable") - } - return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) -} - -func (c *defaultCredential) lastUpdatedTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.lastUpdated -} - -func (c *defaultCredential) markUsagePollAttempted() { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - c.state.lastUpdated = time.Now() -} - -func (c *defaultCredential) incrementPollFailures() { - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateMutex.RLock() - failures := c.state.consecutivePollFailures - c.stateMutex.RUnlock() - if failures <= 0 { - return baseInterval - } - backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) - if backoff > httpRetryMaxBackoff { - return httpRetryMaxBackoff - } - return backoff -} - -func (c *defaultCredential) isPollBackoffAtCap() bool { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - failures := c.state.consecutivePollFailures - return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff -} - -func (c *defaultCredential) earliestReset() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - if c.state.unavailable { - return time.Time{} - } - if c.state.hardRateLimited { - return c.state.rateLimitResetAt - } - earliest := c.state.fiveHourReset - if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { - earliest = c.state.weeklyReset - } - return earliest -} - -func (c *defaultCredential) pollUsage(ctx context.Context) { - if !c.pollAccess.TryLock() { - return - } - defer c.pollAccess.Unlock() - defer c.markUsagePollAttempted() - - c.retryCredentialReloadIfNeeded() - if !c.isAvailable() { - return - } - - accessToken, err := c.getAccessToken() - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": get token: ", err) - } - c.incrementPollFailures() - return - } - - httpClient := &http.Client{ - Transport: c.httpClient.Transport, - Timeout: 5 * time.Second, - } - - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+accessToken) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - return request, nil - }) - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": ", err) - } - c.incrementPollFailures() - return - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - if response.StatusCode == http.StatusTooManyRequests { - c.logger.Warn("poll usage for ", c.tag, ": rate limited") - } - body, _ := io.ReadAll(response.Body) - c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - c.incrementPollFailures() - return - } - - var usageResponse struct { - FiveHour struct { - Utilization float64 `json:"utilization"` - ResetsAt time.Time `json:"resets_at"` - } `json:"five_hour"` - SevenDay struct { - Utilization float64 `json:"utilization"` - ResetsAt time.Time `json:"resets_at"` - } `json:"seven_day"` - } - err = json.NewDecoder(response.Body).Decode(&usageResponse) - if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.incrementPollFailures() - return - } - - c.stateMutex.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() - oldFiveHour := c.state.fiveHourUtilization - oldWeekly := c.state.weeklyUtilization - c.state.consecutivePollFailures = 0 - c.state.fiveHourUtilization = usageResponse.FiveHour.Utilization - if !usageResponse.FiveHour.ResetsAt.IsZero() { - c.state.fiveHourReset = usageResponse.FiveHour.ResetsAt - } - c.state.weeklyUtilization = usageResponse.SevenDay.Utilization - if !usageResponse.SevenDay.ResetsAt.IsZero() { - c.state.weeklyReset = usageResponse.SevenDay.ResetsAt - } - if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { - c.state.hardRateLimited = false - } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - resetSuffix := "" - if !c.state.weeklyReset.IsZero() { - resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) - } - c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) - } - needsProfileFetch := c.state.rateLimitTier == "" - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } - - if needsProfileFetch { - c.fetchProfile(ctx, httpClient, accessToken) - } -} - -func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) { - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+accessToken) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - return request, nil - }) - if err != nil { - c.logger.Debug("fetch profile for ", c.tag, ": ", err) - return - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - return - } - - var profileResponse struct { - Organization *struct { - OrganizationType string `json:"organization_type"` - RateLimitTier string `json:"rate_limit_tier"` - } `json:"organization"` - } - err = json.NewDecoder(response.Body).Decode(&profileResponse) - if err != nil || profileResponse.Organization == nil { - return - } - - accountType := "" - switch profileResponse.Organization.OrganizationType { - case "claude_pro": - accountType = "pro" - case "claude_max": - accountType = "max" - case "claude_team": - accountType = "team" - case "claude_enterprise": - accountType = "enterprise" - } - rateLimitTier := profileResponse.Organization.RateLimitTier - - c.stateMutex.Lock() - if accountType != "" && c.state.accountType == "" { - c.state.accountType = accountType - } - if rateLimitTier != "" { - c.state.rateLimitTier = rateLimitTier - } - c.stateMutex.Unlock() - c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier)) -} - -func (c *defaultCredential) close() { - if c.watcher != nil { - err := c.watcher.Close() - if err != nil { - c.logger.Error("close credential watcher for ", c.tag, ": ", err) - } - } - if c.usageTracker != nil { - c.usageTracker.cancelPendingSave() - err := c.usageTracker.Save() - if err != nil { - c.logger.Error("save usage statistics for ", c.tag, ": ", err) - } - } -} - -func (c *defaultCredential) tagName() string { - return c.tag -} - -func (c *defaultCredential) isExternal() bool { - return false -} - -func (c *defaultCredential) fiveHourUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.fiveHourUtilization -} - -func (c *defaultCredential) fiveHourCap() float64 { - return c.cap5h -} - -func (c *defaultCredential) weeklyCap() float64 { - return c.capWeekly -} - -func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { - return c.usageTracker -} - -func (c *defaultCredential) httpTransport() *http.Client { - return c.httpClient -} - -func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { - accessToken, err := c.getAccessToken() - if err != nil { - return nil, E.Cause(err, "get access token for ", c.tag) - } - - proxyURL := claudeAPIBaseURL + original.URL.RequestURI() - var body io.Reader - if bodyBytes != nil { - body = bytes.NewReader(bodyBytes) - } else { - body = original.Body - } - proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) - if err != nil { - return nil, err - } - - for key, values := range original.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { - proxyRequest.Header[key] = values - } - } - - serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0 - if c.usageTracker != nil && !serviceOverridesAcceptEncoding { - proxyRequest.Header.Del("Accept-Encoding") - } - - anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta") - if anthropicBetaHeader != "" { - proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) - } else { - proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - } - - for key, values := range serviceHeaders { - proxyRequest.Header.Del(key) - proxyRequest.Header[key] = values - } - proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - - return proxyRequest, nil -} - -// credentialProvider is the interface for all credential types. -type credentialProvider interface { - selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential - linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool - pollIfStale(ctx context.Context) - allCredentials() []credential - close() -} - -type credentialSelectionScope string - -const ( - credentialSelectionScopeAll credentialSelectionScope = "all" - credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" -) - -type credentialSelection struct { - scope credentialSelectionScope - filter func(credential) bool -} - -func (s credentialSelection) allows(cred credential) bool { - return s.filter == nil || s.filter(cred) -} - -func (s credentialSelection) scopeOrDefault() credentialSelectionScope { - if s.scope == "" { - return credentialSelectionScopeAll - } - return s.scope -} - -// singleCredentialProvider wraps a single credential (legacy or single default). -type singleCredentialProvider struct { - cred credential - sessionAccess sync.RWMutex - sessions map[string]time.Time -} - -func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if !selection.allows(p.cred) { - return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") - } - if !p.cred.isAvailable() { - return nil, false, p.cred.unavailableError() - } - if !p.cred.isUsable() { - return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") - } - var isNew bool - if sessionID != "" { - p.sessionAccess.Lock() - if p.sessions == nil { - p.sessions = make(map[string]time.Time) - } - _, exists := p.sessions[sessionID] - if !exists { - p.sessions[sessionID] = time.Now() - isNew = true - } - p.sessionAccess.Unlock() - } - return p.cred, isNew, nil -} - -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { - cred.markRateLimited(resetAt) - return nil -} - -func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { - now := time.Now() - p.sessionAccess.Lock() - for id, createdAt := range p.sessions { - if now.Sub(createdAt) > sessionExpiry { - delete(p.sessions, id) - } - } - p.sessionAccess.Unlock() - - if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { - p.cred.pollUsage(ctx) - } -} - -func (p *singleCredentialProvider) allCredentials() []credential { - return []credential{p.cred} -} - -func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { - return func() bool { - return false - } -} - -func (p *singleCredentialProvider) close() {} - -const sessionExpiry = 24 * time.Hour - -type sessionEntry struct { - tag string - selectionScope credentialSelectionScope - createdAt time.Time -} - -type credentialInterruptKey struct { - tag string - selectionScope credentialSelectionScope -} - -type credentialInterruptEntry struct { - context context.Context - cancel context.CancelFunc -} - -// balancerProvider assigns sessions to credentials based on a configurable strategy. -type balancerProvider struct { - credentials []credential - strategy string - roundRobinIndex atomic.Uint64 - pollInterval time.Duration - rebalanceThreshold float64 - sessionMutex sync.RWMutex - sessions map[string]sessionEntry - interruptAccess sync.Mutex - credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry - logger log.ContextLogger -} - -func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { - if pollInterval <= 0 { - pollInterval = defaultPollInterval - } - return &balancerProvider{ - credentials: credentials, - strategy: strategy, - pollInterval: pollInterval, - rebalanceThreshold: rebalanceThreshold, - sessions: make(map[string]sessionEntry), - credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), - logger: logger, - } -} - -func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if p.strategy == C.BalancerStrategyFallback { - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allCredentialsUnavailableError(p.credentials) - } - return best, false, nil - } - - selectionScope := selection.scopeOrDefault() - if sessionID != "" { - p.sessionMutex.RLock() - entry, exists := p.sessions[sessionID] - p.sessionMutex.RUnlock() - if exists { - if entry.selectionScope == selectionScope { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { - better := p.pickLeastUsed(selection.filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName(), selectionScope) - break - } - } - } - return cred, false, nil - } - } - } - p.sessionMutex.Lock() - delete(p.sessions, sessionID) - p.sessionMutex.Unlock() - } - } - - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allCredentialsUnavailableError(p.credentials) - } - - isNew := sessionID != "" - if isNew { - p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{ - tag: best.tagName(), - selectionScope: selectionScope, - createdAt: time.Now(), - } - p.sessionMutex.Unlock() - } - return best, isNew, nil -} - -func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { - key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} - p.interruptAccess.Lock() - if entry, loaded := p.credentialInterrupts[key]; loaded { - entry.cancel() - } - ctx, cancel := context.WithCancel(context.Background()) - p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} - p.interruptAccess.Unlock() - - p.sessionMutex.Lock() - for id, entry := range p.sessions { - if entry.tag == tag && entry.selectionScope == selectionScope { - delete(p.sessions, id) - } - } - p.sessionMutex.Unlock() -} - -func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { - if p.strategy == C.BalancerStrategyFallback { - return func() bool { return false } - } - key := credentialInterruptKey{ - tag: cred.tagName(), - selectionScope: selection.scopeOrDefault(), - } - p.interruptAccess.Lock() - entry, loaded := p.credentialInterrupts[key] - if !loaded { - ctx, cancel := context.WithCancel(context.Background()) - entry = credentialInterruptEntry{context: ctx, cancel: cancel} - p.credentialInterrupts[key] = entry - } - p.interruptAccess.Unlock() - return context.AfterFunc(entry.context, onInterrupt) -} - -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { - cred.markRateLimited(resetAt) - if p.strategy == C.BalancerStrategyFallback { - return p.pickCredential(selection.filter) - } - if sessionID != "" { - p.sessionMutex.Lock() - delete(p.sessions, sessionID) - p.sessionMutex.Unlock() - } - - best := p.pickCredential(selection.filter) - if best != nil && sessionID != "" { - p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{ - tag: best.tagName(), - selectionScope: selection.scopeOrDefault(), - createdAt: time.Now(), - } - p.sessionMutex.Unlock() - } - return best -} - -func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { - switch p.strategy { - case C.BalancerStrategyRoundRobin: - return p.pickRoundRobin(filter) - case C.BalancerStrategyRandom: - return p.pickRandom(filter) - case C.BalancerStrategyFallback: - return p.pickFallback(filter) - default: - return p.pickLeastUsed(filter) - } -} - -func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { - continue - } - if cred.isUsable() { - return cred - } - } - return nil -} - -func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { - var best credential - bestScore := float64(-1) - now := time.Now() - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { - continue - } - if !cred.isUsable() { - continue - } - remaining := cred.weeklyCap() - cred.weeklyUtilization() - score := remaining * cred.planWeight() - resetTime := cred.weeklyResetTime() - if !resetTime.IsZero() { - timeUntilReset := resetTime.Sub(now) - if timeUntilReset < time.Hour { - timeUntilReset = time.Hour - } - score *= weeklyWindowDuration / timeUntilReset.Hours() - } - if score > bestScore { - bestScore = score - best = cred - } - } - return best -} - -const weeklyWindowDuration = 7 * 24 // hours - -func ccmPlanWeight(accountType string, rateLimitTier string) float64 { - switch accountType { - case "max": - switch rateLimitTier { - case "default_claude_max_20x": - return 10 - case "default_claude_max_5x": - return 5 - default: - return 5 - } - case "team": - if rateLimitTier == "default_claude_max_5x" { - return 5 - } - return 1 - default: - return 1 - } -} - -func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { - start := int(p.roundRobinIndex.Add(1) - 1) - count := len(p.credentials) - for offset := range count { - candidate := p.credentials[(start+offset)%count] - if filter != nil && !filter(candidate) { - continue - } - if candidate.isUsable() { - return candidate - } - } - return nil -} - -func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { - var usable []credential - for _, candidate := range p.credentials { - if filter != nil && !filter(candidate) { - continue - } - if candidate.isUsable() { - usable = append(usable, candidate) - } - } - if len(usable) == 0 { - return nil - } - return usable[rand.IntN(len(usable))] -} - -func (p *balancerProvider) pollIfStale(ctx context.Context) { - now := time.Now() - p.sessionMutex.Lock() - for id, entry := range p.sessions { - if now.Sub(entry.createdAt) > sessionExpiry { - delete(p.sessions, id) - } - } - p.sessionMutex.Unlock() - - for _, cred := range p.credentials { - if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { - cred.pollUsage(ctx) - } - } -} - -func (p *balancerProvider) allCredentials() []credential { - return p.credentials -} - -func (p *balancerProvider) close() {} - -func allCredentialsUnavailableError(credentials []credential) error { - var hasUnavailable bool - var earliest time.Time - for _, cred := range credentials { - if cred.unavailableError() != nil { - hasUnavailable = true - continue - } - resetAt := cred.earliestReset() - if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { - earliest = resetAt - } - } - if hasUnavailable { - return E.New("all credentials unavailable") - } - if earliest.IsZero() { - return E.New("all credentials rate-limited") - } - return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) -} - -func extractCCMSessionID(bodyBytes []byte) string { - var body struct { - Metadata struct { - UserID string `json:"user_id"` - } `json:"metadata"` - } - err := json.Unmarshal(bodyBytes, &body) - if err != nil { - return "" - } - userID := body.Metadata.UserID - sessionIndex := strings.LastIndex(userID, "_session_") - if sessionIndex < 0 { - return "" - } - return userID[sessionIndex+len("_session_"):] -} - -func buildCredentialProviders( - ctx context.Context, - options option.CCMServiceOptions, - logger log.ContextLogger, -) (map[string]credentialProvider, []credential, error) { - allCredentialMap := make(map[string]credential) - var allCreds []credential - providers := make(map[string]credentialProvider) - - // Pass 1: create default and external credentials - for _, credOpt := range options.Credentials { - switch credOpt.Type { - case "default": - cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) - if err != nil { - return nil, nil, err - } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} - case "external": - cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) - if err != nil { - return nil, nil, err - } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} - } - } - - // Pass 2: create balancer providers - for _, credOpt := range options.Credentials { - if credOpt.Type == "balancer" { - subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) - if err != nil { - return nil, nil, err - } - providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) - } - } - - return providers, allCreds, nil -} - -func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { - credentials := make([]credential, 0, len(tags)) - for _, tag := range tags { - cred, exists := allCredentials[tag] - if !exists { - return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) - } - credentials = append(credentials, cred) - } - if len(credentials) == 0 { - return nil, E.New("credential ", parentTag, " has no sub-credentials") - } - return credentials, nil -} - -func parseRateLimitResetFromHeaders(headers http.Header) time.Time { - claim := headers.Get("anthropic-ratelimit-unified-representative-claim") - switch claim { - case "5h": - return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset") - case "7d": - return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") - default: - panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim)) - } -} - -func validateCCMOptions(options option.CCMServiceOptions) error { - hasCredentials := len(options.Credentials) > 0 - hasLegacyPath := options.CredentialPath != "" - hasLegacyUsages := options.UsagesPath != "" - hasLegacyDetour := options.Detour != "" - - if hasCredentials && hasLegacyPath { - return E.New("credential_path and credentials are mutually exclusive") - } - if hasCredentials && hasLegacyUsages { - return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") - } - if hasCredentials && hasLegacyDetour { - return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") - } - - if hasCredentials { - tags := make(map[string]bool) - credentialTypes := make(map[string]string) - for _, cred := range options.Credentials { - if tags[cred.Tag] { - return E.New("duplicate credential tag: ", cred.Tag) - } - tags[cred.Tag] = true - credentialTypes[cred.Tag] = cred.Type - if cred.Type == "default" || cred.Type == "" { - if cred.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") - } - if cred.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") - } - if cred.DefaultOptions.Limit5h > 100 { - return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") - } - if cred.DefaultOptions.LimitWeekly > 100 { - return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") - } - if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { - return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") - } - if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { - return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") - } - } - if cred.Type == "external" { - if cred.ExternalOptions.Token == "" { - return E.New("credential ", cred.Tag, ": external credential requires token") - } - if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": reverse external credential requires url") - } - } - if cred.Type == "balancer" { - switch cred.BalancerOptions.Strategy { - case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: - default: - return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) - } - if cred.BalancerOptions.RebalanceThreshold < 0 { - return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") - } - } - } - - for _, user := range options.Users { - if user.Credential == "" { - return E.New("user ", user.Name, " must specify credential in multi-credential mode") - } - if !tags[user.Credential] { - return E.New("user ", user.Name, " references unknown credential: ", user.Credential) - } - if user.ExternalCredential != "" { - if !tags[user.ExternalCredential] { - return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) - } - if credentialTypes[user.ExternalCredential] != "external" { - return E.New("user ", user.Name, ": external_credential must reference an external type credential") - } - } - } - } - - return nil -} - -// credentialForUser finds the credential provider for a user. -// In legacy mode, returns the single provider. -// In multi-credential mode, returns the provider mapped to the user's credential tag. -func credentialForUser( - userConfigMap map[string]*option.CCMUser, - providers map[string]credentialProvider, - legacyProvider credentialProvider, - username string, -) (credentialProvider, error) { - if legacyProvider != nil { - return legacyProvider, nil - } - userConfig, exists := userConfigMap[username] - if !exists { - return nil, E.New("no credential mapping for user: ", username) - } - provider, exists := providers[userConfig.Credential] - if !exists { - return nil, E.New("unknown credential: ", userConfig.Credential) - } - return provider, nil -} - -// noUserCredentialProvider returns the single provider for legacy mode or the first credential in multi-credential mode (no auth). -func noUserCredentialProvider( - providers map[string]credentialProvider, - legacyProvider credentialProvider, - options option.CCMServiceOptions, -) credentialProvider { - if legacyProvider != nil { - return legacyProvider - } - if len(options.Credentials) > 0 { - tag := options.Credentials[0].Tag - return providers[tag] - } - return nil -} diff --git a/service/ccm/service.go b/service/ccm/service.go index 6a2aa2b740..6dce1931bd 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -1,17 +1,12 @@ package ccm import ( - "bytes" "context" "encoding/json" "errors" - "io" - "mime" "net/http" - "strconv" "strings" "sync" - "time" "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" @@ -21,23 +16,16 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" aTLS "github.com/sagernet/sing/common/tls" - "github.com/anthropics/anthropic-sdk-go" "github.com/go-chi/chi/v5" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) -const ( - contextWindowStandard = 200000 - contextWindowPremium = 1000000 - premiumContextThreshold = 200000 - retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" -) +const retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" func RegisterService(registry *boxService.Registry) { boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService) @@ -152,23 +140,6 @@ func isReverseProxyHeader(header string) bool { } } -const ( - weeklyWindowSeconds = 604800 - weeklyWindowMinutes = weeklyWindowSeconds / 60 -) - -func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { - resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") - if !exists { - return nil - } - - return &WeeklyCycleHint{ - WindowMinutes: weeklyWindowMinutes, - ResetAt: resetAt.UTC(), - } -} - type Service struct { boxService.Adapter ctx context.Context @@ -308,545 +279,6 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } -func isExtendedContextRequest(betaHeader string) bool { - for _, feature := range strings.Split(betaHeader, ",") { - if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { - return true - } - } - return false -} - -func isFastModeRequest(betaHeader string) bool { - for _, feature := range strings.Split(betaHeader, ",") { - if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") { - return true - } - } - return false -} - -func detectContextWindow(betaHeader string, totalInputTokens int64) int { - if totalInputTokens > premiumContextThreshold { - if isExtendedContextRequest(betaHeader) { - return contextWindowPremium - } - } - return contextWindowStandard -} - -func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := log.ContextWithNewID(r.Context()) - if r.URL.Path == "/ccm/v1/status" { - s.handleStatusEndpoint(w, r) - return - } - - if r.URL.Path == "/ccm/v1/reverse" { - s.handleReverseConnect(ctx, w, r) - return - } - - if !strings.HasPrefix(r.URL.Path, "/v1/") { - writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found") - return - } - - var username string - if len(s.options.Users) > 0 { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - var ok bool - username, ok = s.userManager.Authenticate(clientToken) - if !ok { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } - } - - // Always read body to extract model and session ID - var bodyBytes []byte - var requestModel string - var messagesCount int - var sessionID string - - if r.Body != nil { - var err error - bodyBytes, err = io.ReadAll(r.Body) - if err != nil { - s.logger.ErrorContext(ctx, "read request body: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") - return - } - - var request struct { - Model string `json:"model"` - Messages []anthropic.MessageParam `json:"messages"` - } - err = json.Unmarshal(bodyBytes, &request) - if err == nil { - requestModel = request.Model - messagesCount = len(request.Messages) - } - - sessionID = extractCCMSessionID(bodyBytes) - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - // Resolve credential provider and user config - var provider credentialProvider - var userConfig *option.CCMUser - if len(s.options.Users) > 0 { - userConfig = s.userConfigMap[username] - var err error - provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - s.logger.ErrorContext(ctx, "resolve credential: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) - return - } - } else { - provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) - } - if provider == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") - return - } - - provider.pollIfStale(s.ctx) - - anthropicBetaHeader := r.Header.Get("anthropic-beta") - if isFastModeRequest(anthropicBetaHeader) { - if _, isSingle := provider.(*singleCredentialProvider); !isSingle { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "fast mode requests will consume Extra usage, please use a default credential directly") - return - } - } - - selection := credentialSelectionForUser(userConfig) - - selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) - if err != nil { - writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) - return - } - if isNew { - logParts := []any{"assigned credential ", selectedCredential.tagName()} - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if username != "" { - logParts = append(logParts, " by user ", username) - } - if requestModel != "" { - modelDisplay := requestModel - if isExtendedContextRequest(anthropicBetaHeader) { - modelDisplay += "[1m]" - } - logParts = append(logParts, ", model=", modelDisplay) - } - s.logger.DebugContext(ctx, logParts...) - } - - if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "fast mode requests cannot be proxied through external credentials") - return - } - - requestContext := selectedCredential.wrapRequestContext(ctx) - { - currentRequestContext := requestContext - requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { - currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) - })) - } - defer func() { - requestContext.cancelRequest() - }() - proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if err != nil { - s.logger.ErrorContext(ctx, "create proxy request: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") - return - } - - response, err := selectedCredential.httpTransport().Do(proxyRequest) - if err != nil { - if r.Context().Err() != nil { - return - } - if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") - return - } - writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) - return - } - requestContext.releaseCredentialInterrupt() - - // Transparent 429 retry - for response.StatusCode == http.StatusTooManyRequests { - resetAt := parseRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) - selectedCredential.updateStateFromHeaders(response.Header) - if bodyBytes == nil || nextCredential == nil { - response.Body.Close() - writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") - return - } - response.Body.Close() - s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) - requestContext.cancelRequest() - requestContext = nextCredential.wrapRequestContext(ctx) - { - currentRequestContext := requestContext - requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { - currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) - })) - } - retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if buildErr != nil { - s.logger.ErrorContext(ctx, "retry request: ", buildErr) - writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) - return - } - retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest) - if retryErr != nil { - if r.Context().Err() != nil { - return - } - if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") - return - } - s.logger.ErrorContext(ctx, "retry request: ", retryErr) - writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) - return - } - requestContext.releaseCredentialInterrupt() - response = retryResponse - selectedCredential = nextCredential - } - defer response.Body.Close() - - selectedCredential.updateStateFromHeaders(response.Header) - - if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { - body, _ := io.ReadAll(response.Body) - s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) - go selectedCredential.pollUsage(s.ctx) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", - "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) - return - } - - // Rewrite response headers for external users - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) - } - - for key, values := range response.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { - w.Header()[key] = values - } - } - w.WriteHeader(response.StatusCode) - - usageTracker := selectedCredential.usageTrackerOrNil() - if usageTracker != nil && response.StatusCode == http.StatusOK { - s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) - } else { - mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) - if err == nil && mediaType != "text/event-stream" { - _, _ = io.Copy(w, response.Body) - return - } - flusher, ok := w.(http.Flusher) - if !ok { - s.logger.ErrorContext(ctx, "streaming not supported") - return - } - buffer := make([]byte, buf.BufferSize) - for { - n, err := response.Body.Read(buffer) - if n > 0 { - _, writeError := w.Write(buffer[:n]) - if writeError != nil { - s.logger.ErrorContext(ctx, "write streaming response: ", writeError) - return - } - flusher.Flush() - } - if err != nil { - return - } - } - } -} - -func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { - weeklyCycleHint := extractWeeklyCycleHint(response.Header) - mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) - isStreaming := err == nil && mediaType == "text/event-stream" - - if !isStreaming { - bodyBytes, err := io.ReadAll(response.Body) - if err != nil { - s.logger.ErrorContext(ctx, "read response body: ", err) - return - } - - var message anthropic.Message - var usage anthropic.Usage - var responseModel string - err = json.Unmarshal(bodyBytes, &message) - if err == nil { - responseModel = string(message.Model) - usage = message.Usage - } - if responseModel == "" { - responseModel = requestModel - } - - if usage.InputTokens > 0 || usage.OutputTokens > 0 { - if responseModel != "" { - totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens - contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - messagesCount, - usage.InputTokens, - usage.OutputTokens, - usage.CacheReadInputTokens, - usage.CacheCreationInputTokens, - usage.CacheCreation.Ephemeral5mInputTokens, - usage.CacheCreation.Ephemeral1hInputTokens, - username, - time.Now(), - weeklyCycleHint, - ) - } - } - - _, _ = writer.Write(bodyBytes) - return - } - - flusher, ok := writer.(http.Flusher) - if !ok { - s.logger.ErrorContext(ctx, "streaming not supported") - return - } - - var accumulatedUsage anthropic.Usage - var responseModel string - buffer := make([]byte, buf.BufferSize) - var leftover []byte - - for { - n, err := response.Body.Read(buffer) - if n > 0 { - data := append(leftover, buffer[:n]...) - lines := bytes.Split(data, []byte("\n")) - - if err == nil { - leftover = lines[len(lines)-1] - lines = lines[:len(lines)-1] - } else { - leftover = nil - } - - for _, line := range lines { - line = bytes.TrimSpace(line) - if len(line) == 0 { - continue - } - - if bytes.HasPrefix(line, []byte("data: ")) { - eventData := bytes.TrimPrefix(line, []byte("data: ")) - if bytes.Equal(eventData, []byte("[DONE]")) { - continue - } - - var event anthropic.MessageStreamEventUnion - err := json.Unmarshal(eventData, &event) - if err != nil { - continue - } - switch event.Type { - case "message_start": - messageStart := event.AsMessageStart() - if messageStart.Message.Model != "" { - responseModel = string(messageStart.Message.Model) - } - if messageStart.Message.Usage.InputTokens > 0 { - accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens - accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens - accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens - accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens - accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens - } - case "message_delta": - messageDelta := event.AsMessageDelta() - if messageDelta.Usage.OutputTokens > 0 { - accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens - } - } - } - } - - _, writeError := writer.Write(buffer[:n]) - if writeError != nil { - s.logger.ErrorContext(ctx, "write streaming response: ", writeError) - return - } - flusher.Flush() - } - - if err != nil { - if responseModel == "" { - responseModel = requestModel - } - - if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 { - if responseModel != "" { - totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens - contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - messagesCount, - accumulatedUsage.InputTokens, - accumulatedUsage.OutputTokens, - accumulatedUsage.CacheReadInputTokens, - accumulatedUsage.CacheCreationInputTokens, - accumulatedUsage.CacheCreation.Ephemeral5mInputTokens, - accumulatedUsage.CacheCreation.Ephemeral1hInputTokens, - username, - time.Now(), - weeklyCycleHint, - ) - } - } - return - } - } -} - -func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") - return - } - - if len(s.options.Users) == 0 { - writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") - return - } - - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - username, ok := s.userManager.Authenticate(clientToken) - if !ok { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } - - userConfig := s.userConfigMap[username] - if userConfig == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") - return - } - - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) - return - } - - provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]float64{ - "five_hour_utilization": avgFiveHour, - "weekly_utilization": avgWeekly, - "plan_weight": totalWeight, - }) -} - -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { - var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 - for _, cred := range provider.allCredentials() { - if !cred.isAvailable() { - continue - } - if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { - continue - } - if !userConfig.AllowExternalUsage && cred.isExternal() { - continue - } - weight := cred.planWeight() - remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() - if remaining5h < 0 { - remaining5h = 0 - } - remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() - if remainingWeekly < 0 { - remainingWeekly = 0 - } - totalWeightedRemaining5h += remaining5h * weight - totalWeightedRemainingWeekly += remainingWeekly * weight - totalWeight += weight - } - if totalWeight == 0 { - return 100, 100, 0 - } - return 100 - totalWeightedRemaining5h/totalWeight, - 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight -} - -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) - if err != nil { - return - } - - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - // Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range) - headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) - headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) - if totalWeight > 0 { - headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) - } -} - func (s *Service) InterfaceUpdated() { for _, cred := range s.allCredentials { extCred, ok := cred.(*externalCredential) diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go new file mode 100644 index 0000000000..7dd0c64115 --- /dev/null +++ b/service/ccm/service_handler.go @@ -0,0 +1,499 @@ +package ccm + +import ( + "bytes" + "context" + "encoding/json" + "io" + "mime" + "net/http" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/buf" + + "github.com/anthropics/anthropic-sdk-go" +) + +const ( + contextWindowStandard = 200000 + contextWindowPremium = 1000000 + premiumContextThreshold = 200000 +) + +const ( + weeklyWindowSeconds = 604800 + weeklyWindowMinutes = weeklyWindowSeconds / 60 +) + +func isExtendedContextRequest(betaHeader string) bool { + for _, feature := range strings.Split(betaHeader, ",") { + if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { + return true + } + } + return false +} + +func isFastModeRequest(betaHeader string) bool { + for _, feature := range strings.Split(betaHeader, ",") { + if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") { + return true + } + } + return false +} + +func detectContextWindow(betaHeader string, totalInputTokens int64) int { + if totalInputTokens > premiumContextThreshold { + if isExtendedContextRequest(betaHeader) { + return contextWindowPremium + } + } + return contextWindowStandard +} + +func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { + resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") + if !exists { + return nil + } + + return &WeeklyCycleHint{ + WindowMinutes: weeklyWindowMinutes, + ResetAt: resetAt.UTC(), + } +} + +func extractCCMSessionID(bodyBytes []byte) string { + var body struct { + Metadata struct { + UserID string `json:"user_id"` + } `json:"metadata"` + } + err := json.Unmarshal(bodyBytes, &body) + if err != nil { + return "" + } + userID := body.Metadata.UserID + sessionIndex := strings.LastIndex(userID, "_session_") + if sessionIndex < 0 { + return "" + } + return userID[sessionIndex+len("_session_"):] +} + +func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := log.ContextWithNewID(r.Context()) + if r.URL.Path == "/ccm/v1/status" { + s.handleStatusEndpoint(w, r) + return + } + + if r.URL.Path == "/ccm/v1/reverse" { + s.handleReverseConnect(ctx, w, r) + return + } + + if !strings.HasPrefix(r.URL.Path, "/v1/") { + writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found") + return + } + + var username string + if len(s.options.Users) > 0 { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + var ok bool + username, ok = s.userManager.Authenticate(clientToken) + if !ok { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + } + + // Always read body to extract model and session ID + var bodyBytes []byte + var requestModel string + var messagesCount int + var sessionID string + + if r.Body != nil { + var err error + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + s.logger.ErrorContext(ctx, "read request body: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") + return + } + + var request struct { + Model string `json:"model"` + Messages []anthropic.MessageParam `json:"messages"` + } + err = json.Unmarshal(bodyBytes, &request) + if err == nil { + requestModel = request.Model + messagesCount = len(request.Messages) + } + + sessionID = extractCCMSessionID(bodyBytes) + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + // Resolve credential provider and user config + var provider credentialProvider + var userConfig *option.CCMUser + if len(s.options.Users) > 0 { + userConfig = s.userConfigMap[username] + var err error + provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + s.logger.ErrorContext(ctx, "resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + } + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") + return + } + + provider.pollIfStale(s.ctx) + + anthropicBetaHeader := r.Header.Get("anthropic-beta") + if isFastModeRequest(anthropicBetaHeader) { + if _, isSingle := provider.(*singleCredentialProvider); !isSingle { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "fast mode requests will consume Extra usage, please use a default credential directly") + return + } + } + + selection := credentialSelectionForUser(userConfig) + + selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) + if err != nil { + writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) + return + } + if isNew { + logParts := []any{"assigned credential ", selectedCredential.tagName()} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + if requestModel != "" { + modelDisplay := requestModel + if isExtendedContextRequest(anthropicBetaHeader) { + modelDisplay += "[1m]" + } + logParts = append(logParts, ", model=", modelDisplay) + } + s.logger.DebugContext(ctx, logParts...) + } + + if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "fast mode requests cannot be proxied through external credentials") + return + } + + requestContext := selectedCredential.wrapRequestContext(ctx) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } + defer func() { + requestContext.cancelRequest() + }() + proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if err != nil { + s.logger.ErrorContext(ctx, "create proxy request: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") + return + } + + response, err := selectedCredential.httpClient().Do(proxyRequest) + if err != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") + return + } + writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) + return + } + requestContext.releaseCredentialInterrupt() + + // Transparent 429 retry + for response.StatusCode == http.StatusTooManyRequests { + resetAt := parseRateLimitResetFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) + selectedCredential.updateStateFromHeaders(response.Header) + if bodyBytes == nil || nextCredential == nil { + response.Body.Close() + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") + return + } + response.Body.Close() + s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + requestContext.cancelRequest() + requestContext = nextCredential.wrapRequestContext(ctx) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } + retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if buildErr != nil { + s.logger.ErrorContext(ctx, "retry request: ", buildErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) + return + } + retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest) + if retryErr != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") + return + } + s.logger.ErrorContext(ctx, "retry request: ", retryErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) + return + } + requestContext.releaseCredentialInterrupt() + response = retryResponse + selectedCredential = nextCredential + } + defer response.Body.Close() + + selectedCredential.updateStateFromHeaders(response.Header) + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) + go selectedCredential.pollUsage(s.ctx) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) + return + } + + // Rewrite response headers for external users + if userConfig != nil && userConfig.ExternalCredential != "" { + s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) + } + + for key, values := range response.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { + w.Header()[key] = values + } + } + w.WriteHeader(response.StatusCode) + + usageTracker := selectedCredential.usageTrackerOrNil() + if usageTracker != nil && response.StatusCode == http.StatusOK { + s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) + } else { + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + if err == nil && mediaType != "text/event-stream" { + _, _ = io.Copy(w, response.Body) + return + } + flusher, ok := w.(http.Flusher) + if !ok { + s.logger.ErrorContext(ctx, "streaming not supported") + return + } + buffer := make([]byte, buf.BufferSize) + for { + n, err := response.Body.Read(buffer) + if n > 0 { + _, writeError := w.Write(buffer[:n]) + if writeError != nil { + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) + return + } + flusher.Flush() + } + if err != nil { + return + } + } + } +} + +func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { + weeklyCycleHint := extractWeeklyCycleHint(response.Header) + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + isStreaming := err == nil && mediaType == "text/event-stream" + + if !isStreaming { + bodyBytes, err := io.ReadAll(response.Body) + if err != nil { + s.logger.ErrorContext(ctx, "read response body: ", err) + return + } + + var message anthropic.Message + var usage anthropic.Usage + var responseModel string + err = json.Unmarshal(bodyBytes, &message) + if err == nil { + responseModel = string(message.Model) + usage = message.Usage + } + if responseModel == "" { + responseModel = requestModel + } + + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + if responseModel != "" { + totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens + contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + messagesCount, + usage.InputTokens, + usage.OutputTokens, + usage.CacheReadInputTokens, + usage.CacheCreationInputTokens, + usage.CacheCreation.Ephemeral5mInputTokens, + usage.CacheCreation.Ephemeral1hInputTokens, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + + _, _ = writer.Write(bodyBytes) + return + } + + flusher, ok := writer.(http.Flusher) + if !ok { + s.logger.ErrorContext(ctx, "streaming not supported") + return + } + + var accumulatedUsage anthropic.Usage + var responseModel string + buffer := make([]byte, buf.BufferSize) + var leftover []byte + + for { + n, err := response.Body.Read(buffer) + if n > 0 { + data := append(leftover, buffer[:n]...) + lines := bytes.Split(data, []byte("\n")) + + if err == nil { + leftover = lines[len(lines)-1] + lines = lines[:len(lines)-1] + } else { + leftover = nil + } + + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + + if bytes.HasPrefix(line, []byte("data: ")) { + eventData := bytes.TrimPrefix(line, []byte("data: ")) + if bytes.Equal(eventData, []byte("[DONE]")) { + continue + } + + var event anthropic.MessageStreamEventUnion + err := json.Unmarshal(eventData, &event) + if err != nil { + continue + } + switch event.Type { + case "message_start": + messageStart := event.AsMessageStart() + if messageStart.Message.Model != "" { + responseModel = string(messageStart.Message.Model) + } + if messageStart.Message.Usage.InputTokens > 0 { + accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens + accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens + accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens + accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens + accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens + } + case "message_delta": + messageDelta := event.AsMessageDelta() + if messageDelta.Usage.OutputTokens > 0 { + accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens + } + } + } + } + + _, writeError := writer.Write(buffer[:n]) + if writeError != nil { + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) + return + } + flusher.Flush() + } + + if err != nil { + if responseModel == "" { + responseModel = requestModel + } + + if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 { + if responseModel != "" { + totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens + contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + messagesCount, + accumulatedUsage.InputTokens, + accumulatedUsage.OutputTokens, + accumulatedUsage.CacheReadInputTokens, + accumulatedUsage.CacheCreationInputTokens, + accumulatedUsage.CacheCreation.Ephemeral5mInputTokens, + accumulatedUsage.CacheCreation.Ephemeral1hInputTokens, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + return + } + } +} diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go new file mode 100644 index 0000000000..3f91b46149 --- /dev/null +++ b/service/ccm/service_status.go @@ -0,0 +1,109 @@ +package ccm + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/sagernet/sing-box/option" +) + +func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") + return + } + + if len(s.options.Users) == 0 { + writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + username, ok := s.userManager.Authenticate(clientToken) + if !ok { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + + userConfig := s.userConfigMap[username] + if userConfig == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") + return + } + + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + provider.pollIfStale(r.Context()) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]float64{ + "five_hour_utilization": avgFiveHour, + "weekly_utilization": avgWeekly, + "plan_weight": totalWeight, + }) +} + +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + for _, cred := range provider.allCredentials() { + if !cred.isAvailable() { + continue + } + if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { + continue + } + if !userConfig.AllowExternalUsage && cred.isExternal() { + continue + } + weight := cred.planWeight() + remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 + } + remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 + } + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight + } + if totalWeight == 0 { + return 100, 100, 0 + } + return 100 - totalWeightedRemaining5h/totalWeight, + 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight +} + +func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) + if err != nil { + return + } + + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) + headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) + if totalWeight > 0 { + headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + } +} diff --git a/service/ccm/service_user.go b/service/ccm/service_user.go index 94637ed814..149894c048 100644 --- a/service/ccm/service_user.go +++ b/service/ccm/service_user.go @@ -7,13 +7,13 @@ import ( ) type UserManager struct { - accessMutex sync.RWMutex + access sync.RWMutex tokenMap map[string]string } func (m *UserManager) UpdateUsers(users []option.CCMUser) { - m.accessMutex.Lock() - defer m.accessMutex.Unlock() + m.access.Lock() + defer m.access.Unlock() tokenMap := make(map[string]string, len(users)) for _, user := range users { tokenMap[user.Token] = user.Name @@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.CCMUser) { } func (m *UserManager) Authenticate(token string) (string, bool) { - m.accessMutex.RLock() + m.access.RLock() username, found := m.tokenMap[token] - m.accessMutex.RUnlock() + m.access.RUnlock() return username, found } diff --git a/service/ocm/credential.go b/service/ocm/credential.go index bb240b5aba..27a8894705 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -1,225 +1,194 @@ package ocm import ( - "bytes" "context" - "encoding/json" - "io" "net/http" - "os" - "os/user" - "path/filepath" + "strconv" + "strings" + "sync" "time" - E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" ) const ( - oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" - oauth2TokenURL = "https://auth.openai.com/oauth/token" - openaiAPIBaseURL = "https://api.openai.com" - chatGPTBackendURL = "https://chatgpt.com/backend-api/codex" - tokenRefreshIntervalDays = 8 + defaultPollInterval = 60 * time.Minute + failedPollRetryInterval = time.Minute + httpRetryMaxBackoff = 5 * time.Minute ) -func getRealUser() (*user.User, error) { - if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { - sudoUserInfo, err := user.Lookup(sudoUser) +const ( + httpRetryMaxAttempts = 3 + httpRetryInitialDelay = 200 * time.Millisecond +) + +const sessionExpiry = 24 * time.Hour + +func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { + var lastError error + for attempt := range httpRetryMaxAttempts { + if attempt > 0 { + delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + select { + case <-ctx.Done(): + return nil, lastError + case <-time.After(delay): + } + } + request, err := buildRequest() + if err != nil { + return nil, err + } + response, err := client.Do(request) if err == nil { - return sudoUserInfo, nil + return response, nil + } + lastError = err + if ctx.Err() != nil { + return nil, lastError } } - return user.Current() + return nil, lastError } -func getDefaultCredentialsPath() (string, error) { - if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" { - return filepath.Join(codexHome, "auth.json"), nil - } - userInfo, err := getRealUser() - if err != nil { - return "", err - } - return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil +type credentialState struct { + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + remotePlanWeight float64 + lastUpdated time.Time + consecutivePollFailures int + unavailable bool + lastCredentialLoadAttempt time.Time + lastCredentialLoadError string } -func readCredentialsFromFile(path string) (*oauthCredentials, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var credentials oauthCredentials - err = json.Unmarshal(data, &credentials) - if err != nil { - return nil, err - } - return &credentials, nil +type credentialRequestContext struct { + context.Context + releaseOnce sync.Once + cancelOnce sync.Once + releaseFuncs []func() bool + cancelFunc context.CancelFunc } -func checkCredentialFileWritable(path string) error { - file, err := os.OpenFile(path, os.O_WRONLY, 0) - if err != nil { - return err - } - return file.Close() +func (c *credentialRequestContext) addInterruptLink(stop func() bool) { + c.releaseFuncs = append(c.releaseFuncs, stop) } -func writeCredentialsToFile(credentials *oauthCredentials, path string) error { - data, err := json.MarshalIndent(credentials, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, data, 0o600) +func (c *credentialRequestContext) releaseCredentialInterrupt() { + c.releaseOnce.Do(func() { + for _, f := range c.releaseFuncs { + f() + } + }) } -type oauthCredentials struct { - APIKey string `json:"OPENAI_API_KEY,omitempty"` - Tokens *tokenData `json:"tokens,omitempty"` - LastRefresh *time.Time `json:"last_refresh,omitempty"` +func (c *credentialRequestContext) cancelRequest() { + c.releaseCredentialInterrupt() + c.cancelOnce.Do(c.cancelFunc) } -type tokenData struct { - IDToken string `json:"id_token,omitempty"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - AccountID string `json:"account_id,omitempty"` +type credential interface { + tagName() string + isAvailable() bool + isUsable() bool + isExternal() bool + fiveHourUtilization() float64 + weeklyUtilization() float64 + fiveHourCap() float64 + weeklyCap() float64 + planWeight() float64 + weeklyResetTime() time.Time + markRateLimited(resetAt time.Time) + earliestReset() time.Time + unavailableError() error + + getAccessToken() (string, error) + buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) + updateStateFromHeaders(header http.Header) + + wrapRequestContext(ctx context.Context) *credentialRequestContext + interruptConnections() + + setOnBecameUnusable(fn func()) + start() error + pollUsage(ctx context.Context) + lastUpdatedTime() time.Time + pollBackoff(base time.Duration) time.Duration + usageTrackerOrNil() *AggregatedUsage + httpClient() *http.Client + close() + + // OCM-specific + ocmDialer() N.Dialer + ocmIsAPIKeyMode() bool + ocmGetAccountID() string + ocmGetBaseURL() string } -func (c *oauthCredentials) isAPIKeyMode() bool { - return c.APIKey != "" -} +type credentialSelectionScope string -func (c *oauthCredentials) getAccessToken() string { - if c.APIKey != "" { - return c.APIKey - } - if c.Tokens != nil { - return c.Tokens.AccessToken - } - return "" -} +const ( + credentialSelectionScopeAll credentialSelectionScope = "all" + credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" +) -func (c *oauthCredentials) getAccountID() string { - if c.Tokens != nil { - return c.Tokens.AccountID - } - return "" +type credentialSelection struct { + scope credentialSelectionScope + filter func(credential) bool } -func (c *oauthCredentials) needsRefresh() bool { - if c.APIKey != "" { - return false - } - if c.Tokens == nil || c.Tokens.RefreshToken == "" { - return false - } - if c.LastRefresh == nil { - return true - } - return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour +func (s credentialSelection) allows(cred credential) bool { + return s.filter == nil || s.filter(cred) } -func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { - if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" { - return nil, E.New("refresh token is empty") - } - - requestBody, err := json.Marshal(map[string]string{ - "grant_type": "refresh_token", - "refresh_token": credentials.Tokens.RefreshToken, - "client_id": oauth2ClientID, - "scope": "openid profile email", - }) - if err != nil { - return nil, E.Cause(err, "marshal request") - } - - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, err - } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "application/json") - return request, nil - }) - if err != nil { - return nil, err - } - defer response.Body.Close() - - if response.StatusCode == http.StatusTooManyRequests { - body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) - } - if response.StatusCode != http.StatusOK { - body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh failed: ", response.Status, " ", string(body)) - } - - var tokenResponse struct { - IDToken string `json:"id_token"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - } - err = json.NewDecoder(response.Body).Decode(&tokenResponse) - if err != nil { - return nil, E.Cause(err, "decode response") +func (s credentialSelection) scopeOrDefault() credentialSelectionScope { + if s.scope == "" { + return credentialSelectionScopeAll } + return s.scope +} - newCredentials := *credentials - if newCredentials.Tokens == nil { - newCredentials.Tokens = &tokenData{} - } - if tokenResponse.IDToken != "" { - newCredentials.Tokens.IDToken = tokenResponse.IDToken - } - if tokenResponse.AccessToken != "" { - newCredentials.Tokens.AccessToken = tokenResponse.AccessToken - } - if tokenResponse.RefreshToken != "" { - newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken +func normalizeRateLimitIdentifier(limitIdentifier string) string { + trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier)) + if trimmedIdentifier == "" { + return "" } - now := time.Now() - newCredentials.LastRefresh = &now - - return &newCredentials, nil + return strings.ReplaceAll(trimmedIdentifier, "_", "-") } -func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { - if credentials == nil { - return nil - } - cloned := *credentials - if credentials.Tokens != nil { - clonedTokens := *credentials.Tokens - cloned.Tokens = &clonedTokens +func parseInt64Header(headers http.Header, headerName string) (int64, bool) { + headerValue := strings.TrimSpace(headers.Get(headerName)) + if headerValue == "" { + return 0, false } - if credentials.LastRefresh != nil { - lastRefresh := *credentials.LastRefresh - cloned.LastRefresh = &lastRefresh + parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64) + if parseError != nil { + return 0, false } - return &cloned + return parsedValue, true } -func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { - if left == nil || right == nil { - return left == right - } - if left.APIKey != right.APIKey { - return false - } - if (left.Tokens == nil) != (right.Tokens == nil) { - return false - } - if left.Tokens != nil && *left.Tokens != *right.Tokens { - return false - } - if (left.LastRefresh == nil) != (right.LastRefresh == nil) { - return false +func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time { + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier != "" { + resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at" + if resetStr := headers.Get(resetHeader); resetStr != "" { + value, err := strconv.ParseInt(resetStr, 10, 64) + if err == nil { + return time.Unix(value, 0) + } + } } - if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) { - return false + if retryAfter := headers.Get("Retry-After"); retryAfter != "" { + seconds, err := strconv.ParseInt(retryAfter, 10, 64) + if err == nil { + return time.Now().Add(time.Duration(seconds) * time.Second) + } } - return true + return time.Now().Add(5 * time.Minute) } diff --git a/service/ocm/credential_builder.go b/service/ocm/credential_builder.go new file mode 100644 index 0000000000..5faaf67c67 --- /dev/null +++ b/service/ocm/credential_builder.go @@ -0,0 +1,223 @@ +package ocm + +import ( + "context" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func buildOCMCredentialProviders( + ctx context.Context, + options option.OCMServiceOptions, + logger log.ContextLogger, +) (map[string]credentialProvider, []credential, error) { + allCredentialMap := make(map[string]credential) + var allCreds []credential + providers := make(map[string]credentialProvider) + + // Pass 1: create default and external credentials + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "default": + cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + case "external": + cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + } + } + + // Pass 2: create balancer providers + for _, credOpt := range options.Credentials { + if credOpt.Type == "balancer" { + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) + } + } + + return providers, allCreds, nil +} + +func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { + credentials := make([]credential, 0, len(tags)) + for _, tag := range tags { + cred, exists := allCredentials[tag] + if !exists { + return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) + } + credentials = append(credentials, cred) + } + if len(credentials) == 0 { + return nil, E.New("credential ", parentTag, " has no sub-credentials") + } + return credentials, nil +} + +func validateOCMOptions(options option.OCMServiceOptions) error { + hasCredentials := len(options.Credentials) > 0 + hasLegacyPath := options.CredentialPath != "" + hasLegacyUsages := options.UsagesPath != "" + hasLegacyDetour := options.Detour != "" + + if hasCredentials && hasLegacyPath { + return E.New("credential_path and credentials are mutually exclusive") + } + if hasCredentials && hasLegacyUsages { + return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") + } + if hasCredentials && hasLegacyDetour { + return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") + } + + if hasCredentials { + tags := make(map[string]bool) + credentialTypes := make(map[string]string) + for _, cred := range options.Credentials { + if tags[cred.Tag] { + return E.New("duplicate credential tag: ", cred.Tag) + } + tags[cred.Tag] = true + credentialTypes[cred.Tag] = cred.Type + if cred.Type == "default" || cred.Type == "" { + if cred.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") + } + if cred.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") + } + if cred.DefaultOptions.Limit5h > 100 { + return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") + } + if cred.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") + } + if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { + return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") + } + if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") + } + } + if cred.Type == "external" { + if cred.ExternalOptions.Token == "" { + return E.New("credential ", cred.Tag, ": external credential requires token") + } + if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": reverse external credential requires url") + } + } + if cred.Type == "balancer" { + switch cred.BalancerOptions.Strategy { + case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: + default: + return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) + } + if cred.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") + } + } + } + + for _, user := range options.Users { + if user.Credential == "" { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + if user.ExternalCredential != "" { + if !tags[user.ExternalCredential] { + return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) + } + if credentialTypes[user.ExternalCredential] != "external" { + return E.New("user ", user.Name, ": external_credential must reference an external type credential") + } + } + } + } + + return nil +} + +func validateOCMCompositeCredentialModes( + options option.OCMServiceOptions, + providers map[string]credentialProvider, +) error { + for _, credOpt := range options.Credentials { + if credOpt.Type != "balancer" { + continue + } + + provider, exists := providers[credOpt.Tag] + if !exists { + return E.New("unknown credential: ", credOpt.Tag) + } + + for _, subCred := range provider.allCredentials() { + if !subCred.isAvailable() { + continue + } + if subCred.ocmIsAPIKeyMode() { + return E.New( + "credential ", credOpt.Tag, + " references API key default credential ", subCred.tagName(), + "; balancer and fallback only support OAuth default credentials", + ) + } + } + } + + return nil +} + +func credentialForUser( + userConfigMap map[string]*option.OCMUser, + providers map[string]credentialProvider, + legacyProvider credentialProvider, + username string, +) (credentialProvider, error) { + if legacyProvider != nil { + return legacyProvider, nil + } + userConfig, exists := userConfigMap[username] + if !exists { + return nil, E.New("no credential mapping for user: ", username) + } + provider, exists := providers[userConfig.Credential] + if !exists { + return nil, E.New("unknown credential: ", userConfig.Credential) + } + return provider, nil +} + +func noUserCredentialProvider( + providers map[string]credentialProvider, + legacyProvider credentialProvider, + options option.OCMServiceOptions, +) credentialProvider { + if legacyProvider != nil { + return legacyProvider + } + if len(options.Credentials) > 0 { + tag := options.Credentials[0].Tag + return providers[tag] + } + return nil +} diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go new file mode 100644 index 0000000000..b82af9d20f --- /dev/null +++ b/service/ocm/credential_default.go @@ -0,0 +1,749 @@ +package ocm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" +) + +type defaultCredential struct { + tag string + serviceContext context.Context + credentialPath string + credentialFilePath string + credentials *oauthCredentials + access sync.RWMutex + state credentialState + stateAccess sync.RWMutex + pollAccess sync.Mutex + reloadAccess sync.Mutex + watcherAccess sync.Mutex + cap5h float64 + capWeekly float64 + usageTracker *AggregatedUsage + dialer N.Dialer + forwardHTTPClient *http.Client + logger log.ContextLogger + watcher *fswatch.Watcher + watcherRetryAt time.Time + + // Connection interruption + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + reserve5h := options.Reserve5h + if reserve5h == 0 { + reserve5h = 1 + } + reserveWeekly := options.ReserveWeekly + if reserveWeekly == 0 { + reserveWeekly = 1 + } + var cap5h float64 + if options.Limit5h > 0 { + cap5h = float64(options.Limit5h) + } else { + cap5h = float64(100 - reserve5h) + } + var capWeekly float64 + if options.LimitWeekly > 0 { + capWeekly = float64(options.LimitWeekly) + } else { + capWeekly = float64(100 - reserveWeekly) + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: tag, + serviceContext: ctx, + credentialPath: options.CredentialPath, + cap5h: cap5h, + capWeekly: capWeekly, + dialer: credentialDialer, + forwardHTTPClient: httpClient, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if options.UsagesPath != "" { + credential.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + return credential, nil +} + +func (c *defaultCredential) start() error { + credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) + if err != nil { + return E.Cause(err, "resolve credential path for ", c.tag) + } + c.credentialFilePath = credentialFilePath + err = c.ensureCredentialWatcher() + if err != nil { + c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + } + err = c.reloadCredentials(true) + if err != nil { + c.logger.Warn("initial credential load for ", c.tag, ": ", err) + } + if c.usageTracker != nil { + err = c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *defaultCredential) setOnBecameUnusable(fn func()) { + c.onBecameUnusable = fn +} + +func (c *defaultCredential) tagName() string { + return c.tag +} + +func (c *defaultCredential) isExternal() bool { + return false +} + +func (c *defaultCredential) getAccessToken() (string, error) { + c.retryCredentialReloadIfNeeded() + + c.access.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.getAccessToken() + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + + err := c.reloadCredentials(true) + if err == nil { + c.access.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.getAccessToken() + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + } + + c.access.Lock() + defer c.access.Unlock() + + if c.credentials == nil { + return "", c.unavailableError() + } + if !c.credentials.needsRefresh() { + return c.credentials.getAccessToken(), nil + } + + err = platformCanWriteCredentials(c.credentialPath) + if err != nil { + return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") + } + + baseCredentials := cloneCredentials(c.credentials) + newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) + if err != nil { + return "", err + } + + latestCredentials, latestErr := platformReadCredentials(c.credentialPath) + if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { + c.credentials = latestCredentials + c.stateAccess.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.checkTransitionLocked() + c.stateAccess.Unlock() + if !latestCredentials.needsRefresh() { + return latestCredentials.getAccessToken(), nil + } + return "", E.New("credential ", c.tag, " changed while refreshing") + } + + c.credentials = newCredentials + c.stateAccess.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.checkTransitionLocked() + c.stateAccess.Unlock() + + err = platformWriteCredentials(newCredentials, c.credentialPath) + if err != nil { + c.logger.Error("persist refreshed token for ", c.tag, ": ", err) + } + + return newCredentials.getAccessToken(), nil +} + +func (c *defaultCredential) getAccountID() string { + c.access.RLock() + defer c.access.RUnlock() + if c.credentials == nil { + return "" + } + return c.credentials.getAccountID() +} + +func (c *defaultCredential) isAPIKeyMode() bool { + c.access.RLock() + defer c.access.RUnlock() + if c.credentials == nil { + return false + } + return c.credentials.isAPIKeyMode() +} + +func (c *defaultCredential) getBaseURL() string { + if c.isAPIKeyMode() { + return openaiAPIBaseURL + } + return chatGPTBackendURL +} + +func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { + c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + hadData := false + + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier == "" { + activeLimitIdentifier = "codex" + } + + fiveHourResetChanged := false + fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") + if fiveHourResetAt != "" { + value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) + if err == nil { + hadData = true + newReset := time.Unix(value, 0) + if newReset.After(c.state.fiveHourReset) { + fiveHourResetChanged = true + c.state.fiveHourReset = newReset + } + } + } + fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") + if fiveHourPercent != "" { + value, err := strconv.ParseFloat(fiveHourPercent, 64) + if err == nil { + hadData = true + if value >= c.state.fiveHourUtilization || fiveHourResetChanged { + c.state.fiveHourUtilization = value + } + } + } + + weeklyResetChanged := false + weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") + if weeklyResetAt != "" { + value, err := strconv.ParseInt(weeklyResetAt, 10, 64) + if err == nil { + hadData = true + newReset := time.Unix(value, 0) + if newReset.After(c.state.weeklyReset) { + weeklyResetChanged = true + c.state.weeklyReset = newReset + } + } + } + weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") + if weeklyPercent != "" { + value, err := strconv.ParseFloat(weeklyPercent, 64) + if err == nil { + hadData = true + if value >= c.state.weeklyUtilization || weeklyResetChanged { + c.state.weeklyUtilization = value + } + } + } + if hadData { + c.state.consecutivePollFailures = 0 + c.state.lastUpdated = time.Now() + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateAccess.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) isUsable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateAccess.RLock() + if c.state.unavailable { + c.stateAccess.RUnlock() + return false + } + if c.state.consecutivePollFailures > 0 { + c.stateAccess.RUnlock() + return false + } + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateAccess.RUnlock() + return false + } + c.stateAccess.RUnlock() + c.stateAccess.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.checkReservesLocked() + c.stateAccess.Unlock() + return usable + } + usable := c.checkReservesLocked() + c.stateAccess.RUnlock() + return usable +} + +func (c *defaultCredential) checkReservesLocked() bool { + if c.state.fiveHourUtilization >= c.cap5h { + return false + } + if c.state.weeklyUtilization >= c.capWeekly { + return false + } + return true +} + +// checkTransitionLocked detects usable->unusable transition. +// Must be called with stateAccess write lock held. +func (c *defaultCredential) checkTransitionLocked() bool { + unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *defaultCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFuncs: []func() bool{stop}, + cancelFunc: cancel, + } +} + +func (c *defaultCredential) fiveHourUtilization() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourUtilization +} + +func (c *defaultCredential) weeklyUtilization() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.weeklyUtilization +} + +func (c *defaultCredential) planWeight() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return ocmPlanWeight(c.state.accountType) +} + +func (c *defaultCredential) weeklyResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.weeklyReset +} + +func (c *defaultCredential) isAvailable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return !c.state.unavailable +} + +func (c *defaultCredential) unavailableError() error { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if !c.state.unavailable { + return nil + } + if c.state.lastCredentialLoadError == "" { + return E.New("credential ", c.tag, " is unavailable") + } + return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) +} + +func (c *defaultCredential) lastUpdatedTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.lastUpdated +} + +func (c *defaultCredential) markUsagePollAttempted() { + c.stateAccess.Lock() + defer c.stateAccess.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *defaultCredential) incrementPollFailures() { + c.stateAccess.Lock() + c.state.consecutivePollFailures++ + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateAccess.RLock() + failures := c.state.consecutivePollFailures + c.stateAccess.RUnlock() + if failures <= 0 { + return baseInterval + } + backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) + if backoff > httpRetryMaxBackoff { + return httpRetryMaxBackoff + } + return backoff +} + +func (c *defaultCredential) isPollBackoffAtCap() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + failures := c.state.consecutivePollFailures + return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff +} + +func (c *defaultCredential) earliestReset() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if c.state.unavailable { + return time.Time{} + } + if c.state.hardRateLimited { + return c.state.rateLimitResetAt + } + earliest := c.state.fiveHourReset + if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { + earliest = c.state.weeklyReset + } + return earliest +} + +func (c *defaultCredential) fiveHourCap() float64 { + return c.cap5h +} + +func (c *defaultCredential) weeklyCap() float64 { + return c.capWeekly +} + +func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { + return c.usageTracker +} + +func (c *defaultCredential) httpClient() *http.Client { + return c.forwardHTTPClient +} + +func (c *defaultCredential) ocmDialer() N.Dialer { + return c.dialer +} + +func (c *defaultCredential) ocmIsAPIKeyMode() bool { + return c.isAPIKeyMode() +} + +func (c *defaultCredential) ocmGetAccountID() string { + return c.getAccountID() +} + +func (c *defaultCredential) ocmGetBaseURL() string { + return c.getBaseURL() +} + +func (c *defaultCredential) pollUsage(ctx context.Context) { + if !c.pollAccess.TryLock() { + return + } + defer c.pollAccess.Unlock() + defer c.markUsagePollAttempted() + + c.retryCredentialReloadIfNeeded() + if !c.isAvailable() { + return + } + if c.isAPIKeyMode() { + return + } + + accessToken, err := c.getAccessToken() + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": get token: ", err) + } + c.incrementPollFailures() + return + } + + var usageURL string + if c.isAPIKeyMode() { + usageURL = openaiAPIBaseURL + "/api/codex/usage" + } else { + usageURL = strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage" + } + + accountID := c.getAccountID() + pollClient := &http.Client{ + Transport: c.forwardHTTPClient.Transport, + Timeout: 5 * time.Second, + } + + response, err := doHTTPWithRetry(ctx, pollClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + if accountID != "" { + request.Header.Set("ChatGPT-Account-Id", accountID) + } + return request, nil + }) + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } + c.incrementPollFailures() + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + if response.StatusCode == http.StatusTooManyRequests { + c.logger.Warn("poll usage for ", c.tag, ": rate limited") + } + body, _ := io.ReadAll(response.Body) + c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + c.incrementPollFailures() + return + } + + type usageWindow struct { + UsedPercent float64 `json:"used_percent"` + ResetAt int64 `json:"reset_at"` + } + var usageResponse struct { + PlanType string `json:"plan_type"` + RateLimit *struct { + PrimaryWindow *usageWindow `json:"primary_window"` + SecondaryWindow *usageWindow `json:"secondary_window"` + } `json:"rate_limit"` + } + err = json.NewDecoder(response.Body).Decode(&usageResponse) + if err != nil { + c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.incrementPollFailures() + return + } + + c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + c.state.consecutivePollFailures = 0 + if usageResponse.RateLimit != nil { + if w := usageResponse.RateLimit.PrimaryWindow; w != nil { + c.state.fiveHourUtilization = w.UsedPercent + if w.ResetAt > 0 { + c.state.fiveHourReset = time.Unix(w.ResetAt, 0) + } + } + if w := usageResponse.RateLimit.SecondaryWindow; w != nil { + c.state.weeklyUtilization = w.UsedPercent + if w.ResetAt > 0 { + c.state.weeklyReset = time.Unix(w.ResetAt, 0) + } + } + } + if usageResponse.PlanType != "" { + c.state.accountType = usageResponse.PlanType + } + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { + accessToken, err := c.getAccessToken() + if err != nil { + return nil, E.Cause(err, "get access token for ", c.tag) + } + + path := original.URL.Path + var proxyPath string + if c.isAPIKeyMode() { + proxyPath = path + } else { + proxyPath = strings.TrimPrefix(path, "/v1") + } + + proxyURL := c.getBaseURL() + proxyPath + if original.URL.RawQuery != "" { + proxyURL += "?" + original.URL.RawQuery + } + + var body io.Reader + if bodyBytes != nil { + body = bytes.NewReader(bodyBytes) + } else { + body = original.Body + } + proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) + if err != nil { + return nil, err + } + + for key, values := range original.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { + proxyRequest.Header[key] = values + } + } + + for key, values := range serviceHeaders { + proxyRequest.Header.Del(key) + proxyRequest.Header[key] = values + } + proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) + + if accountID := c.getAccountID(); accountID != "" { + proxyRequest.Header.Set("ChatGPT-Account-Id", accountID) + } + + return proxyRequest, nil +} + +func (c *defaultCredential) close() { + if c.watcher != nil { + err := c.watcher.Close() + if err != nil { + c.logger.Error("close credential watcher for ", c.tag, ": ", err) + } + } + if c.usageTracker != nil { + c.usageTracker.cancelPendingSave() + err := c.usageTracker.Save() + if err != nil { + c.logger.Error("save usage statistics for ", c.tag, ": ", err) + } + } +} diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 0e0556be71..968bf904d7 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -30,17 +30,17 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - credDialer N.Dialer - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + credDialer N.Dialer + forwardHTTPClient *http.Client + state credentialState + stateAccess sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -147,7 +147,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx // Receiver mode: no URL, wait for reverse connection cred.baseURL = reverseProxyBaseURL cred.credDialer = reverseSessionDialer{credential: cred} - cred.httpClient = &http.Client{ + cred.forwardHTTPClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -211,11 +211,11 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx Time: ntp.TimeFuncFromContext(ctx), } } - cred.httpClient = &http.Client{Transport: transport} + cred.forwardHTTPClient = &http.Client{Transport: transport} } else { // Normal mode: standard HTTP client for proxying cred.credDialer = credentialDialer - cred.httpClient = &http.Client{Transport: transport} + cred.forwardHTTPClient = &http.Client{Transport: transport} cred.reverseCredDialer = reverseSessionDialer{credential: cred} cred.reverseHttpClient = &http.Client{ Transport: &http.Transport{ @@ -273,39 +273,39 @@ func (c *externalCredential) isUsable() bool { if !c.isAvailable() { return false } - c.stateMutex.RLock() + c.stateAccess.RLock() if c.state.consecutivePollFailures > 0 { - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return false } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return false } - c.stateMutex.RUnlock() - c.stateMutex.Lock() + c.stateAccess.RUnlock() + c.stateAccess.Lock() if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 - c.stateMutex.Unlock() + c.stateAccess.Unlock() return usable } usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return usable } func (c *externalCredential) fiveHourUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.fiveHourUtilization } func (c *externalCredential) weeklyUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.weeklyUtilization } @@ -318,8 +318,8 @@ func (c *externalCredential) weeklyCap() float64 { } func (c *externalCredential) planWeight() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() if c.state.remotePlanWeight > 0 { return c.state.remotePlanWeight } @@ -327,26 +327,26 @@ func (c *externalCredential) planWeight() float64 { } func (c *externalCredential) weeklyResetTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.weeklyReset } func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } } func (c *externalCredential) earliestReset() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() if c.state.hardRateLimited { return c.state.rateLimitResetAt } @@ -432,7 +432,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con } func (c *externalCredential) updateStateFromHeaders(headers http.Header) { - c.stateMutex.Lock() + c.stateAccess.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization @@ -494,7 +494,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } @@ -569,9 +569,9 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp } } // Forward transport with retries - if c.httpClient != nil { + if c.forwardHTTPClient != nil { forwardClient := &http.Client{ - Transport: c.httpClient.Transport, + Transport: c.forwardHTTPClient.Transport, Timeout: 5 * time.Second, } return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL)) @@ -602,10 +602,10 @@ func (c *externalCredential) pollUsage(ctx context.Context) { // 404 means the remote does not have a status endpoint yet; // usage will be updated passively from response headers. if response.StatusCode == http.StatusNotFound { - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.consecutivePollFailures = 0 c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() } else { c.incrementPollFailures() } @@ -624,7 +624,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { return } - c.stateMutex.Lock() + c.stateAccess.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization @@ -645,28 +645,28 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } } func (c *externalCredential) lastUpdatedTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.lastUpdated } func (c *externalCredential) markUsagePollAttempted() { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() + c.stateAccess.Lock() + defer c.stateAccess.Unlock() c.state.lastUpdated = time.Now() } func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateMutex.RLock() + c.stateAccess.RLock() failures := c.state.consecutivePollFailures - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if failures <= 0 { return baseInterval } @@ -678,17 +678,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati } func (c *externalCredential) isPollBackoffAtCap() bool { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() failures := c.state.consecutivePollFailures return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff } func (c *externalCredential) incrementPollFailures() { - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.consecutivePollFailures++ shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } @@ -698,14 +698,14 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { return c.usageTracker } -func (c *externalCredential) httpTransport() *http.Client { +func (c *externalCredential) httpClient() *http.Client { if c.reverseHttpClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { return c.reverseHttpClient } } - return c.httpClient + return c.forwardHTTPClient } func (c *externalCredential) ocmDialer() N.Dialer { diff --git a/service/ocm/credential_file.go b/service/ocm/credential_file.go index b8252904ea..861dbdb864 100644 --- a/service/ocm/credential_file.go +++ b/service/ocm/credential_file.go @@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error { } func (c *defaultCredential) retryCredentialReloadIfNeeded() { - c.stateMutex.RLock() + c.stateAccess.RLock() unavailable := c.state.unavailable lastAttempt := c.state.lastCredentialLoadAttempt - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if !unavailable { return } @@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.reloadAccess.Lock() defer c.reloadAccess.Unlock() - c.stateMutex.RLock() + c.stateAccess.RLock() unavailable := c.state.unavailable lastAttempt := c.state.lastCredentialLoadAttempt - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if !force { if !unavailable { return nil @@ -97,39 +97,39 @@ func (c *defaultCredential) reloadCredentials(force bool) error { } } - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.lastCredentialLoadAttempt = time.Now() - c.stateMutex.Unlock() + c.stateAccess.Unlock() credentials, err := platformReadCredentials(c.credentialPath) if err != nil { return c.markCredentialsUnavailable(E.Cause(err, "read credentials")) } - c.accessMutex.Lock() + c.access.Lock() c.credentials = credentials - c.accessMutex.Unlock() + c.access.Unlock() - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.unavailable = false c.state.lastCredentialLoadError = "" c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() return nil } func (c *defaultCredential) markCredentialsUnavailable(err error) error { - c.accessMutex.Lock() + c.access.Lock() hadCredentials := c.credentials != nil c.credentials = nil - c.accessMutex.Unlock() + c.access.Unlock() - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt && hadCredentials { c.interruptConnections() diff --git a/service/ocm/credential_oauth.go b/service/ocm/credential_oauth.go new file mode 100644 index 0000000000..bb240b5aba --- /dev/null +++ b/service/ocm/credential_oauth.go @@ -0,0 +1,225 @@ +package ocm + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "os/user" + "path/filepath" + "time" + + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + oauth2TokenURL = "https://auth.openai.com/oauth/token" + openaiAPIBaseURL = "https://api.openai.com" + chatGPTBackendURL = "https://chatgpt.com/backend-api/codex" + tokenRefreshIntervalDays = 8 +) + +func getRealUser() (*user.User, error) { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + sudoUserInfo, err := user.Lookup(sudoUser) + if err == nil { + return sudoUserInfo, nil + } + } + return user.Current() +} + +func getDefaultCredentialsPath() (string, error) { + if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" { + return filepath.Join(codexHome, "auth.json"), nil + } + userInfo, err := getRealUser() + if err != nil { + return "", err + } + return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil +} + +func readCredentialsFromFile(path string) (*oauthCredentials, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var credentials oauthCredentials + err = json.Unmarshal(data, &credentials) + if err != nil { + return nil, err + } + return &credentials, nil +} + +func checkCredentialFileWritable(path string) error { + file, err := os.OpenFile(path, os.O_WRONLY, 0) + if err != nil { + return err + } + return file.Close() +} + +func writeCredentialsToFile(credentials *oauthCredentials, path string) error { + data, err := json.MarshalIndent(credentials, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o600) +} + +type oauthCredentials struct { + APIKey string `json:"OPENAI_API_KEY,omitempty"` + Tokens *tokenData `json:"tokens,omitempty"` + LastRefresh *time.Time `json:"last_refresh,omitempty"` +} + +type tokenData struct { + IDToken string `json:"id_token,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + AccountID string `json:"account_id,omitempty"` +} + +func (c *oauthCredentials) isAPIKeyMode() bool { + return c.APIKey != "" +} + +func (c *oauthCredentials) getAccessToken() string { + if c.APIKey != "" { + return c.APIKey + } + if c.Tokens != nil { + return c.Tokens.AccessToken + } + return "" +} + +func (c *oauthCredentials) getAccountID() string { + if c.Tokens != nil { + return c.Tokens.AccountID + } + return "" +} + +func (c *oauthCredentials) needsRefresh() bool { + if c.APIKey != "" { + return false + } + if c.Tokens == nil || c.Tokens.RefreshToken == "" { + return false + } + if c.LastRefresh == nil { + return true + } + return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour +} + +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { + if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" { + return nil, E.New("refresh token is empty") + } + + requestBody, err := json.Marshal(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": credentials.Tokens.RefreshToken, + "client_id": oauth2ClientID, + "scope": "openid profile email", + }) + if err != nil { + return nil, E.Cause(err, "marshal request") + } + + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + return request, nil + }) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode == http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + } + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh failed: ", response.Status, " ", string(body)) + } + + var tokenResponse struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + err = json.NewDecoder(response.Body).Decode(&tokenResponse) + if err != nil { + return nil, E.Cause(err, "decode response") + } + + newCredentials := *credentials + if newCredentials.Tokens == nil { + newCredentials.Tokens = &tokenData{} + } + if tokenResponse.IDToken != "" { + newCredentials.Tokens.IDToken = tokenResponse.IDToken + } + if tokenResponse.AccessToken != "" { + newCredentials.Tokens.AccessToken = tokenResponse.AccessToken + } + if tokenResponse.RefreshToken != "" { + newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken + } + now := time.Now() + newCredentials.LastRefresh = &now + + return &newCredentials, nil +} + +func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { + if credentials == nil { + return nil + } + cloned := *credentials + if credentials.Tokens != nil { + clonedTokens := *credentials.Tokens + cloned.Tokens = &clonedTokens + } + if credentials.LastRefresh != nil { + lastRefresh := *credentials.LastRefresh + cloned.LastRefresh = &lastRefresh + } + return &cloned +} + +func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { + if left == nil || right == nil { + return left == right + } + if left.APIKey != right.APIKey { + return false + } + if (left.Tokens == nil) != (right.Tokens == nil) { + return false + } + if left.Tokens != nil && *left.Tokens != *right.Tokens { + return false + } + if (left.LastRefresh == nil) != (right.LastRefresh == nil) { + return false + } + if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) { + return false + } + return true +} diff --git a/service/ocm/credential_provider.go b/service/ocm/credential_provider.go new file mode 100644 index 0000000000..53383e3686 --- /dev/null +++ b/service/ocm/credential_provider.go @@ -0,0 +1,411 @@ +package ocm + +import ( + "context" + "math/rand/v2" + "sync" + "sync/atomic" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" +) + +type credentialProvider interface { + selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential + linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool + pollIfStale(ctx context.Context) + allCredentials() []credential + close() +} + +type singleCredentialProvider struct { + cred credential + sessionAccess sync.RWMutex + sessions map[string]time.Time +} + +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if !selection.allows(p.cred) { + return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") + } + if !p.cred.isAvailable() { + return nil, false, p.cred.unavailableError() + } + if !p.cred.isUsable() { + return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") + } + var isNew bool + if sessionID != "" { + p.sessionAccess.Lock() + if p.sessions == nil { + p.sessions = make(map[string]time.Time) + } + _, exists := p.sessions[sessionID] + if !exists { + p.sessions[sessionID] = time.Now() + isNew = true + } + p.sessionAccess.Unlock() + } + return p.cred, isNew, nil +} + +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { + cred.markRateLimited(resetAt) + return nil +} + +func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, createdAt := range p.sessions { + if now.Sub(createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + + if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { + p.cred.pollUsage(ctx) + } +} + +func (p *singleCredentialProvider) allCredentials() []credential { + return []credential{p.cred} +} + +func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} + +func (p *singleCredentialProvider) close() {} + +type sessionEntry struct { + tag string + selectionScope credentialSelectionScope + createdAt time.Time +} + +type credentialInterruptKey struct { + tag string + selectionScope credentialSelectionScope +} + +type credentialInterruptEntry struct { + context context.Context + cancel context.CancelFunc +} + +type balancerProvider struct { + credentials []credential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + rebalanceThreshold float64 + sessionAccess sync.RWMutex + sessions map[string]sessionEntry + interruptAccess sync.Mutex + credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry + logger log.ContextLogger +} + +func compositeCredentialSelectable(cred credential) bool { + return !cred.ocmIsAPIKeyMode() +} + +func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &balancerProvider{ + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + rebalanceThreshold: rebalanceThreshold, + sessions: make(map[string]sessionEntry), + credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), + logger: logger, + } +} + +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if p.strategy == C.BalancerStrategyFallback { + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + return best, false, nil + } + + selectionScope := selection.scopeOrDefault() + if sessionID != "" { + p.sessionAccess.RLock() + entry, exists := p.sessions[sessionID] + p.sessionAccess.RUnlock() + if exists { + if entry.selectionScope == selectionScope { + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName(), selectionScope) + break + } + } + } + return cred, false, nil + } + } + } + p.sessionAccess.Lock() + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } + } + + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + + isNew := sessionID != "" + if isNew { + p.sessionAccess.Lock() + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selectionScope, + createdAt: time.Now(), + } + p.sessionAccess.Unlock() + } + return best, isNew, nil +} + +func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { + key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} + p.interruptAccess.Lock() + if entry, loaded := p.credentialInterrupts[key]; loaded { + entry.cancel() + } + ctx, cancel := context.WithCancel(context.Background()) + p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.interruptAccess.Unlock() + + p.sessionAccess.Lock() + for id, entry := range p.sessions { + if entry.tag == tag && entry.selectionScope == selectionScope { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() +} + +func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + if p.strategy == C.BalancerStrategyFallback { + return func() bool { return false } + } + key := credentialInterruptKey{ + tag: cred.tagName(), + selectionScope: selection.scopeOrDefault(), + } + p.interruptAccess.Lock() + entry, loaded := p.credentialInterrupts[key] + if !loaded { + ctx, cancel := context.WithCancel(context.Background()) + entry = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[key] = entry + } + p.interruptAccess.Unlock() + return context.AfterFunc(entry.context, onInterrupt) +} + +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { + cred.markRateLimited(resetAt) + if p.strategy == C.BalancerStrategyFallback { + return p.pickCredential(selection.filter) + } + if sessionID != "" { + p.sessionAccess.Lock() + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } + + best := p.pickCredential(selection.filter) + if best != nil && sessionID != "" { + p.sessionAccess.Lock() + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selection.scopeOrDefault(), + createdAt: time.Now(), + } + p.sessionAccess.Unlock() + } + return best +} + +func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { + switch p.strategy { + case C.BalancerStrategyRoundRobin: + return p.pickRoundRobin(filter) + case C.BalancerStrategyRandom: + return p.pickRandom(filter) + case C.BalancerStrategyFallback: + return p.pickFallback(filter) + default: + return p.pickLeastUsed(filter) + } +} + +func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !compositeCredentialSelectable(cred) { + continue + } + if cred.isUsable() { + return cred + } + } + return nil +} + +const weeklyWindowHours = 7 * 24 + +func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { + var best credential + bestScore := float64(-1) + now := time.Now() + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !compositeCredentialSelectable(cred) { + continue + } + if !cred.isUsable() { + continue + } + remaining := cred.weeklyCap() - cred.weeklyUtilization() + score := remaining * cred.planWeight() + resetTime := cred.weeklyResetTime() + if !resetTime.IsZero() { + timeUntilReset := resetTime.Sub(now) + if timeUntilReset < time.Hour { + timeUntilReset = time.Hour + } + score *= weeklyWindowHours / timeUntilReset.Hours() + } + if score > bestScore { + bestScore = score + best = cred + } + } + return best +} + +func ocmPlanWeight(accountType string) float64 { + switch accountType { + case "pro": + return 10 + case "plus": + return 1 + default: + return 1 + } +} + +func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { + start := int(p.roundRobinIndex.Add(1) - 1) + count := len(p.credentials) + for offset := range count { + candidate := p.credentials[(start+offset)%count] + if filter != nil && !filter(candidate) { + continue + } + if !compositeCredentialSelectable(candidate) { + continue + } + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { + var usable []credential + for _, candidate := range p.credentials { + if filter != nil && !filter(candidate) { + continue + } + if !compositeCredentialSelectable(candidate) { + continue + } + if candidate.isUsable() { + usable = append(usable, candidate) + } + } + if len(usable) == 0 { + return nil + } + return usable[rand.IntN(len(usable))] +} + +func (p *balancerProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, entry := range p.sessions { + if now.Sub(entry.createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + + for _, cred := range p.credentials { + if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { + cred.pollUsage(ctx) + } + } +} + +func (p *balancerProvider) allCredentials() []credential { + return p.credentials +} + +func (p *balancerProvider) close() {} + +func allRateLimitedError(credentials []credential) error { + var hasUnavailable bool + var earliest time.Time + for _, cred := range credentials { + if cred.unavailableError() != nil { + hasUnavailable = true + continue + } + resetAt := cred.earliestReset() + if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { + earliest = resetAt + } + } + if hasUnavailable { + return E.New("all credentials unavailable") + } + if earliest.IsZero() { + return E.New("all credentials rate-limited") + } + return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) +} diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go deleted file mode 100644 index 181132f09d..0000000000 --- a/service/ocm/credential_state.go +++ /dev/null @@ -1,1524 +0,0 @@ -package ocm - -import ( - "bytes" - "context" - stdTLS "crypto/tls" - "encoding/json" - "io" - "math/rand/v2" - "net" - "net/http" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/sagernet/fswatch" - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/ntp" -) - -const ( - defaultPollInterval = 60 * time.Minute - failedPollRetryInterval = time.Minute - httpRetryMaxBackoff = 5 * time.Minute -) - -const ( - httpRetryMaxAttempts = 3 - httpRetryInitialDelay = 200 * time.Millisecond -) - -func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { - var lastError error - for attempt := range httpRetryMaxAttempts { - if attempt > 0 { - delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) - select { - case <-ctx.Done(): - return nil, lastError - case <-time.After(delay): - } - } - request, err := buildRequest() - if err != nil { - return nil, err - } - response, err := client.Do(request) - if err == nil { - return response, nil - } - lastError = err - if ctx.Err() != nil { - return nil, lastError - } - } - return nil, lastError -} - -type credentialState struct { - fiveHourUtilization float64 - fiveHourReset time.Time - weeklyUtilization float64 - weeklyReset time.Time - hardRateLimited bool - rateLimitResetAt time.Time - accountType string - remotePlanWeight float64 - lastUpdated time.Time - consecutivePollFailures int - unavailable bool - lastCredentialLoadAttempt time.Time - lastCredentialLoadError string -} - -type defaultCredential struct { - tag string - serviceContext context.Context - credentialPath string - credentialFilePath string - credentials *oauthCredentials - accessMutex sync.RWMutex - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - reloadAccess sync.Mutex - watcherAccess sync.Mutex - cap5h float64 - capWeekly float64 - usageTracker *AggregatedUsage - dialer N.Dialer - httpClient *http.Client - logger log.ContextLogger - watcher *fswatch.Watcher - watcherRetryAt time.Time - - // Connection interruption - onBecameUnusable func() - interrupted bool - requestContext context.Context - cancelRequests context.CancelFunc - requestAccess sync.Mutex -} - -type credentialRequestContext struct { - context.Context - releaseOnce sync.Once - cancelOnce sync.Once - releaseFuncs []func() bool - cancelFunc context.CancelFunc -} - -func (c *credentialRequestContext) addInterruptLink(stop func() bool) { - c.releaseFuncs = append(c.releaseFuncs, stop) -} - -func (c *credentialRequestContext) releaseCredentialInterrupt() { - c.releaseOnce.Do(func() { - for _, f := range c.releaseFuncs { - f() - } - }) -} - -func (c *credentialRequestContext) cancelRequest() { - c.releaseCredentialInterrupt() - c.cancelOnce.Do(c.cancelFunc) -} - -type credential interface { - tagName() string - isAvailable() bool - isUsable() bool - isExternal() bool - fiveHourUtilization() float64 - weeklyUtilization() float64 - fiveHourCap() float64 - weeklyCap() float64 - planWeight() float64 - weeklyResetTime() time.Time - markRateLimited(resetAt time.Time) - earliestReset() time.Time - unavailableError() error - - getAccessToken() (string, error) - buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) - updateStateFromHeaders(header http.Header) - - wrapRequestContext(ctx context.Context) *credentialRequestContext - interruptConnections() - - setOnBecameUnusable(fn func()) - start() error - pollUsage(ctx context.Context) - lastUpdatedTime() time.Time - pollBackoff(base time.Duration) time.Duration - usageTrackerOrNil() *AggregatedUsage - httpTransport() *http.Client - close() - - // OCM-specific - ocmDialer() N.Dialer - ocmIsAPIKeyMode() bool - ocmGetAccountID() string - ocmGetBaseURL() string -} - -func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { - credentialDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer for credential ", tag) - } - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSClientConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - reserve5h := options.Reserve5h - if reserve5h == 0 { - reserve5h = 1 - } - reserveWeekly := options.ReserveWeekly - if reserveWeekly == 0 { - reserveWeekly = 1 - } - var cap5h float64 - if options.Limit5h > 0 { - cap5h = float64(options.Limit5h) - } else { - cap5h = float64(100 - reserve5h) - } - var capWeekly float64 - if options.LimitWeekly > 0 { - capWeekly = float64(options.LimitWeekly) - } else { - capWeekly = float64(100 - reserveWeekly) - } - requestContext, cancelRequests := context.WithCancel(context.Background()) - credential := &defaultCredential{ - tag: tag, - serviceContext: ctx, - credentialPath: options.CredentialPath, - cap5h: cap5h, - capWeekly: capWeekly, - dialer: credentialDialer, - httpClient: httpClient, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - } - if options.UsagesPath != "" { - credential.usageTracker = &AggregatedUsage{ - LastUpdated: time.Now(), - Combinations: make([]CostCombination, 0), - filePath: options.UsagesPath, - logger: logger, - } - } - return credential, nil -} - -func (c *defaultCredential) start() error { - credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) - if err != nil { - return E.Cause(err, "resolve credential path for ", c.tag) - } - c.credentialFilePath = credentialFilePath - err = c.ensureCredentialWatcher() - if err != nil { - c.logger.Debug("start credential watcher for ", c.tag, ": ", err) - } - err = c.reloadCredentials(true) - if err != nil { - c.logger.Warn("initial credential load for ", c.tag, ": ", err) - } - if c.usageTracker != nil { - err = c.usageTracker.Load() - if err != nil { - c.logger.Warn("load usage statistics for ", c.tag, ": ", err) - } - } - return nil -} - -func (c *defaultCredential) getAccessToken() (string, error) { - c.retryCredentialReloadIfNeeded() - - c.accessMutex.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.getAccessToken() - c.accessMutex.RUnlock() - return token, nil - } - c.accessMutex.RUnlock() - - err := c.reloadCredentials(true) - if err == nil { - c.accessMutex.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.getAccessToken() - c.accessMutex.RUnlock() - return token, nil - } - c.accessMutex.RUnlock() - } - - c.accessMutex.Lock() - defer c.accessMutex.Unlock() - - if c.credentials == nil { - return "", c.unavailableError() - } - if !c.credentials.needsRefresh() { - return c.credentials.getAccessToken(), nil - } - - err = platformCanWriteCredentials(c.credentialPath) - if err != nil { - return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") - } - - baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials) - if err != nil { - return "", err - } - - latestCredentials, latestErr := platformReadCredentials(c.credentialPath) - if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { - c.credentials = latestCredentials - c.stateMutex.Lock() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.checkTransitionLocked() - c.stateMutex.Unlock() - if !latestCredentials.needsRefresh() { - return latestCredentials.getAccessToken(), nil - } - return "", E.New("credential ", c.tag, " changed while refreshing") - } - - c.credentials = newCredentials - c.stateMutex.Lock() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.checkTransitionLocked() - c.stateMutex.Unlock() - - err = platformWriteCredentials(newCredentials, c.credentialPath) - if err != nil { - c.logger.Error("persist refreshed token for ", c.tag, ": ", err) - } - - return newCredentials.getAccessToken(), nil -} - -func (c *defaultCredential) getAccountID() string { - c.accessMutex.RLock() - defer c.accessMutex.RUnlock() - if c.credentials == nil { - return "" - } - return c.credentials.getAccountID() -} - -func (c *defaultCredential) isAPIKeyMode() bool { - c.accessMutex.RLock() - defer c.accessMutex.RUnlock() - if c.credentials == nil { - return false - } - return c.credentials.isAPIKeyMode() -} - -func (c *defaultCredential) getBaseURL() string { - if c.isAPIKeyMode() { - return openaiAPIBaseURL - } - return chatGPTBackendURL -} - -func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { - c.stateMutex.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() - oldFiveHour := c.state.fiveHourUtilization - oldWeekly := c.state.weeklyUtilization - hadData := false - - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier == "" { - activeLimitIdentifier = "codex" - } - - fiveHourResetChanged := false - fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") - if fiveHourResetAt != "" { - value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) - if err == nil { - hadData = true - newReset := time.Unix(value, 0) - if newReset.After(c.state.fiveHourReset) { - fiveHourResetChanged = true - c.state.fiveHourReset = newReset - } - } - } - fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") - if fiveHourPercent != "" { - value, err := strconv.ParseFloat(fiveHourPercent, 64) - if err == nil { - hadData = true - if value >= c.state.fiveHourUtilization || fiveHourResetChanged { - c.state.fiveHourUtilization = value - } - } - } - - weeklyResetChanged := false - weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") - if weeklyResetAt != "" { - value, err := strconv.ParseInt(weeklyResetAt, 10, 64) - if err == nil { - hadData = true - newReset := time.Unix(value, 0) - if newReset.After(c.state.weeklyReset) { - weeklyResetChanged = true - c.state.weeklyReset = newReset - } - } - } - weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") - if weeklyPercent != "" { - value, err := strconv.ParseFloat(weeklyPercent, 64) - if err == nil { - hadData = true - if value >= c.state.weeklyUtilization || weeklyResetChanged { - c.state.weeklyUtilization = value - } - } - } - if hadData { - c.state.consecutivePollFailures = 0 - c.state.lastUpdated = time.Now() - } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - resetSuffix := "" - if !c.state.weeklyReset.IsZero() { - resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) - } - c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) - } - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) markRateLimited(resetAt time.Time) { - c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) - c.stateMutex.Lock() - c.state.hardRateLimited = true - c.state.rateLimitResetAt = resetAt - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) isUsable() bool { - c.retryCredentialReloadIfNeeded() - - c.stateMutex.RLock() - if c.state.unavailable { - c.stateMutex.RUnlock() - return false - } - if c.state.consecutivePollFailures > 0 { - c.stateMutex.RUnlock() - return false - } - if c.state.hardRateLimited { - if time.Now().Before(c.state.rateLimitResetAt) { - c.stateMutex.RUnlock() - return false - } - c.stateMutex.RUnlock() - c.stateMutex.Lock() - if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { - c.state.hardRateLimited = false - } - usable := c.checkReservesLocked() - c.stateMutex.Unlock() - return usable - } - usable := c.checkReservesLocked() - c.stateMutex.RUnlock() - return usable -} - -func (c *defaultCredential) checkReservesLocked() bool { - if c.state.fiveHourUtilization >= c.cap5h { - return false - } - if c.state.weeklyUtilization >= c.capWeekly { - return false - } - return true -} - -// checkTransitionLocked detects usable→unusable transition. -// Must be called with stateMutex write lock held. -func (c *defaultCredential) checkTransitionLocked() bool { - unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 - if unusable && !c.interrupted { - c.interrupted = true - return true - } - if !unusable && c.interrupted { - c.interrupted = false - } - return false -} - -func (c *defaultCredential) interruptConnections() { - c.logger.Warn("interrupting connections for ", c.tag) - c.requestAccess.Lock() - c.cancelRequests() - c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) - c.requestAccess.Unlock() - if c.onBecameUnusable != nil { - c.onBecameUnusable() - } -} - -func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { - c.requestAccess.Lock() - credentialContext := c.requestContext - c.requestAccess.Unlock() - derived, cancel := context.WithCancel(parent) - stop := context.AfterFunc(credentialContext, func() { - cancel() - }) - return &credentialRequestContext{ - Context: derived, - releaseFuncs: []func() bool{stop}, - cancelFunc: cancel, - } -} - -func (c *defaultCredential) weeklyUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.weeklyUtilization -} - -func (c *defaultCredential) planWeight() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return ocmPlanWeight(c.state.accountType) -} - -func (c *defaultCredential) weeklyResetTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.weeklyReset -} - -func (c *defaultCredential) isAvailable() bool { - c.retryCredentialReloadIfNeeded() - - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return !c.state.unavailable -} - -func (c *defaultCredential) unavailableError() error { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - if !c.state.unavailable { - return nil - } - if c.state.lastCredentialLoadError == "" { - return E.New("credential ", c.tag, " is unavailable") - } - return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) -} - -func (c *defaultCredential) lastUpdatedTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.lastUpdated -} - -func (c *defaultCredential) markUsagePollAttempted() { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - c.state.lastUpdated = time.Now() -} - -func (c *defaultCredential) incrementPollFailures() { - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateMutex.RLock() - failures := c.state.consecutivePollFailures - c.stateMutex.RUnlock() - if failures <= 0 { - return baseInterval - } - backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) - if backoff > httpRetryMaxBackoff { - return httpRetryMaxBackoff - } - return backoff -} - -func (c *defaultCredential) isPollBackoffAtCap() bool { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - failures := c.state.consecutivePollFailures - return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff -} - -func (c *defaultCredential) earliestReset() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - if c.state.unavailable { - return time.Time{} - } - if c.state.hardRateLimited { - return c.state.rateLimitResetAt - } - earliest := c.state.fiveHourReset - if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { - earliest = c.state.weeklyReset - } - return earliest -} - -func (c *defaultCredential) pollUsage(ctx context.Context) { - if !c.pollAccess.TryLock() { - return - } - defer c.pollAccess.Unlock() - defer c.markUsagePollAttempted() - - c.retryCredentialReloadIfNeeded() - if !c.isAvailable() { - return - } - if c.isAPIKeyMode() { - return - } - - accessToken, err := c.getAccessToken() - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": get token: ", err) - } - c.incrementPollFailures() - return - } - - var usageURL string - if c.isAPIKeyMode() { - usageURL = openaiAPIBaseURL + "/api/codex/usage" - } else { - usageURL = strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage" - } - - accountID := c.getAccountID() - httpClient := &http.Client{ - Transport: c.httpClient.Transport, - Timeout: 5 * time.Second, - } - - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+accessToken) - if accountID != "" { - request.Header.Set("ChatGPT-Account-Id", accountID) - } - return request, nil - }) - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": ", err) - } - c.incrementPollFailures() - return - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - if response.StatusCode == http.StatusTooManyRequests { - c.logger.Warn("poll usage for ", c.tag, ": rate limited") - } - body, _ := io.ReadAll(response.Body) - c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - c.incrementPollFailures() - return - } - - type usageWindow struct { - UsedPercent float64 `json:"used_percent"` - ResetAt int64 `json:"reset_at"` - } - var usageResponse struct { - PlanType string `json:"plan_type"` - RateLimit *struct { - PrimaryWindow *usageWindow `json:"primary_window"` - SecondaryWindow *usageWindow `json:"secondary_window"` - } `json:"rate_limit"` - } - err = json.NewDecoder(response.Body).Decode(&usageResponse) - if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.incrementPollFailures() - return - } - - c.stateMutex.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() - oldFiveHour := c.state.fiveHourUtilization - oldWeekly := c.state.weeklyUtilization - c.state.consecutivePollFailures = 0 - if usageResponse.RateLimit != nil { - if w := usageResponse.RateLimit.PrimaryWindow; w != nil { - c.state.fiveHourUtilization = w.UsedPercent - if w.ResetAt > 0 { - c.state.fiveHourReset = time.Unix(w.ResetAt, 0) - } - } - if w := usageResponse.RateLimit.SecondaryWindow; w != nil { - c.state.weeklyUtilization = w.UsedPercent - if w.ResetAt > 0 { - c.state.weeklyReset = time.Unix(w.ResetAt, 0) - } - } - } - if usageResponse.PlanType != "" { - c.state.accountType = usageResponse.PlanType - } - if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { - c.state.hardRateLimited = false - } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - resetSuffix := "" - if !c.state.weeklyReset.IsZero() { - resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) - } - c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) - } - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) close() { - if c.watcher != nil { - err := c.watcher.Close() - if err != nil { - c.logger.Error("close credential watcher for ", c.tag, ": ", err) - } - } - if c.usageTracker != nil { - c.usageTracker.cancelPendingSave() - err := c.usageTracker.Save() - if err != nil { - c.logger.Error("save usage statistics for ", c.tag, ": ", err) - } - } -} - -func (c *defaultCredential) setOnBecameUnusable(fn func()) { - c.onBecameUnusable = fn -} - -func (c *defaultCredential) tagName() string { - return c.tag -} - -func (c *defaultCredential) isExternal() bool { - return false -} - -func (c *defaultCredential) fiveHourUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.fiveHourUtilization -} - -func (c *defaultCredential) fiveHourCap() float64 { - return c.cap5h -} - -func (c *defaultCredential) weeklyCap() float64 { - return c.capWeekly -} - -func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { - return c.usageTracker -} - -func (c *defaultCredential) httpTransport() *http.Client { - return c.httpClient -} - -func (c *defaultCredential) ocmDialer() N.Dialer { - return c.dialer -} - -func (c *defaultCredential) ocmIsAPIKeyMode() bool { - return c.isAPIKeyMode() -} - -func (c *defaultCredential) ocmGetAccountID() string { - return c.getAccountID() -} - -func (c *defaultCredential) ocmGetBaseURL() string { - return c.getBaseURL() -} - -func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { - accessToken, err := c.getAccessToken() - if err != nil { - return nil, E.Cause(err, "get access token for ", c.tag) - } - - path := original.URL.Path - var proxyPath string - if c.isAPIKeyMode() { - proxyPath = path - } else { - proxyPath = strings.TrimPrefix(path, "/v1") - } - - proxyURL := c.getBaseURL() + proxyPath - if original.URL.RawQuery != "" { - proxyURL += "?" + original.URL.RawQuery - } - - var body io.Reader - if bodyBytes != nil { - body = bytes.NewReader(bodyBytes) - } else { - body = original.Body - } - proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) - if err != nil { - return nil, err - } - - for key, values := range original.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { - proxyRequest.Header[key] = values - } - } - - for key, values := range serviceHeaders { - proxyRequest.Header.Del(key) - proxyRequest.Header[key] = values - } - proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - - if accountID := c.getAccountID(); accountID != "" { - proxyRequest.Header.Set("ChatGPT-Account-Id", accountID) - } - - return proxyRequest, nil -} - -type credentialProvider interface { - selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential - linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool - pollIfStale(ctx context.Context) - allCredentials() []credential - close() -} - -type credentialSelectionScope string - -const ( - credentialSelectionScopeAll credentialSelectionScope = "all" - credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" -) - -type credentialSelection struct { - scope credentialSelectionScope - filter func(credential) bool -} - -func (s credentialSelection) allows(cred credential) bool { - return s.filter == nil || s.filter(cred) -} - -func (s credentialSelection) scopeOrDefault() credentialSelectionScope { - if s.scope == "" { - return credentialSelectionScopeAll - } - return s.scope -} - -type singleCredentialProvider struct { - cred credential - sessionAccess sync.RWMutex - sessions map[string]time.Time -} - -func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if !selection.allows(p.cred) { - return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") - } - if !p.cred.isAvailable() { - return nil, false, p.cred.unavailableError() - } - if !p.cred.isUsable() { - return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") - } - var isNew bool - if sessionID != "" { - p.sessionAccess.Lock() - if p.sessions == nil { - p.sessions = make(map[string]time.Time) - } - _, exists := p.sessions[sessionID] - if !exists { - p.sessions[sessionID] = time.Now() - isNew = true - } - p.sessionAccess.Unlock() - } - return p.cred, isNew, nil -} - -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { - cred.markRateLimited(resetAt) - return nil -} - -func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { - now := time.Now() - p.sessionAccess.Lock() - for id, createdAt := range p.sessions { - if now.Sub(createdAt) > sessionExpiry { - delete(p.sessions, id) - } - } - p.sessionAccess.Unlock() - - if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { - p.cred.pollUsage(ctx) - } -} - -func (p *singleCredentialProvider) allCredentials() []credential { - return []credential{p.cred} -} - -func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { - return func() bool { - return false - } -} - -func (p *singleCredentialProvider) close() {} - -const sessionExpiry = 24 * time.Hour - -type sessionEntry struct { - tag string - selectionScope credentialSelectionScope - createdAt time.Time -} - -type credentialInterruptKey struct { - tag string - selectionScope credentialSelectionScope -} - -type credentialInterruptEntry struct { - context context.Context - cancel context.CancelFunc -} - -type balancerProvider struct { - credentials []credential - strategy string - roundRobinIndex atomic.Uint64 - pollInterval time.Duration - rebalanceThreshold float64 - sessionMutex sync.RWMutex - sessions map[string]sessionEntry - interruptAccess sync.Mutex - credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry - logger log.ContextLogger -} - -func compositeCredentialSelectable(cred credential) bool { - return !cred.ocmIsAPIKeyMode() -} - -func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { - if pollInterval <= 0 { - pollInterval = defaultPollInterval - } - return &balancerProvider{ - credentials: credentials, - strategy: strategy, - pollInterval: pollInterval, - rebalanceThreshold: rebalanceThreshold, - sessions: make(map[string]sessionEntry), - credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), - logger: logger, - } -} - -func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if p.strategy == C.BalancerStrategyFallback { - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allRateLimitedError(p.credentials) - } - return best, false, nil - } - - selectionScope := selection.scopeOrDefault() - if sessionID != "" { - p.sessionMutex.RLock() - entry, exists := p.sessions[sessionID] - p.sessionMutex.RUnlock() - if exists { - if entry.selectionScope == selectionScope { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { - better := p.pickLeastUsed(selection.filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName(), selectionScope) - break - } - } - } - return cred, false, nil - } - } - } - p.sessionMutex.Lock() - delete(p.sessions, sessionID) - p.sessionMutex.Unlock() - } - } - - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allRateLimitedError(p.credentials) - } - - isNew := sessionID != "" - if isNew { - p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{ - tag: best.tagName(), - selectionScope: selectionScope, - createdAt: time.Now(), - } - p.sessionMutex.Unlock() - } - return best, isNew, nil -} - -func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { - key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} - p.interruptAccess.Lock() - if entry, loaded := p.credentialInterrupts[key]; loaded { - entry.cancel() - } - ctx, cancel := context.WithCancel(context.Background()) - p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} - p.interruptAccess.Unlock() - - p.sessionMutex.Lock() - for id, entry := range p.sessions { - if entry.tag == tag && entry.selectionScope == selectionScope { - delete(p.sessions, id) - } - } - p.sessionMutex.Unlock() -} - -func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { - if p.strategy == C.BalancerStrategyFallback { - return func() bool { return false } - } - key := credentialInterruptKey{ - tag: cred.tagName(), - selectionScope: selection.scopeOrDefault(), - } - p.interruptAccess.Lock() - entry, loaded := p.credentialInterrupts[key] - if !loaded { - ctx, cancel := context.WithCancel(context.Background()) - entry = credentialInterruptEntry{context: ctx, cancel: cancel} - p.credentialInterrupts[key] = entry - } - p.interruptAccess.Unlock() - return context.AfterFunc(entry.context, onInterrupt) -} - -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { - cred.markRateLimited(resetAt) - if p.strategy == C.BalancerStrategyFallback { - return p.pickCredential(selection.filter) - } - if sessionID != "" { - p.sessionMutex.Lock() - delete(p.sessions, sessionID) - p.sessionMutex.Unlock() - } - - best := p.pickCredential(selection.filter) - if best != nil && sessionID != "" { - p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{ - tag: best.tagName(), - selectionScope: selection.scopeOrDefault(), - createdAt: time.Now(), - } - p.sessionMutex.Unlock() - } - return best -} - -func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { - switch p.strategy { - case C.BalancerStrategyRoundRobin: - return p.pickRoundRobin(filter) - case C.BalancerStrategyRandom: - return p.pickRandom(filter) - case C.BalancerStrategyFallback: - return p.pickFallback(filter) - default: - return p.pickLeastUsed(filter) - } -} - -func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { - continue - } - if !compositeCredentialSelectable(cred) { - continue - } - if cred.isUsable() { - return cred - } - } - return nil -} - -func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { - var best credential - bestScore := float64(-1) - now := time.Now() - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { - continue - } - if !compositeCredentialSelectable(cred) { - continue - } - if !cred.isUsable() { - continue - } - remaining := cred.weeklyCap() - cred.weeklyUtilization() - score := remaining * cred.planWeight() - resetTime := cred.weeklyResetTime() - if !resetTime.IsZero() { - timeUntilReset := resetTime.Sub(now) - if timeUntilReset < time.Hour { - timeUntilReset = time.Hour - } - score *= weeklyWindowDuration / timeUntilReset.Hours() - } - if score > bestScore { - bestScore = score - best = cred - } - } - return best -} - -const weeklyWindowDuration = 7 * 24 // hours - -func ocmPlanWeight(accountType string) float64 { - switch accountType { - case "pro": - return 10 - case "plus": - return 1 - default: - return 1 - } -} - -func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { - start := int(p.roundRobinIndex.Add(1) - 1) - count := len(p.credentials) - for offset := range count { - candidate := p.credentials[(start+offset)%count] - if filter != nil && !filter(candidate) { - continue - } - if !compositeCredentialSelectable(candidate) { - continue - } - if candidate.isUsable() { - return candidate - } - } - return nil -} - -func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { - var usable []credential - for _, candidate := range p.credentials { - if filter != nil && !filter(candidate) { - continue - } - if !compositeCredentialSelectable(candidate) { - continue - } - if candidate.isUsable() { - usable = append(usable, candidate) - } - } - if len(usable) == 0 { - return nil - } - return usable[rand.IntN(len(usable))] -} - -func (p *balancerProvider) pollIfStale(ctx context.Context) { - now := time.Now() - p.sessionMutex.Lock() - for id, entry := range p.sessions { - if now.Sub(entry.createdAt) > sessionExpiry { - delete(p.sessions, id) - } - } - p.sessionMutex.Unlock() - - for _, cred := range p.credentials { - if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { - cred.pollUsage(ctx) - } - } -} - -func (p *balancerProvider) allCredentials() []credential { - return p.credentials -} - -func (p *balancerProvider) close() {} - -func allRateLimitedError(credentials []credential) error { - var hasUnavailable bool - var earliest time.Time - for _, cred := range credentials { - if cred.unavailableError() != nil { - hasUnavailable = true - continue - } - resetAt := cred.earliestReset() - if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { - earliest = resetAt - } - } - if hasUnavailable { - return E.New("all credentials unavailable") - } - if earliest.IsZero() { - return E.New("all credentials rate-limited") - } - return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) -} - -func buildOCMCredentialProviders( - ctx context.Context, - options option.OCMServiceOptions, - logger log.ContextLogger, -) (map[string]credentialProvider, []credential, error) { - allCredentialMap := make(map[string]credential) - var allCreds []credential - providers := make(map[string]credentialProvider) - - // Pass 1: create default and external credentials - for _, credOpt := range options.Credentials { - switch credOpt.Type { - case "default": - cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) - if err != nil { - return nil, nil, err - } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} - case "external": - cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) - if err != nil { - return nil, nil, err - } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} - } - } - - // Pass 2: create balancer providers - for _, credOpt := range options.Credentials { - if credOpt.Type == "balancer" { - subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) - if err != nil { - return nil, nil, err - } - providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) - } - } - - return providers, allCreds, nil -} - -func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { - credentials := make([]credential, 0, len(tags)) - for _, tag := range tags { - cred, exists := allCredentials[tag] - if !exists { - return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) - } - credentials = append(credentials, cred) - } - if len(credentials) == 0 { - return nil, E.New("credential ", parentTag, " has no sub-credentials") - } - return credentials, nil -} - -func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time { - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier != "" { - resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at" - if resetStr := headers.Get(resetHeader); resetStr != "" { - value, err := strconv.ParseInt(resetStr, 10, 64) - if err == nil { - return time.Unix(value, 0) - } - } - } - if retryAfter := headers.Get("Retry-After"); retryAfter != "" { - seconds, err := strconv.ParseInt(retryAfter, 10, 64) - if err == nil { - return time.Now().Add(time.Duration(seconds) * time.Second) - } - } - return time.Now().Add(5 * time.Minute) -} - -func validateOCMOptions(options option.OCMServiceOptions) error { - hasCredentials := len(options.Credentials) > 0 - hasLegacyPath := options.CredentialPath != "" - hasLegacyUsages := options.UsagesPath != "" - hasLegacyDetour := options.Detour != "" - - if hasCredentials && hasLegacyPath { - return E.New("credential_path and credentials are mutually exclusive") - } - if hasCredentials && hasLegacyUsages { - return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") - } - if hasCredentials && hasLegacyDetour { - return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") - } - - if hasCredentials { - tags := make(map[string]bool) - credentialTypes := make(map[string]string) - for _, cred := range options.Credentials { - if tags[cred.Tag] { - return E.New("duplicate credential tag: ", cred.Tag) - } - tags[cred.Tag] = true - credentialTypes[cred.Tag] = cred.Type - if cred.Type == "default" || cred.Type == "" { - if cred.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") - } - if cred.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") - } - if cred.DefaultOptions.Limit5h > 100 { - return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") - } - if cred.DefaultOptions.LimitWeekly > 100 { - return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") - } - if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { - return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") - } - if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { - return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") - } - } - if cred.Type == "external" { - if cred.ExternalOptions.Token == "" { - return E.New("credential ", cred.Tag, ": external credential requires token") - } - if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": reverse external credential requires url") - } - } - if cred.Type == "balancer" { - switch cred.BalancerOptions.Strategy { - case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: - default: - return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) - } - if cred.BalancerOptions.RebalanceThreshold < 0 { - return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") - } - } - } - - for _, user := range options.Users { - if user.Credential == "" { - return E.New("user ", user.Name, " must specify credential in multi-credential mode") - } - if !tags[user.Credential] { - return E.New("user ", user.Name, " references unknown credential: ", user.Credential) - } - if user.ExternalCredential != "" { - if !tags[user.ExternalCredential] { - return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) - } - if credentialTypes[user.ExternalCredential] != "external" { - return E.New("user ", user.Name, ": external_credential must reference an external type credential") - } - } - } - } - - return nil -} - -func validateOCMCompositeCredentialModes( - options option.OCMServiceOptions, - providers map[string]credentialProvider, -) error { - for _, credOpt := range options.Credentials { - if credOpt.Type != "balancer" { - continue - } - - provider, exists := providers[credOpt.Tag] - if !exists { - return E.New("unknown credential: ", credOpt.Tag) - } - - for _, subCred := range provider.allCredentials() { - if !subCred.isAvailable() { - continue - } - if subCred.ocmIsAPIKeyMode() { - return E.New( - "credential ", credOpt.Tag, - " references API key default credential ", subCred.tagName(), - "; balancer and fallback only support OAuth default credentials", - ) - } - } - } - - return nil -} - -func credentialForUser( - userConfigMap map[string]*option.OCMUser, - providers map[string]credentialProvider, - legacyProvider credentialProvider, - username string, -) (credentialProvider, error) { - if legacyProvider != nil { - return legacyProvider, nil - } - userConfig, exists := userConfigMap[username] - if !exists { - return nil, E.New("no credential mapping for user: ", username) - } - provider, exists := providers[userConfig.Credential] - if !exists { - return nil, E.New("unknown credential: ", userConfig.Credential) - } - return provider, nil -} - -func noUserCredentialProvider( - providers map[string]credentialProvider, - legacyProvider credentialProvider, - options option.OCMServiceOptions, -) credentialProvider { - if legacyProvider != nil { - return legacyProvider - } - if len(options.Credentials) > 0 { - tag := options.Credentials[0].Tag - return providers[tag] - } - return nil -} diff --git a/service/ocm/service.go b/service/ocm/service.go index 071cec8ccb..101f904926 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -1,17 +1,13 @@ package ocm import ( - "bytes" "context" "encoding/json" "errors" "io" - "mime" "net/http" - "strconv" "strings" "sync" - "time" "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" @@ -21,14 +17,11 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" aTLS "github.com/sagernet/sing/common/tls" "github.com/go-chi/chi/v5" - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/responses" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -160,71 +153,20 @@ func isReverseProxyHeader(header string) bool { } } -func normalizeRateLimitIdentifier(limitIdentifier string) string { - trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier)) - if trimmedIdentifier == "" { - return "" - } - return strings.ReplaceAll(trimmedIdentifier, "_", "-") -} - -func parseInt64Header(headers http.Header, headerName string) (int64, bool) { - headerValue := strings.TrimSpace(headers.Get(headerName)) - if headerValue == "" { - return 0, false - } - parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64) - if parseError != nil { - return 0, false - } - return parsedValue, true -} - -func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint { - normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier) - if normalizedLimitIdentifier == "" { - return nil - } - - windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes" - resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at" - - windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader) - resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader) - if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 { - return nil - } - - return &WeeklyCycleHint{ - WindowMinutes: windowMinutes, - ResetAt: time.Unix(resetAtUnix, 0).UTC(), - } -} - -func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier != "" { - if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil { - return activeHint - } - } - return weeklyCycleHintForLimit(headers, "codex") -} - type Service struct { boxService.Adapter - ctx context.Context - logger log.ContextLogger - options option.OCMServiceOptions - httpHeaders http.Header - listener *listener.Listener - tlsConfig tls.ServerConfig - httpServer *http.Server - userManager *UserManager - webSocketMutex sync.Mutex - webSocketGroup sync.WaitGroup - webSocketConns map[*webSocketSession]struct{} - shuttingDown bool + ctx context.Context + logger log.ContextLogger + options option.OCMServiceOptions + httpHeaders http.Header + listener *listener.Listener + tlsConfig tls.ServerConfig + httpServer *http.Server + userManager *UserManager + webSocketAccess sync.Mutex + webSocketGroup sync.WaitGroup + webSocketConns map[*webSocketSession]struct{} + shuttingDown bool // Legacy mode legacyCredential *defaultCredential @@ -361,562 +303,6 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } -func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { - if len(s.options.Users) > 0 { - return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - } - provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options) - if provider == nil { - return nil, E.New("no credential available") - } - return provider, nil -} - -func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := log.ContextWithNewID(r.Context()) - if r.URL.Path == "/ocm/v1/status" { - s.handleStatusEndpoint(w, r) - return - } - - if r.URL.Path == "/ocm/v1/reverse" { - s.handleReverseConnect(ctx, w, r) - return - } - - path := r.URL.Path - if !strings.HasPrefix(path, "/v1/") { - writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") - return - } - - var username string - if len(s.options.Users) > 0 { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - var ok bool - username, ok = s.userManager.Authenticate(clientToken) - if !ok { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } - } - - sessionID := r.Header.Get("session_id") - - // Resolve credential provider and user config - var provider credentialProvider - var userConfig *option.OCMUser - if len(s.options.Users) > 0 { - userConfig = s.userConfigMap[username] - var err error - provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - s.logger.ErrorContext(ctx, "resolve credential: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) - return - } - } else { - provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) - } - if provider == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") - return - } - - provider.pollIfStale(s.ctx) - - selection := credentialSelectionForUser(userConfig) - - selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) - if err != nil { - writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) - return - } - - if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew) - return - } - - if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() { - // API key mode path handling - } else if !selectedCredential.isExternal() { - if path == "/v1/chat/completions" { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "chat completions endpoint is only available in API key mode") - return - } - } - - shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil && - (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) - canRetryRequest := len(provider.allCredentials()) > 1 - - // Read body for model extraction and retry buffer when JSON replay is useful. - var bodyBytes []byte - var requestModel string - var requestServiceTier string - if r.Body != nil && (shouldTrackUsage || canRetryRequest) { - mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) - isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) - if isJSONRequest { - bodyBytes, err = io.ReadAll(r.Body) - if err != nil { - s.logger.ErrorContext(ctx, "read request body: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") - return - } - var request struct { - Model string `json:"model"` - ServiceTier string `json:"service_tier"` - } - if json.Unmarshal(bodyBytes, &request) == nil { - requestModel = request.Model - requestServiceTier = request.ServiceTier - } - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - } - - if isNew { - logParts := []any{"assigned credential ", selectedCredential.tagName()} - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if username != "" { - logParts = append(logParts, " by user ", username) - } - if requestModel != "" { - logParts = append(logParts, ", model=", requestModel) - } - if requestServiceTier == "priority" { - logParts = append(logParts, ", fast") - } - s.logger.DebugContext(ctx, logParts...) - } - - requestContext := selectedCredential.wrapRequestContext(ctx) - { - currentRequestContext := requestContext - requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { - currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) - })) - } - defer func() { - requestContext.cancelRequest() - }() - proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if err != nil { - s.logger.ErrorContext(ctx, "create proxy request: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") - return - } - - response, err := selectedCredential.httpTransport().Do(proxyRequest) - if err != nil { - if r.Context().Err() != nil { - return - } - if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") - return - } - writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) - return - } - requestContext.releaseCredentialInterrupt() - - // Transparent 429 retry - for response.StatusCode == http.StatusTooManyRequests { - resetAt := parseOCMRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) - needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete - selectedCredential.updateStateFromHeaders(response.Header) - if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { - response.Body.Close() - writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") - return - } - response.Body.Close() - s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) - requestContext.cancelRequest() - requestContext = nextCredential.wrapRequestContext(ctx) - { - currentRequestContext := requestContext - requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { - currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) - })) - } - retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if buildErr != nil { - s.logger.ErrorContext(ctx, "retry request: ", buildErr) - writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) - return - } - retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest) - if retryErr != nil { - if r.Context().Err() != nil { - return - } - if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") - return - } - s.logger.ErrorContext(ctx, "retry request: ", retryErr) - writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) - return - } - requestContext.releaseCredentialInterrupt() - response = retryResponse - selectedCredential = nextCredential - } - defer response.Body.Close() - - selectedCredential.updateStateFromHeaders(response.Header) - - if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { - body, _ := io.ReadAll(response.Body) - s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) - go selectedCredential.pollUsage(s.ctx) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", - "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) - return - } - - // Rewrite response headers for external users - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) - } - - for key, values := range response.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { - w.Header()[key] = values - } - } - w.WriteHeader(response.StatusCode) - - usageTracker := selectedCredential.usageTrackerOrNil() - if usageTracker != nil && response.StatusCode == http.StatusOK && - (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { - s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username) - } else { - mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) - if err == nil && mediaType != "text/event-stream" { - _, _ = io.Copy(w, response.Body) - return - } - flusher, ok := w.(http.Flusher) - if !ok { - s.logger.ErrorContext(ctx, "streaming not supported") - return - } - buffer := make([]byte, buf.BufferSize) - for { - n, err := response.Body.Read(buffer) - if n > 0 { - _, writeError := w.Write(buffer[:n]) - if writeError != nil { - s.logger.ErrorContext(ctx, "write streaming response: ", writeError) - return - } - flusher.Flush() - } - if err != nil { - return - } - } - } -} - -func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { - isChatCompletions := path == "/v1/chat/completions" - weeklyCycleHint := extractWeeklyCycleHint(response.Header) - mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) - isStreaming := err == nil && mediaType == "text/event-stream" - if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" { - isStreaming = true - } - if !isStreaming { - bodyBytes, err := io.ReadAll(response.Body) - if err != nil { - s.logger.ErrorContext(ctx, "read response body: ", err) - return - } - - var responseModel, serviceTier string - var inputTokens, outputTokens, cachedTokens int64 - - if isChatCompletions { - var chatCompletion openai.ChatCompletion - if json.Unmarshal(bodyBytes, &chatCompletion) == nil { - responseModel = chatCompletion.Model - serviceTier = string(chatCompletion.ServiceTier) - inputTokens = chatCompletion.Usage.PromptTokens - outputTokens = chatCompletion.Usage.CompletionTokens - cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens - } - } else { - var responsesResponse responses.Response - if json.Unmarshal(bodyBytes, &responsesResponse) == nil { - responseModel = string(responsesResponse.Model) - serviceTier = string(responsesResponse.ServiceTier) - inputTokens = responsesResponse.Usage.InputTokens - outputTokens = responsesResponse.Usage.OutputTokens - cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens - } - } - - if inputTokens > 0 || outputTokens > 0 { - if responseModel == "" { - responseModel = requestModel - } - if responseModel != "" { - contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - inputTokens, - outputTokens, - cachedTokens, - serviceTier, - username, - time.Now(), - weeklyCycleHint, - ) - } - } - - _, _ = writer.Write(bodyBytes) - return - } - - flusher, ok := writer.(http.Flusher) - if !ok { - s.logger.ErrorContext(ctx, "streaming not supported") - return - } - - var inputTokens, outputTokens, cachedTokens int64 - var responseModel, serviceTier string - buffer := make([]byte, buf.BufferSize) - var leftover []byte - - for { - n, err := response.Body.Read(buffer) - if n > 0 { - data := append(leftover, buffer[:n]...) - lines := bytes.Split(data, []byte("\n")) - - if err == nil { - leftover = lines[len(lines)-1] - lines = lines[:len(lines)-1] - } else { - leftover = nil - } - - for _, line := range lines { - line = bytes.TrimSpace(line) - if len(line) == 0 { - continue - } - - if bytes.HasPrefix(line, []byte("data: ")) { - eventData := bytes.TrimPrefix(line, []byte("data: ")) - if bytes.Equal(eventData, []byte("[DONE]")) { - continue - } - - if isChatCompletions { - var chatChunk openai.ChatCompletionChunk - if json.Unmarshal(eventData, &chatChunk) == nil { - if chatChunk.Model != "" { - responseModel = chatChunk.Model - } - if chatChunk.ServiceTier != "" { - serviceTier = string(chatChunk.ServiceTier) - } - if chatChunk.Usage.PromptTokens > 0 { - inputTokens = chatChunk.Usage.PromptTokens - cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens - } - if chatChunk.Usage.CompletionTokens > 0 { - outputTokens = chatChunk.Usage.CompletionTokens - } - } - } else { - var streamEvent responses.ResponseStreamEventUnion - if json.Unmarshal(eventData, &streamEvent) == nil { - if streamEvent.Type == "response.completed" { - completedEvent := streamEvent.AsResponseCompleted() - if string(completedEvent.Response.Model) != "" { - responseModel = string(completedEvent.Response.Model) - } - if completedEvent.Response.ServiceTier != "" { - serviceTier = string(completedEvent.Response.ServiceTier) - } - if completedEvent.Response.Usage.InputTokens > 0 { - inputTokens = completedEvent.Response.Usage.InputTokens - cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens - } - if completedEvent.Response.Usage.OutputTokens > 0 { - outputTokens = completedEvent.Response.Usage.OutputTokens - } - } - } - } - } - } - - _, writeError := writer.Write(buffer[:n]) - if writeError != nil { - s.logger.ErrorContext(ctx, "write streaming response: ", writeError) - return - } - flusher.Flush() - } - - if err != nil { - if responseModel == "" { - responseModel = requestModel - } - - if inputTokens > 0 || outputTokens > 0 { - if responseModel != "" { - contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - inputTokens, - outputTokens, - cachedTokens, - serviceTier, - username, - time.Now(), - weeklyCycleHint, - ) - } - } - return - } - } -} - -func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") - return - } - - if len(s.options.Users) == 0 { - writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") - return - } - - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - username, ok := s.userManager.Authenticate(clientToken) - if !ok { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } - - userConfig := s.userConfigMap[username] - if userConfig == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") - return - } - - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) - return - } - - provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]float64{ - "five_hour_utilization": avgFiveHour, - "weekly_utilization": avgWeekly, - "plan_weight": totalWeight, - }) -} - -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { - var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 - for _, cred := range provider.allCredentials() { - if !cred.isAvailable() { - continue - } - if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { - continue - } - if !userConfig.AllowExternalUsage && cred.isExternal() { - continue - } - weight := cred.planWeight() - remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() - if remaining5h < 0 { - remaining5h = 0 - } - remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() - if remainingWeekly < 0 { - remainingWeekly = 0 - } - totalWeightedRemaining5h += remaining5h * weight - totalWeightedRemainingWeekly += remainingWeekly * weight - totalWeight += weight - } - if totalWeight == 0 { - return 100, 100, 0 - } - return 100 - totalWeightedRemaining5h/totalWeight, - 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight -} - -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) - if err != nil { - return - } - - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier == "" { - activeLimitIdentifier = "codex" - } - - headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) - headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) - if totalWeight > 0 { - headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) - } -} - func (s *Service) InterfaceUpdated() { for _, cred := range s.allCredentials { extCred, ok := cred.(*externalCredential) @@ -952,8 +338,8 @@ func (s *Service) Close() error { } func (s *Service) registerWebSocketSession(session *webSocketSession) bool { - s.webSocketMutex.Lock() - defer s.webSocketMutex.Unlock() + s.webSocketAccess.Lock() + defer s.webSocketAccess.Unlock() if s.shuttingDown { return false @@ -965,12 +351,12 @@ func (s *Service) registerWebSocketSession(session *webSocketSession) bool { } func (s *Service) unregisterWebSocketSession(session *webSocketSession) { - s.webSocketMutex.Lock() + s.webSocketAccess.Lock() _, loaded := s.webSocketConns[session] if loaded { delete(s.webSocketConns, session) } - s.webSocketMutex.Unlock() + s.webSocketAccess.Unlock() if loaded { s.webSocketGroup.Done() @@ -978,28 +364,28 @@ func (s *Service) unregisterWebSocketSession(session *webSocketSession) { } func (s *Service) isShuttingDown() bool { - s.webSocketMutex.Lock() - defer s.webSocketMutex.Unlock() + s.webSocketAccess.Lock() + defer s.webSocketAccess.Unlock() return s.shuttingDown } func (s *Service) interruptWebSocketSessionsForCredential(tag string) { - s.webSocketMutex.Lock() + s.webSocketAccess.Lock() var toClose []*webSocketSession for session := range s.webSocketConns { if session.credentialTag == tag { toClose = append(toClose, session) } } - s.webSocketMutex.Unlock() + s.webSocketAccess.Unlock() for _, session := range toClose { session.Close() } } func (s *Service) startWebSocketShutdown() []*webSocketSession { - s.webSocketMutex.Lock() - defer s.webSocketMutex.Unlock() + s.webSocketAccess.Lock() + defer s.webSocketAccess.Unlock() s.shuttingDown = true diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go new file mode 100644 index 0000000000..9fb9c96d77 --- /dev/null +++ b/service/ocm/service_handler.go @@ -0,0 +1,504 @@ +package ocm + +import ( + "bytes" + "context" + "encoding/json" + "io" + "mime" + "net/http" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" +) + +func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint { + normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier) + if normalizedLimitIdentifier == "" { + return nil + } + + windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes" + resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at" + + windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader) + resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader) + if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 { + return nil + } + + return &WeeklyCycleHint{ + WindowMinutes: windowMinutes, + ResetAt: time.Unix(resetAtUnix, 0).UTC(), + } +} + +func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier != "" { + if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil { + return activeHint + } + } + return weeklyCycleHintForLimit(headers, "codex") +} + +func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { + if len(s.options.Users) > 0 { + return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + } + provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + if provider == nil { + return nil, E.New("no credential available") + } + return provider, nil +} + +func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := log.ContextWithNewID(r.Context()) + if r.URL.Path == "/ocm/v1/status" { + s.handleStatusEndpoint(w, r) + return + } + + if r.URL.Path == "/ocm/v1/reverse" { + s.handleReverseConnect(ctx, w, r) + return + } + + path := r.URL.Path + if !strings.HasPrefix(path, "/v1/") { + writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") + return + } + + var username string + if len(s.options.Users) > 0 { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + var ok bool + username, ok = s.userManager.Authenticate(clientToken) + if !ok { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + } + + sessionID := r.Header.Get("session_id") + + // Resolve credential provider and user config + var provider credentialProvider + var userConfig *option.OCMUser + if len(s.options.Users) > 0 { + userConfig = s.userConfigMap[username] + var err error + provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + s.logger.ErrorContext(ctx, "resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + } + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") + return + } + + provider.pollIfStale(s.ctx) + + selection := credentialSelectionForUser(userConfig) + + selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) + if err != nil { + writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) + return + } + + if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { + s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew) + return + } + + if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() { + // API key mode path handling + } else if !selectedCredential.isExternal() { + if path == "/v1/chat/completions" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "chat completions endpoint is only available in API key mode") + return + } + } + + shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) + canRetryRequest := len(provider.allCredentials()) > 1 + + // Read body for model extraction and retry buffer when JSON replay is useful. + var bodyBytes []byte + var requestModel string + var requestServiceTier string + if r.Body != nil && (shouldTrackUsage || canRetryRequest) { + mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) + isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) + if isJSONRequest { + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + s.logger.ErrorContext(ctx, "read request body: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") + return + } + var request struct { + Model string `json:"model"` + ServiceTier string `json:"service_tier"` + } + if json.Unmarshal(bodyBytes, &request) == nil { + requestModel = request.Model + requestServiceTier = request.ServiceTier + } + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + } + + if isNew { + logParts := []any{"assigned credential ", selectedCredential.tagName()} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + if requestModel != "" { + logParts = append(logParts, ", model=", requestModel) + } + if requestServiceTier == "priority" { + logParts = append(logParts, ", fast") + } + s.logger.DebugContext(ctx, logParts...) + } + + requestContext := selectedCredential.wrapRequestContext(ctx) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } + defer func() { + requestContext.cancelRequest() + }() + proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if err != nil { + s.logger.ErrorContext(ctx, "create proxy request: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") + return + } + + response, err := selectedCredential.httpClient().Do(proxyRequest) + if err != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") + return + } + writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) + return + } + requestContext.releaseCredentialInterrupt() + + // Transparent 429 retry + for response.StatusCode == http.StatusTooManyRequests { + resetAt := parseOCMRateLimitResetFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) + needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete + selectedCredential.updateStateFromHeaders(response.Header) + if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { + response.Body.Close() + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") + return + } + response.Body.Close() + s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + requestContext.cancelRequest() + requestContext = nextCredential.wrapRequestContext(ctx) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } + retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if buildErr != nil { + s.logger.ErrorContext(ctx, "retry request: ", buildErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) + return + } + retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest) + if retryErr != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") + return + } + s.logger.ErrorContext(ctx, "retry request: ", retryErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) + return + } + requestContext.releaseCredentialInterrupt() + response = retryResponse + selectedCredential = nextCredential + } + defer response.Body.Close() + + selectedCredential.updateStateFromHeaders(response.Header) + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) + go selectedCredential.pollUsage(s.ctx) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) + return + } + + // Rewrite response headers for external users + if userConfig != nil && userConfig.ExternalCredential != "" { + s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) + } + + for key, values := range response.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { + w.Header()[key] = values + } + } + w.WriteHeader(response.StatusCode) + + usageTracker := selectedCredential.usageTrackerOrNil() + if usageTracker != nil && response.StatusCode == http.StatusOK && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { + s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username) + } else { + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + if err == nil && mediaType != "text/event-stream" { + _, _ = io.Copy(w, response.Body) + return + } + flusher, ok := w.(http.Flusher) + if !ok { + s.logger.ErrorContext(ctx, "streaming not supported") + return + } + buffer := make([]byte, buf.BufferSize) + for { + n, err := response.Body.Read(buffer) + if n > 0 { + _, writeError := w.Write(buffer[:n]) + if writeError != nil { + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) + return + } + flusher.Flush() + } + if err != nil { + return + } + } + } +} + +func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { + isChatCompletions := path == "/v1/chat/completions" + weeklyCycleHint := extractWeeklyCycleHint(response.Header) + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + isStreaming := err == nil && mediaType == "text/event-stream" + if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" { + isStreaming = true + } + if !isStreaming { + bodyBytes, err := io.ReadAll(response.Body) + if err != nil { + s.logger.ErrorContext(ctx, "read response body: ", err) + return + } + + var responseModel, serviceTier string + var inputTokens, outputTokens, cachedTokens int64 + + if isChatCompletions { + var chatCompletion openai.ChatCompletion + if json.Unmarshal(bodyBytes, &chatCompletion) == nil { + responseModel = chatCompletion.Model + serviceTier = string(chatCompletion.ServiceTier) + inputTokens = chatCompletion.Usage.PromptTokens + outputTokens = chatCompletion.Usage.CompletionTokens + cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens + } + } else { + var responsesResponse responses.Response + if json.Unmarshal(bodyBytes, &responsesResponse) == nil { + responseModel = string(responsesResponse.Model) + serviceTier = string(responsesResponse.ServiceTier) + inputTokens = responsesResponse.Usage.InputTokens + outputTokens = responsesResponse.Usage.OutputTokens + cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens + } + } + + if inputTokens > 0 || outputTokens > 0 { + if responseModel == "" { + responseModel = requestModel + } + if responseModel != "" { + contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + inputTokens, + outputTokens, + cachedTokens, + serviceTier, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + + _, _ = writer.Write(bodyBytes) + return + } + + flusher, ok := writer.(http.Flusher) + if !ok { + s.logger.ErrorContext(ctx, "streaming not supported") + return + } + + var inputTokens, outputTokens, cachedTokens int64 + var responseModel, serviceTier string + buffer := make([]byte, buf.BufferSize) + var leftover []byte + + for { + n, err := response.Body.Read(buffer) + if n > 0 { + data := append(leftover, buffer[:n]...) + lines := bytes.Split(data, []byte("\n")) + + if err == nil { + leftover = lines[len(lines)-1] + lines = lines[:len(lines)-1] + } else { + leftover = nil + } + + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + + if bytes.HasPrefix(line, []byte("data: ")) { + eventData := bytes.TrimPrefix(line, []byte("data: ")) + if bytes.Equal(eventData, []byte("[DONE]")) { + continue + } + + if isChatCompletions { + var chatChunk openai.ChatCompletionChunk + if json.Unmarshal(eventData, &chatChunk) == nil { + if chatChunk.Model != "" { + responseModel = chatChunk.Model + } + if chatChunk.ServiceTier != "" { + serviceTier = string(chatChunk.ServiceTier) + } + if chatChunk.Usage.PromptTokens > 0 { + inputTokens = chatChunk.Usage.PromptTokens + cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens + } + if chatChunk.Usage.CompletionTokens > 0 { + outputTokens = chatChunk.Usage.CompletionTokens + } + } + } else { + var streamEvent responses.ResponseStreamEventUnion + if json.Unmarshal(eventData, &streamEvent) == nil { + if streamEvent.Type == "response.completed" { + completedEvent := streamEvent.AsResponseCompleted() + if string(completedEvent.Response.Model) != "" { + responseModel = string(completedEvent.Response.Model) + } + if completedEvent.Response.ServiceTier != "" { + serviceTier = string(completedEvent.Response.ServiceTier) + } + if completedEvent.Response.Usage.InputTokens > 0 { + inputTokens = completedEvent.Response.Usage.InputTokens + cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens + } + if completedEvent.Response.Usage.OutputTokens > 0 { + outputTokens = completedEvent.Response.Usage.OutputTokens + } + } + } + } + } + } + + _, writeError := writer.Write(buffer[:n]) + if writeError != nil { + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) + return + } + flusher.Flush() + } + + if err != nil { + if responseModel == "" { + responseModel = requestModel + } + + if inputTokens > 0 || outputTokens > 0 { + if responseModel != "" { + contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + inputTokens, + outputTokens, + cachedTokens, + serviceTier, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + return + } + } +} diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go new file mode 100644 index 0000000000..29b95d063b --- /dev/null +++ b/service/ocm/service_status.go @@ -0,0 +1,114 @@ +package ocm + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/sagernet/sing-box/option" +) + +func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") + return + } + + if len(s.options.Users) == 0 { + writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + username, ok := s.userManager.Authenticate(clientToken) + if !ok { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + + userConfig := s.userConfigMap[username] + if userConfig == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") + return + } + + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + provider.pollIfStale(r.Context()) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]float64{ + "five_hour_utilization": avgFiveHour, + "weekly_utilization": avgWeekly, + "plan_weight": totalWeight, + }) +} + +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + for _, cred := range provider.allCredentials() { + if !cred.isAvailable() { + continue + } + if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { + continue + } + if !userConfig.AllowExternalUsage && cred.isExternal() { + continue + } + weight := cred.planWeight() + remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 + } + remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 + } + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight + } + if totalWeight == 0 { + return 100, 100, 0 + } + return 100 - totalWeightedRemaining5h/totalWeight, + 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight +} + +func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) + if err != nil { + return + } + + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier == "" { + activeLimitIdentifier = "codex" + } + + headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) + headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) + if totalWeight > 0 { + headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + } +} diff --git a/service/ocm/service_user.go b/service/ocm/service_user.go index 494b981b9b..b69655e9ac 100644 --- a/service/ocm/service_user.go +++ b/service/ocm/service_user.go @@ -7,13 +7,13 @@ import ( ) type UserManager struct { - accessMutex sync.RWMutex + access sync.RWMutex tokenMap map[string]string } func (m *UserManager) UpdateUsers(users []option.OCMUser) { - m.accessMutex.Lock() - defer m.accessMutex.Unlock() + m.access.Lock() + defer m.access.Unlock() tokenMap := make(map[string]string, len(users)) for _, user := range users { tokenMap[user.Token] = user.Name @@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.OCMUser) { } func (m *UserManager) Authenticate(token string) (string, bool) { - m.accessMutex.RLock() + m.access.RLock() username, found := m.tokenMap[token] - m.accessMutex.RUnlock() + m.access.RUnlock() return username, found } From 6878ad0d358327633ccf457c2ba0dced0724908b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 21:06:25 +0800 Subject: [PATCH 46/96] ccm,ocm: fix naming and error-handling convention violations - Rename credential interface to Credential (exported), cred to credential - Rename mutex/saveMutex to access/saveAccess per go-syntax.md - Fix abbreviations: reverseHttpClient, allCreds, credOpt, extCred, credDialer, reverseCredDialer, portStr - Replace errors.Is(http.ErrServerClosed) with E.IsClosed - Add E.IsClosedOrCanceled guard before streaming write error logs --- service/ccm/credential.go | 8 +- service/ccm/credential_builder.go | 100 +++++++++++----------- service/ccm/credential_external.go | 52 ++++++------ service/ccm/credential_provider.go | 124 +++++++++++++-------------- service/ccm/reverse.go | 13 ++- service/ccm/service.go | 53 ++++++------ service/ccm/service_handler.go | 7 ++ service/ccm/service_status.go | 14 +-- service/ccm/service_usage.go | 32 +++---- service/ccm/service_user.go | 4 +- service/ocm/credential.go | 8 +- service/ocm/credential_builder.go | 110 ++++++++++++------------ service/ocm/credential_external.go | 94 ++++++++++---------- service/ocm/credential_provider.go | 132 ++++++++++++++--------------- service/ocm/reverse.go | 13 ++- service/ocm/service.go | 57 ++++++------- service/ocm/service_handler.go | 6 ++ service/ocm/service_status.go | 14 +-- service/ocm/service_usage.go | 32 +++---- service/ocm/service_user.go | 4 +- service/ocm/service_websocket.go | 10 +-- 21 files changed, 448 insertions(+), 439 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 8589676a8c..d5cae9e1e3 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -90,7 +90,7 @@ func (c *credentialRequestContext) cancelRequest() { c.cancelOnce.Do(c.cancelFunc) } -type credential interface { +type Credential interface { tagName() string isAvailable() bool isUsable() bool @@ -130,11 +130,11 @@ const ( type credentialSelection struct { scope credentialSelectionScope - filter func(credential) bool + filter func(Credential) bool } -func (s credentialSelection) allows(cred credential) bool { - return s.filter == nil || s.filter(cred) +func (s credentialSelection) allows(credential Credential) bool { + return s.filter == nil || s.filter(credential) } func (s credentialSelection) scopeOrDefault() credentialSelectionScope { diff --git a/service/ccm/credential_builder.go b/service/ccm/credential_builder.go index c49a201950..63bfd03951 100644 --- a/service/ccm/credential_builder.go +++ b/service/ccm/credential_builder.go @@ -14,55 +14,55 @@ func buildCredentialProviders( ctx context.Context, options option.CCMServiceOptions, logger log.ContextLogger, -) (map[string]credentialProvider, []credential, error) { - allCredentialMap := make(map[string]credential) - var allCreds []credential +) (map[string]credentialProvider, []Credential, error) { + allCredentialMap := make(map[string]Credential) + var allCredentials []Credential providers := make(map[string]credentialProvider) // Pass 1: create default and external credentials - for _, credOpt := range options.Credentials { - switch credOpt.Type { + for _, credentialOption := range options.Credentials { + switch credentialOption.Type { case "default": - cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + credential, err := newDefaultCredential(ctx, credentialOption.Tag, credentialOption.DefaultOptions, logger) if err != nil { return nil, nil, err } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + allCredentialMap[credentialOption.Tag] = credential + allCredentials = append(allCredentials, credential) + providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential} case "external": - cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) + credential, err := newExternalCredential(ctx, credentialOption.Tag, credentialOption.ExternalOptions, logger) if err != nil { return nil, nil, err } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + allCredentialMap[credentialOption.Tag] = credential + allCredentials = append(allCredentials, credential) + providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential} } } // Pass 2: create balancer providers - for _, credOpt := range options.Credentials { - if credOpt.Type == "balancer" { - subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) + for _, credentialOption := range options.Credentials { + if credentialOption.Type == "balancer" { + subCredentials, err := resolveCredentialTags(credentialOption.BalancerOptions.Credentials, allCredentialMap, credentialOption.Tag) if err != nil { return nil, nil, err } - providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) + providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, time.Duration(credentialOption.BalancerOptions.PollInterval), credentialOption.BalancerOptions.RebalanceThreshold, logger) } } - return providers, allCreds, nil + return providers, allCredentials, nil } -func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { - credentials := make([]credential, 0, len(tags)) +func resolveCredentialTags(tags []string, allCredentials map[string]Credential, parentTag string) ([]Credential, error) { + credentials := make([]Credential, 0, len(tags)) for _, tag := range tags { - cred, exists := allCredentials[tag] + credential, exists := allCredentials[tag] if !exists { return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) } - credentials = append(credentials, cred) + credentials = append(credentials, credential) } if len(credentials) == 0 { return nil, E.New("credential ", parentTag, " has no sub-credentials") @@ -89,48 +89,48 @@ func validateCCMOptions(options option.CCMServiceOptions) error { if hasCredentials { tags := make(map[string]bool) credentialTypes := make(map[string]string) - for _, cred := range options.Credentials { - if tags[cred.Tag] { - return E.New("duplicate credential tag: ", cred.Tag) + for _, credential := range options.Credentials { + if tags[credential.Tag] { + return E.New("duplicate credential tag: ", credential.Tag) } - tags[cred.Tag] = true - credentialTypes[cred.Tag] = cred.Type - if cred.Type == "default" || cred.Type == "" { - if cred.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") + tags[credential.Tag] = true + credentialTypes[credential.Tag] = credential.Type + if credential.Type == "default" || credential.Type == "" { + if credential.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") } - if cred.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") + if credential.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") } - if cred.DefaultOptions.Limit5h > 100 { - return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") + if credential.DefaultOptions.Limit5h > 100 { + return E.New("credential ", credential.Tag, ": limit_5h must be at most 100") } - if cred.DefaultOptions.LimitWeekly > 100 { - return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") + if credential.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100") } - if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { - return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") + if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 { + return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive") } - if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { - return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") + if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") } } - if cred.Type == "external" { - if cred.ExternalOptions.Token == "" { - return E.New("credential ", cred.Tag, ": external credential requires token") + if credential.Type == "external" { + if credential.ExternalOptions.Token == "" { + return E.New("credential ", credential.Tag, ": external credential requires token") } - if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": reverse external credential requires url") + if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" { + return E.New("credential ", credential.Tag, ": reverse external credential requires url") } } - if cred.Type == "balancer" { - switch cred.BalancerOptions.Strategy { + if credential.Type == "balancer" { + switch credential.BalancerOptions.Strategy { case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: default: - return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) + return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) } - if cred.BalancerOptions.RebalanceThreshold < 0 { - return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") + if credential.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative") } } } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 24ddf6c4a2..eb75c5b082 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -48,7 +48,7 @@ type externalCredential struct { // Reverse proxy fields reverse bool - reverseHttpClient *http.Client + reverseHTTPClient *http.Client reverseSession *yamux.Session reverseAccess sync.RWMutex closed bool @@ -63,9 +63,9 @@ type externalCredential struct { } func externalCredentialURLPort(parsedURL *url.URL) uint16 { - portStr := parsedURL.Port() - if portStr != "" { - port, err := strconv.ParseUint(portStr, 10, 16) + portString := parsedURL.Port() + if portString != "" { + port, err := strconv.ParseUint(portString, 10, 16) if err == nil { return uint16(port) } @@ -113,7 +113,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) - cred := &externalCredential{ + credential := &externalCredential{ tag: tag, token: options.Token, pollInterval: pollInterval, @@ -127,12 +127,12 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx if options.URL == "" { // Receiver mode: no URL, wait for reverse connection - cred.baseURL = reverseProxyBaseURL - cred.forwardHTTPClient = &http.Client{ + credential.baseURL = reverseProxyBaseURL + credential.forwardHTTPClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return cred.openReverseConnection(ctx) + return credential.openReverseConnection(ctx) }, }, } @@ -173,34 +173,34 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx } } - cred.baseURL = externalCredentialBaseURL(parsedURL) + credential.baseURL = externalCredentialBaseURL(parsedURL) if options.Reverse { // Connector mode: we dial out to serve, not to proxy - cred.connectorDialer = credentialDialer + credential.connectorDialer = credentialDialer if options.Server != "" { - cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) + credential.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) } else { - cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL)) + credential.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL)) } - cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse") - cred.connectorURL = parsedURL + credential.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse") + credential.connectorURL = parsedURL if parsedURL.Scheme == "https" { - cred.connectorTLS = &stdTLS.Config{ + credential.connectorTLS = &stdTLS.Config{ ServerName: parsedURL.Hostname(), RootCAs: adapter.RootPoolFromContext(ctx), Time: ntp.TimeFuncFromContext(ctx), } } - cred.forwardHTTPClient = &http.Client{Transport: transport} + credential.forwardHTTPClient = &http.Client{Transport: transport} } else { // Normal mode: standard HTTP client for proxying - cred.forwardHTTPClient = &http.Client{Transport: transport} - cred.reverseHttpClient = &http.Client{ + credential.forwardHTTPClient = &http.Client{Transport: transport} + credential.reverseHTTPClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return cred.openReverseConnection(ctx) + return credential.openReverseConnection(ctx) }, }, } @@ -208,7 +208,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx } if options.UsagesPath != "" { - cred.usageTracker = &AggregatedUsage{ + credential.usageTracker = &AggregatedUsage{ LastUpdated: time.Now(), Combinations: make([]CostCombination, 0), filePath: options.UsagesPath, @@ -216,7 +216,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx } } - return cred, nil + return credential, nil } func (c *externalCredential) start() error { @@ -352,7 +352,7 @@ func (c *externalCredential) getAccessToken() (string, error) { func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) { baseURL := c.baseURL - if c.reverseHttpClient != nil { + if c.reverseHTTPClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { baseURL = reverseProxyBaseURL @@ -511,7 +511,7 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp } } // Try reverse transport first (single attempt, no retry) - if c.reverseHttpClient != nil { + if c.reverseHTTPClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { request, err := buildRequest(reverseProxyBaseURL)() @@ -519,7 +519,7 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp return nil, err } reverseClient := &http.Client{ - Transport: c.reverseHttpClient.Transport, + Transport: c.reverseHTTPClient.Transport, Timeout: 5 * time.Second, } response, err := reverseClient.Do(request) @@ -660,10 +660,10 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { } func (c *externalCredential) httpClient() *http.Client { - if c.reverseHttpClient != nil { + if c.reverseHTTPClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { - return c.reverseHttpClient + return c.reverseHTTPClient } } return c.forwardHTTPClient diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go index cd77bfcdc1..5500df6a14 100644 --- a/service/ccm/credential_provider.go +++ b/service/ccm/credential_provider.go @@ -13,29 +13,29 @@ import ( ) type credentialProvider interface { - selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential - linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool + selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) + onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential + linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool pollIfStale(ctx context.Context) - allCredentials() []credential + allCredentials() []Credential close() } type singleCredentialProvider struct { - cred credential + credential Credential sessionAccess sync.RWMutex sessions map[string]time.Time } -func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if !selection.allows(p.cred) { - return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) { + if !selection.allows(p.credential) { + return nil, false, E.New("credential ", p.credential.tagName(), " is filtered out") } - if !p.cred.isAvailable() { - return nil, false, p.cred.unavailableError() + if !p.credential.isAvailable() { + return nil, false, p.credential.unavailableError() } - if !p.cred.isUsable() { - return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") + if !p.credential.isUsable() { + return nil, false, E.New("credential ", p.credential.tagName(), " is rate-limited") } var isNew bool if sessionID != "" { @@ -50,11 +50,11 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, selection } p.sessionAccess.Unlock() } - return p.cred, isNew, nil + return p.credential, isNew, nil } -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { - cred.markRateLimited(resetAt) +func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential, resetAt time.Time, _ credentialSelection) Credential { + credential.markRateLimited(resetAt) return nil } @@ -68,16 +68,16 @@ func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { } p.sessionAccess.Unlock() - if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { - p.cred.pollUsage(ctx) + if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) { + p.credential.pollUsage(ctx) } } -func (p *singleCredentialProvider) allCredentials() []credential { - return []credential{p.cred} +func (p *singleCredentialProvider) allCredentials() []Credential { + return []Credential{p.credential} } -func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { +func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credentialSelection, _ func()) func() bool { return func() bool { return false } @@ -102,7 +102,7 @@ type credentialInterruptEntry struct { } type balancerProvider struct { - credentials []credential + credentials []Credential strategy string roundRobinIndex atomic.Uint64 pollInterval time.Duration @@ -114,7 +114,7 @@ type balancerProvider struct { logger log.ContextLogger } -func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { +func newBalancerProvider(credentials []Credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval } @@ -129,7 +129,7 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval } } -func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) { if p.strategy == C.BalancerStrategyFallback { best := p.pickCredential(selection.filter) if best == nil { @@ -145,23 +145,23 @@ func (p *balancerProvider) selectCredential(sessionID string, selection credenti p.sessionAccess.RUnlock() if exists { if entry.selectionScope == selectionScope { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() { + for _, credential := range p.credentials { + if credential.tagName() == entry.tag && selection.allows(credential) && credential.isUsable() { if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { better := p.pickLeastUsed(selection.filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() + if better != nil && better.tagName() != credential.tagName() { + effectiveThreshold := p.rebalanceThreshold / credential.planWeight() + delta := credential.weeklyUtilization() - better.weeklyUtilization() if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), + p.logger.Info("rebalancing away from ", credential.tagName(), ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName(), selectionScope) + effectiveThreshold, "% (weight ", credential.planWeight(), ")") + p.rebalanceCredential(credential.tagName(), selectionScope) break } } } - return cred, false, nil + return credential, false, nil } } } @@ -208,12 +208,12 @@ func (p *balancerProvider) rebalanceCredential(tag string, selectionScope creden p.sessionAccess.Unlock() } -func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { +func (p *balancerProvider) linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool { if p.strategy == C.BalancerStrategyFallback { return func() bool { return false } } key := credentialInterruptKey{ - tag: cred.tagName(), + tag: credential.tagName(), selectionScope: selection.scopeOrDefault(), } p.interruptAccess.Lock() @@ -227,8 +227,8 @@ func (p *balancerProvider) linkProviderInterrupt(cred credential, selection cred return context.AfterFunc(entry.context, onInterrupt) } -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { - cred.markRateLimited(resetAt) +func (p *balancerProvider) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential { + credential.markRateLimited(resetAt) if p.strategy == C.BalancerStrategyFallback { return p.pickCredential(selection.filter) } @@ -251,7 +251,7 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese return best } -func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { +func (p *balancerProvider) pickCredential(filter func(Credential) bool) Credential { switch p.strategy { case C.BalancerStrategyRoundRobin: return p.pickRoundRobin(filter) @@ -264,13 +264,13 @@ func (p *balancerProvider) pickCredential(filter func(credential) bool) credenti } } -func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { +func (p *balancerProvider) pickFallback(filter func(Credential) bool) Credential { + for _, credential := range p.credentials { + if filter != nil && !filter(credential) { continue } - if cred.isUsable() { - return cred + if credential.isUsable() { + return credential } } return nil @@ -278,20 +278,20 @@ func (p *balancerProvider) pickFallback(filter func(credential) bool) credential const weeklyWindowHours = 7 * 24 -func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { - var best credential +func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credential { + var best Credential bestScore := float64(-1) now := time.Now() - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { + for _, credential := range p.credentials { + if filter != nil && !filter(credential) { continue } - if !cred.isUsable() { + if !credential.isUsable() { continue } - remaining := cred.weeklyCap() - cred.weeklyUtilization() - score := remaining * cred.planWeight() - resetTime := cred.weeklyResetTime() + remaining := credential.weeklyCap() - credential.weeklyUtilization() + score := remaining * credential.planWeight() + resetTime := credential.weeklyResetTime() if !resetTime.IsZero() { timeUntilReset := resetTime.Sub(now) if timeUntilReset < time.Hour { @@ -301,13 +301,13 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia } if score > bestScore { bestScore = score - best = cred + best = credential } } return best } -func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { +func (p *balancerProvider) pickRoundRobin(filter func(Credential) bool) Credential { start := int(p.roundRobinIndex.Add(1) - 1) count := len(p.credentials) for offset := range count { @@ -322,8 +322,8 @@ func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credenti return nil } -func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { - var usable []credential +func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential { + var usable []Credential for _, candidate := range p.credentials { if filter != nil && !filter(candidate) { continue @@ -348,14 +348,14 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) { } p.sessionAccess.Unlock() - for _, cred := range p.credentials { - if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { - cred.pollUsage(ctx) + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) } } } -func (p *balancerProvider) allCredentials() []credential { +func (p *balancerProvider) allCredentials() []Credential { return p.credentials } @@ -382,15 +382,15 @@ func ccmPlanWeight(accountType string, rateLimitTier string) float64 { } } -func allCredentialsUnavailableError(credentials []credential) error { +func allCredentialsUnavailableError(credentials []Credential) error { var hasUnavailable bool var earliest time.Time - for _, cred := range credentials { - if cred.unavailableError() != nil { + for _, credential := range credentials { + if credential.unavailableError() != nil { hasUnavailable = true continue } - resetAt := cred.earliestReset() + resetAt := credential.earliestReset() if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { earliest = resetAt } diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 6ecc224f9a..97ef1751c1 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -4,7 +4,6 @@ import ( "bufio" "context" stdTLS "crypto/tls" - "errors" "io" "math/rand/v2" "net" @@ -124,13 +123,13 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite } func (s *Service) findReceiverCredential(token string) *externalCredential { - for _, cred := range s.allCredentials { - extCred, ok := cred.(*externalCredential) - if !ok || extCred.connectorURL != nil { + for _, credential := range s.allCredentials { + external, ok := credential.(*externalCredential) + if !ok || external.connectorURL != nil { continue } - if extCred.token == token { - return extCred + if external.token == token { + return external } } return nil @@ -248,7 +247,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio } err = httpServer.Serve(&yamuxNetListener{session: session}) sessionLifetime := time.Since(serveStart) - if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil { + if err != nil && !E.IsClosed(err) && ctx.Err() == nil { return sessionLifetime, E.Cause(err, "serve") } return sessionLifetime, E.New("connection closed") diff --git a/service/ccm/service.go b/service/ccm/service.go index 6dce1931bd..69964c02c7 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -3,7 +3,6 @@ package ccm import ( "context" "encoding/json" - "errors" "net/http" "strings" "sync" @@ -55,18 +54,18 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro }) } -func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool { +func hasAlternativeCredential(provider credentialProvider, currentCredential Credential, selection credentialSelection) bool { if provider == nil || currentCredential == nil { return false } - for _, cred := range provider.allCredentials() { - if cred == currentCredential { + for _, credential := range provider.allCredentials() { + if credential == currentCredential { continue } - if !selection.allows(cred) { + if !selection.allows(credential) { continue } - if cred.isUsable() { + if credential.isUsable() { return true } } @@ -96,7 +95,7 @@ func writeCredentialUnavailableError( w http.ResponseWriter, r *http.Request, provider credentialProvider, - currentCredential credential, + currentCredential Credential, selection credentialSelection, fallback string, ) { @@ -111,8 +110,8 @@ func credentialSelectionForUser(userConfig *option.CCMUser) credentialSelection selection := credentialSelection{scope: credentialSelectionScopeAll} if userConfig != nil && !userConfig.AllowExternalUsage { selection.scope = credentialSelectionScopeNonExternal - selection.filter = func(cred credential) bool { - return !cred.isExternal() + selection.filter = func(credential Credential) bool { + return !credential.isExternal() } } return selection @@ -159,7 +158,7 @@ type Service struct { // Multi-credential mode providers map[string]credentialProvider - allCredentials []credential + allCredentials []Credential userConfigMap map[string]*option.CCMUser } @@ -204,7 +203,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio } service.userConfigMap = userConfigMap } else { - cred, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{ + credential, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{ CredentialPath: options.CredentialPath, UsagesPath: options.UsagesPath, Detour: options.Detour, @@ -212,9 +211,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio if err != nil { return nil, err } - service.legacyCredential = cred - service.legacyProvider = &singleCredentialProvider{cred: cred} - service.allCredentials = []credential{cred} + service.legacyCredential = credential + service.legacyProvider = &singleCredentialProvider{credential: credential} + service.allCredentials = []Credential{credential} } if options.TLS != nil { @@ -235,11 +234,11 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) - for _, cred := range s.allCredentials { - if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { - extCred.reverseService = s + for _, credential := range s.allCredentials { + if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil { + external.reverseService = s } - err := cred.start() + err := credential.start() if err != nil { return err } @@ -271,7 +270,7 @@ func (s *Service) Start(stage adapter.StartStage) error { go func() { serveErr := s.httpServer.Serve(tcpListener) - if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + if serveErr != nil && !E.IsClosed(serveErr) { s.logger.Error("serve error: ", serveErr) } }() @@ -280,15 +279,15 @@ func (s *Service) Start(stage adapter.StartStage) error { } func (s *Service) InterfaceUpdated() { - for _, cred := range s.allCredentials { - extCred, ok := cred.(*externalCredential) + for _, credential := range s.allCredentials { + external, ok := credential.(*externalCredential) if !ok { continue } - if extCred.reverse && extCred.connectorURL != nil { - extCred.reverseService = s - extCred.resetReverseContext() - go extCred.connectorLoop() + if external.reverse && external.connectorURL != nil { + external.reverseService = s + external.resetReverseContext() + go external.connectorLoop() } } } @@ -300,8 +299,8 @@ func (s *Service) Close() error { s.tlsConfig, ) - for _, cred := range s.allCredentials { - cred.close() + for _, credential := range s.allCredentials { + credential.close() } return err diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 7dd0c64115..7a59cfe4a2 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -14,6 +14,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" "github.com/anthropics/anthropic-sdk-go" ) @@ -336,6 +337,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if n > 0 { _, writeError := w.Write(buffer[:n]) if writeError != nil { + if E.IsClosedOrCanceled(writeError) { + return + } s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } @@ -462,6 +466,9 @@ func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.Re _, writeError := writer.Write(buffer[:n]) if writeError != nil { + if E.IsClosedOrCanceled(writeError) { + return + } s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index 3f91b46149..75929c59f5 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -62,22 +62,22 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 - for _, cred := range provider.allCredentials() { - if !cred.isAvailable() { + for _, credential := range provider.allCredentials() { + if !credential.isAvailable() { continue } - if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { + if userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { continue } - if !userConfig.AllowExternalUsage && cred.isExternal() { + if !userConfig.AllowExternalUsage && credential.isExternal() { continue } - weight := cred.planWeight() - remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + weight := credential.planWeight() + remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization() if remaining5h < 0 { remaining5h = 0 } - remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + remainingWeekly := credential.weeklyCap() - credential.weeklyUtilization() if remainingWeekly < 0 { remainingWeekly = 0 } diff --git a/service/ccm/service_usage.go b/service/ccm/service_usage.go index 36e9ee65df..e23db66542 100644 --- a/service/ccm/service_usage.go +++ b/service/ccm/service_usage.go @@ -35,13 +35,13 @@ type CostCombination struct { type AggregatedUsage struct { LastUpdated time.Time `json:"last_updated"` Combinations []CostCombination `json:"combinations"` - mutex sync.Mutex + access sync.Mutex filePath string logger log.ContextLogger lastSaveTime time.Time pendingSave bool saveTimer *time.Timer - saveMutex sync.Mutex + saveAccess sync.Mutex } type UsageStatsJSON struct { @@ -527,8 +527,8 @@ func deriveWeekStartUnix(cycleHint *WeeklyCycleHint) int64 { } func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON { - u.mutex.Lock() - defer u.mutex.Unlock() + u.access.Lock() + defer u.access.Unlock() result := &AggregatedUsageJSON{ LastUpdated: u.LastUpdated, @@ -561,8 +561,8 @@ func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON { } func (u *AggregatedUsage) Load() error { - u.mutex.Lock() - defer u.mutex.Unlock() + u.access.Lock() + defer u.access.Unlock() u.LastUpdated = time.Time{} u.Combinations = nil @@ -608,9 +608,9 @@ func (u *AggregatedUsage) Save() error { defer os.Remove(tmpFile) err = os.Rename(tmpFile, u.filePath) if err == nil { - u.saveMutex.Lock() + u.saveAccess.Lock() u.lastSaveTime = time.Now() - u.saveMutex.Unlock() + u.saveAccess.Unlock() } return err } @@ -644,8 +644,8 @@ func (u *AggregatedUsage) AddUsageWithCycleHint( observedAt = time.Now() } - u.mutex.Lock() - defer u.mutex.Unlock() + u.access.Lock() + defer u.access.Unlock() u.LastUpdated = observedAt weekStartUnix := deriveWeekStartUnix(cycleHint) @@ -660,8 +660,8 @@ func (u *AggregatedUsage) AddUsageWithCycleHint( func (u *AggregatedUsage) scheduleSave() { const saveInterval = time.Minute - u.saveMutex.Lock() - defer u.saveMutex.Unlock() + u.saveAccess.Lock() + defer u.saveAccess.Unlock() timeSinceLastSave := time.Since(u.lastSaveTime) @@ -678,9 +678,9 @@ func (u *AggregatedUsage) scheduleSave() { remainingTime := saveInterval - timeSinceLastSave u.saveTimer = time.AfterFunc(remainingTime, func() { - u.saveMutex.Lock() + u.saveAccess.Lock() u.pendingSave = false - u.saveMutex.Unlock() + u.saveAccess.Unlock() u.saveAsync() }) } @@ -695,8 +695,8 @@ func (u *AggregatedUsage) saveAsync() { } func (u *AggregatedUsage) cancelPendingSave() { - u.saveMutex.Lock() - defer u.saveMutex.Unlock() + u.saveAccess.Lock() + defer u.saveAccess.Unlock() if u.saveTimer != nil { u.saveTimer.Stop() diff --git a/service/ccm/service_user.go b/service/ccm/service_user.go index 149894c048..e3f52bdf08 100644 --- a/service/ccm/service_user.go +++ b/service/ccm/service_user.go @@ -7,8 +7,8 @@ import ( ) type UserManager struct { - access sync.RWMutex - tokenMap map[string]string + access sync.RWMutex + tokenMap map[string]string } func (m *UserManager) UpdateUsers(users []option.CCMUser) { diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 27a8894705..e0ad9f5657 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -92,7 +92,7 @@ func (c *credentialRequestContext) cancelRequest() { c.cancelOnce.Do(c.cancelFunc) } -type credential interface { +type Credential interface { tagName() string isAvailable() bool isUsable() bool @@ -139,11 +139,11 @@ const ( type credentialSelection struct { scope credentialSelectionScope - filter func(credential) bool + filter func(Credential) bool } -func (s credentialSelection) allows(cred credential) bool { - return s.filter == nil || s.filter(cred) +func (s credentialSelection) allows(credential Credential) bool { + return s.filter == nil || s.filter(credential) } func (s credentialSelection) scopeOrDefault() credentialSelectionScope { diff --git a/service/ocm/credential_builder.go b/service/ocm/credential_builder.go index 5faaf67c67..e308d04d1b 100644 --- a/service/ocm/credential_builder.go +++ b/service/ocm/credential_builder.go @@ -14,55 +14,55 @@ func buildOCMCredentialProviders( ctx context.Context, options option.OCMServiceOptions, logger log.ContextLogger, -) (map[string]credentialProvider, []credential, error) { - allCredentialMap := make(map[string]credential) - var allCreds []credential +) (map[string]credentialProvider, []Credential, error) { + allCredentialMap := make(map[string]Credential) + var allCredentials []Credential providers := make(map[string]credentialProvider) // Pass 1: create default and external credentials - for _, credOpt := range options.Credentials { - switch credOpt.Type { + for _, credentialOption := range options.Credentials { + switch credentialOption.Type { case "default": - cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + credential, err := newDefaultCredential(ctx, credentialOption.Tag, credentialOption.DefaultOptions, logger) if err != nil { return nil, nil, err } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + allCredentialMap[credentialOption.Tag] = credential + allCredentials = append(allCredentials, credential) + providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential} case "external": - cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) + credential, err := newExternalCredential(ctx, credentialOption.Tag, credentialOption.ExternalOptions, logger) if err != nil { return nil, nil, err } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + allCredentialMap[credentialOption.Tag] = credential + allCredentials = append(allCredentials, credential) + providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential} } } // Pass 2: create balancer providers - for _, credOpt := range options.Credentials { - if credOpt.Type == "balancer" { - subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) + for _, credentialOption := range options.Credentials { + if credentialOption.Type == "balancer" { + subCredentials, err := resolveCredentialTags(credentialOption.BalancerOptions.Credentials, allCredentialMap, credentialOption.Tag) if err != nil { return nil, nil, err } - providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) + providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, time.Duration(credentialOption.BalancerOptions.PollInterval), credentialOption.BalancerOptions.RebalanceThreshold, logger) } } - return providers, allCreds, nil + return providers, allCredentials, nil } -func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { - credentials := make([]credential, 0, len(tags)) +func resolveCredentialTags(tags []string, allCredentials map[string]Credential, parentTag string) ([]Credential, error) { + credentials := make([]Credential, 0, len(tags)) for _, tag := range tags { - cred, exists := allCredentials[tag] + credential, exists := allCredentials[tag] if !exists { return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) } - credentials = append(credentials, cred) + credentials = append(credentials, credential) } if len(credentials) == 0 { return nil, E.New("credential ", parentTag, " has no sub-credentials") @@ -89,48 +89,48 @@ func validateOCMOptions(options option.OCMServiceOptions) error { if hasCredentials { tags := make(map[string]bool) credentialTypes := make(map[string]string) - for _, cred := range options.Credentials { - if tags[cred.Tag] { - return E.New("duplicate credential tag: ", cred.Tag) + for _, credential := range options.Credentials { + if tags[credential.Tag] { + return E.New("duplicate credential tag: ", credential.Tag) } - tags[cred.Tag] = true - credentialTypes[cred.Tag] = cred.Type - if cred.Type == "default" || cred.Type == "" { - if cred.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") + tags[credential.Tag] = true + credentialTypes[credential.Tag] = credential.Type + if credential.Type == "default" || credential.Type == "" { + if credential.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") } - if cred.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") + if credential.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") } - if cred.DefaultOptions.Limit5h > 100 { - return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") + if credential.DefaultOptions.Limit5h > 100 { + return E.New("credential ", credential.Tag, ": limit_5h must be at most 100") } - if cred.DefaultOptions.LimitWeekly > 100 { - return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") + if credential.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100") } - if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { - return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") + if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 { + return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive") } - if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { - return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") + if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") } } - if cred.Type == "external" { - if cred.ExternalOptions.Token == "" { - return E.New("credential ", cred.Tag, ": external credential requires token") + if credential.Type == "external" { + if credential.ExternalOptions.Token == "" { + return E.New("credential ", credential.Tag, ": external credential requires token") } - if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": reverse external credential requires url") + if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" { + return E.New("credential ", credential.Tag, ": reverse external credential requires url") } } - if cred.Type == "balancer" { - switch cred.BalancerOptions.Strategy { + if credential.Type == "balancer" { + switch credential.BalancerOptions.Strategy { case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: default: - return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) + return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) } - if cred.BalancerOptions.RebalanceThreshold < 0 { - return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") + if credential.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative") } } } @@ -160,14 +160,14 @@ func validateOCMCompositeCredentialModes( options option.OCMServiceOptions, providers map[string]credentialProvider, ) error { - for _, credOpt := range options.Credentials { - if credOpt.Type != "balancer" { + for _, credentialOption := range options.Credentials { + if credentialOption.Type != "balancer" { continue } - provider, exists := providers[credOpt.Tag] + provider, exists := providers[credentialOption.Tag] if !exists { - return E.New("unknown credential: ", credOpt.Tag) + return E.New("unknown credential: ", credentialOption.Tag) } for _, subCred := range provider.allCredentials() { @@ -176,7 +176,7 @@ func validateOCMCompositeCredentialModes( } if subCred.ocmIsAPIKeyMode() { return E.New( - "credential ", credOpt.Tag, + "credential ", credentialOption.Tag, " references API key default credential ", subCred.tagName(), "; balancer and fallback only support OAuth default credentials", ) diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 968bf904d7..02675780b0 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -33,7 +33,7 @@ type externalCredential struct { tag string baseURL string token string - credDialer N.Dialer + credentialDialer N.Dialer forwardHTTPClient *http.Client state credentialState stateAccess sync.RWMutex @@ -49,20 +49,20 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseHttpClient *http.Client - reverseCredDialer N.Dialer - reverseSession *yamux.Session - reverseAccess sync.RWMutex - closed bool - reverseContext context.Context - reverseCancel context.CancelFunc - connectorDialer N.Dialer - connectorDestination M.Socksaddr - connectorRequestPath string - connectorURL *url.URL - connectorTLS *stdTLS.Config - reverseService http.Handler + reverse bool + reverseHTTPClient *http.Client + reverseCredentialDialer N.Dialer + reverseSession *yamux.Session + reverseAccess sync.RWMutex + closed bool + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorDestination M.Socksaddr + connectorRequestPath string + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler } type reverseSessionDialer struct { @@ -81,9 +81,9 @@ func (d reverseSessionDialer) ListenPacket(ctx context.Context, destination M.So } func externalCredentialURLPort(parsedURL *url.URL) uint16 { - portStr := parsedURL.Port() - if portStr != "" { - port, err := strconv.ParseUint(portStr, 10, 16) + portString := parsedURL.Port() + if portString != "" { + port, err := strconv.ParseUint(portString, 10, 16) if err == nil { return uint16(port) } @@ -131,7 +131,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) - cred := &externalCredential{ + credential := &externalCredential{ tag: tag, token: options.Token, pollInterval: pollInterval, @@ -145,13 +145,13 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx if options.URL == "" { // Receiver mode: no URL, wait for reverse connection - cred.baseURL = reverseProxyBaseURL - cred.credDialer = reverseSessionDialer{credential: cred} - cred.forwardHTTPClient = &http.Client{ + credential.baseURL = reverseProxyBaseURL + credential.credentialDialer = reverseSessionDialer{credential: credential} + credential.forwardHTTPClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return cred.openReverseConnection(ctx) + return credential.openReverseConnection(ctx) }, }, } @@ -192,36 +192,36 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx } } - cred.baseURL = externalCredentialBaseURL(parsedURL) + credential.baseURL = externalCredentialBaseURL(parsedURL) if options.Reverse { // Connector mode: we dial out to serve, not to proxy - cred.connectorDialer = credentialDialer + credential.connectorDialer = credentialDialer if options.Server != "" { - cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) + credential.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) } else { - cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL)) + credential.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL)) } - cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ocm/v1/reverse") - cred.connectorURL = parsedURL + credential.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ocm/v1/reverse") + credential.connectorURL = parsedURL if parsedURL.Scheme == "https" { - cred.connectorTLS = &stdTLS.Config{ + credential.connectorTLS = &stdTLS.Config{ ServerName: parsedURL.Hostname(), RootCAs: adapter.RootPoolFromContext(ctx), Time: ntp.TimeFuncFromContext(ctx), } } - cred.forwardHTTPClient = &http.Client{Transport: transport} + credential.forwardHTTPClient = &http.Client{Transport: transport} } else { // Normal mode: standard HTTP client for proxying - cred.credDialer = credentialDialer - cred.forwardHTTPClient = &http.Client{Transport: transport} - cred.reverseCredDialer = reverseSessionDialer{credential: cred} - cred.reverseHttpClient = &http.Client{ + credential.credentialDialer = credentialDialer + credential.forwardHTTPClient = &http.Client{Transport: transport} + credential.reverseCredentialDialer = reverseSessionDialer{credential: credential} + credential.reverseHTTPClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return cred.openReverseConnection(ctx) + return credential.openReverseConnection(ctx) }, }, } @@ -229,7 +229,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx } if options.UsagesPath != "" { - cred.usageTracker = &AggregatedUsage{ + credential.usageTracker = &AggregatedUsage{ LastUpdated: time.Now(), Combinations: make([]CostCombination, 0), filePath: options.UsagesPath, @@ -237,7 +237,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx } } - return cred, nil + return credential, nil } func (c *externalCredential) start() error { @@ -376,7 +376,7 @@ func (c *externalCredential) getAccessToken() (string, error) { func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) { baseURL := c.baseURL - if c.reverseHttpClient != nil { + if c.reverseHTTPClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { baseURL = reverseProxyBaseURL @@ -550,7 +550,7 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp } } // Try reverse transport first (single attempt, no retry) - if c.reverseHttpClient != nil { + if c.reverseHTTPClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { request, err := buildRequest(reverseProxyBaseURL)() @@ -558,7 +558,7 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp return nil, err } reverseClient := &http.Client{ - Transport: c.reverseHttpClient.Transport, + Transport: c.reverseHTTPClient.Transport, Timeout: 5 * time.Second, } response, err := reverseClient.Do(request) @@ -699,23 +699,23 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { } func (c *externalCredential) httpClient() *http.Client { - if c.reverseHttpClient != nil { + if c.reverseHTTPClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { - return c.reverseHttpClient + return c.reverseHTTPClient } } return c.forwardHTTPClient } func (c *externalCredential) ocmDialer() N.Dialer { - if c.reverseCredDialer != nil { + if c.reverseCredentialDialer != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { - return c.reverseCredDialer + return c.reverseCredentialDialer } } - return c.credDialer + return c.credentialDialer } func (c *externalCredential) ocmIsAPIKeyMode() bool { @@ -727,7 +727,7 @@ func (c *externalCredential) ocmGetAccountID() string { } func (c *externalCredential) ocmGetBaseURL() string { - if c.reverseHttpClient != nil { + if c.reverseHTTPClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { return reverseProxyBaseURL diff --git a/service/ocm/credential_provider.go b/service/ocm/credential_provider.go index 53383e3686..6f3da6b43c 100644 --- a/service/ocm/credential_provider.go +++ b/service/ocm/credential_provider.go @@ -13,29 +13,29 @@ import ( ) type credentialProvider interface { - selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential - linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool + selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) + onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential + linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool pollIfStale(ctx context.Context) - allCredentials() []credential + allCredentials() []Credential close() } type singleCredentialProvider struct { - cred credential + credential Credential sessionAccess sync.RWMutex sessions map[string]time.Time } -func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if !selection.allows(p.cred) { - return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) { + if !selection.allows(p.credential) { + return nil, false, E.New("credential ", p.credential.tagName(), " is filtered out") } - if !p.cred.isAvailable() { - return nil, false, p.cred.unavailableError() + if !p.credential.isAvailable() { + return nil, false, p.credential.unavailableError() } - if !p.cred.isUsable() { - return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") + if !p.credential.isUsable() { + return nil, false, E.New("credential ", p.credential.tagName(), " is rate-limited") } var isNew bool if sessionID != "" { @@ -50,11 +50,11 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, selection } p.sessionAccess.Unlock() } - return p.cred, isNew, nil + return p.credential, isNew, nil } -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { - cred.markRateLimited(resetAt) +func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential, resetAt time.Time, _ credentialSelection) Credential { + credential.markRateLimited(resetAt) return nil } @@ -68,16 +68,16 @@ func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { } p.sessionAccess.Unlock() - if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { - p.cred.pollUsage(ctx) + if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) { + p.credential.pollUsage(ctx) } } -func (p *singleCredentialProvider) allCredentials() []credential { - return []credential{p.cred} +func (p *singleCredentialProvider) allCredentials() []Credential { + return []Credential{p.credential} } -func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { +func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credentialSelection, _ func()) func() bool { return func() bool { return false } @@ -102,7 +102,7 @@ type credentialInterruptEntry struct { } type balancerProvider struct { - credentials []credential + credentials []Credential strategy string roundRobinIndex atomic.Uint64 pollInterval time.Duration @@ -114,11 +114,11 @@ type balancerProvider struct { logger log.ContextLogger } -func compositeCredentialSelectable(cred credential) bool { - return !cred.ocmIsAPIKeyMode() +func compositeCredentialSelectable(credential Credential) bool { + return !credential.ocmIsAPIKeyMode() } -func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { +func newBalancerProvider(credentials []Credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { if pollInterval <= 0 { pollInterval = defaultPollInterval } @@ -133,7 +133,7 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval } } -func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) { if p.strategy == C.BalancerStrategyFallback { best := p.pickCredential(selection.filter) if best == nil { @@ -149,23 +149,23 @@ func (p *balancerProvider) selectCredential(sessionID string, selection credenti p.sessionAccess.RUnlock() if exists { if entry.selectionScope == selectionScope { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() { + for _, credential := range p.credentials { + if credential.tagName() == entry.tag && compositeCredentialSelectable(credential) && selection.allows(credential) && credential.isUsable() { if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { better := p.pickLeastUsed(selection.filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() + if better != nil && better.tagName() != credential.tagName() { + effectiveThreshold := p.rebalanceThreshold / credential.planWeight() + delta := credential.weeklyUtilization() - better.weeklyUtilization() if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), + p.logger.Info("rebalancing away from ", credential.tagName(), ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName(), selectionScope) + effectiveThreshold, "% (weight ", credential.planWeight(), ")") + p.rebalanceCredential(credential.tagName(), selectionScope) break } } } - return cred, false, nil + return credential, false, nil } } } @@ -212,12 +212,12 @@ func (p *balancerProvider) rebalanceCredential(tag string, selectionScope creden p.sessionAccess.Unlock() } -func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { +func (p *balancerProvider) linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool { if p.strategy == C.BalancerStrategyFallback { return func() bool { return false } } key := credentialInterruptKey{ - tag: cred.tagName(), + tag: credential.tagName(), selectionScope: selection.scopeOrDefault(), } p.interruptAccess.Lock() @@ -231,8 +231,8 @@ func (p *balancerProvider) linkProviderInterrupt(cred credential, selection cred return context.AfterFunc(entry.context, onInterrupt) } -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { - cred.markRateLimited(resetAt) +func (p *balancerProvider) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential { + credential.markRateLimited(resetAt) if p.strategy == C.BalancerStrategyFallback { return p.pickCredential(selection.filter) } @@ -255,7 +255,7 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese return best } -func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { +func (p *balancerProvider) pickCredential(filter func(Credential) bool) Credential { switch p.strategy { case C.BalancerStrategyRoundRobin: return p.pickRoundRobin(filter) @@ -268,16 +268,16 @@ func (p *balancerProvider) pickCredential(filter func(credential) bool) credenti } } -func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { +func (p *balancerProvider) pickFallback(filter func(Credential) bool) Credential { + for _, credential := range p.credentials { + if filter != nil && !filter(credential) { continue } - if !compositeCredentialSelectable(cred) { + if !compositeCredentialSelectable(credential) { continue } - if cred.isUsable() { - return cred + if credential.isUsable() { + return credential } } return nil @@ -285,23 +285,23 @@ func (p *balancerProvider) pickFallback(filter func(credential) bool) credential const weeklyWindowHours = 7 * 24 -func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { - var best credential +func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credential { + var best Credential bestScore := float64(-1) now := time.Now() - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { + for _, credential := range p.credentials { + if filter != nil && !filter(credential) { continue } - if !compositeCredentialSelectable(cred) { + if !compositeCredentialSelectable(credential) { continue } - if !cred.isUsable() { + if !credential.isUsable() { continue } - remaining := cred.weeklyCap() - cred.weeklyUtilization() - score := remaining * cred.planWeight() - resetTime := cred.weeklyResetTime() + remaining := credential.weeklyCap() - credential.weeklyUtilization() + score := remaining * credential.planWeight() + resetTime := credential.weeklyResetTime() if !resetTime.IsZero() { timeUntilReset := resetTime.Sub(now) if timeUntilReset < time.Hour { @@ -311,7 +311,7 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia } if score > bestScore { bestScore = score - best = cred + best = credential } } return best @@ -328,7 +328,7 @@ func ocmPlanWeight(accountType string) float64 { } } -func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { +func (p *balancerProvider) pickRoundRobin(filter func(Credential) bool) Credential { start := int(p.roundRobinIndex.Add(1) - 1) count := len(p.credentials) for offset := range count { @@ -346,8 +346,8 @@ func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credenti return nil } -func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { - var usable []credential +func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential { + var usable []Credential for _, candidate := range p.credentials { if filter != nil && !filter(candidate) { continue @@ -375,28 +375,28 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) { } p.sessionAccess.Unlock() - for _, cred := range p.credentials { - if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { - cred.pollUsage(ctx) + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) } } } -func (p *balancerProvider) allCredentials() []credential { +func (p *balancerProvider) allCredentials() []Credential { return p.credentials } func (p *balancerProvider) close() {} -func allRateLimitedError(credentials []credential) error { +func allRateLimitedError(credentials []Credential) error { var hasUnavailable bool var earliest time.Time - for _, cred := range credentials { - if cred.unavailableError() != nil { + for _, credential := range credentials { + if credential.unavailableError() != nil { hasUnavailable = true continue } - resetAt := cred.earliestReset() + resetAt := credential.earliestReset() if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { earliest = resetAt } diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index ab99c77a6e..494cb47162 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -4,7 +4,6 @@ import ( "bufio" "context" stdTLS "crypto/tls" - "errors" "io" "math/rand/v2" "net" @@ -124,13 +123,13 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite } func (s *Service) findReceiverCredential(token string) *externalCredential { - for _, cred := range s.allCredentials { - extCred, ok := cred.(*externalCredential) - if !ok || extCred.connectorURL != nil { + for _, credential := range s.allCredentials { + external, ok := credential.(*externalCredential) + if !ok || external.connectorURL != nil { continue } - if extCred.token == token { - return extCred + if external.token == token { + return external } } return nil @@ -248,7 +247,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio } err = httpServer.Serve(&yamuxNetListener{session: session}) sessionLifetime := time.Since(serveStart) - if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil { + if err != nil && !E.IsClosed(err) && ctx.Err() == nil { return sessionLifetime, E.Cause(err, "serve") } return sessionLifetime, E.New("connection closed") diff --git a/service/ocm/service.go b/service/ocm/service.go index 101f904926..272bbb3a57 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -3,7 +3,6 @@ package ocm import ( "context" "encoding/json" - "errors" "io" "net/http" "strings" @@ -68,18 +67,18 @@ const ( retryableUsageCode = "credential_usage_exhausted" ) -func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool { +func hasAlternativeCredential(provider credentialProvider, currentCredential Credential, selection credentialSelection) bool { if provider == nil || currentCredential == nil { return false } - for _, cred := range provider.allCredentials() { - if cred == currentCredential { + for _, credential := range provider.allCredentials() { + if credential == currentCredential { continue } - if !selection.allows(cred) { + if !selection.allows(credential) { continue } - if cred.isUsable() { + if credential.isUsable() { return true } } @@ -109,7 +108,7 @@ func writeCredentialUnavailableError( w http.ResponseWriter, r *http.Request, provider credentialProvider, - currentCredential credential, + currentCredential Credential, selection credentialSelection, fallback string, ) { @@ -124,8 +123,8 @@ func credentialSelectionForUser(userConfig *option.OCMUser) credentialSelection selection := credentialSelection{scope: credentialSelectionScopeAll} if userConfig != nil && !userConfig.AllowExternalUsage { selection.scope = credentialSelectionScopeNonExternal - selection.filter = func(cred credential) bool { - return !cred.isExternal() + selection.filter = func(credential Credential) bool { + return !credential.isExternal() } } return selection @@ -174,7 +173,7 @@ type Service struct { // Multi-credential mode providers map[string]credentialProvider - allCredentials []credential + allCredentials []Credential userConfigMap map[string]*option.OCMUser } @@ -218,7 +217,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio } service.userConfigMap = userConfigMap } else { - cred, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{ + credential, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{ CredentialPath: options.CredentialPath, UsagesPath: options.UsagesPath, Detour: options.Detour, @@ -226,9 +225,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio if err != nil { return nil, err } - service.legacyCredential = cred - service.legacyProvider = &singleCredentialProvider{cred: cred} - service.allCredentials = []credential{cred} + service.legacyCredential = credential + service.legacyProvider = &singleCredentialProvider{credential: credential} + service.allCredentials = []Credential{credential} } if options.TLS != nil { @@ -249,16 +248,16 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) - for _, cred := range s.allCredentials { - if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { - extCred.reverseService = s + for _, credential := range s.allCredentials { + if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil { + external.reverseService = s } - err := cred.start() + err := credential.start() if err != nil { return err } - tag := cred.tagName() - cred.setOnBecameUnusable(func() { + tag := credential.tagName() + credential.setOnBecameUnusable(func() { s.interruptWebSocketSessionsForCredential(tag) }) } @@ -295,7 +294,7 @@ func (s *Service) Start(stage adapter.StartStage) error { go func() { serveErr := s.httpServer.Serve(tcpListener) - if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + if serveErr != nil && !E.IsClosed(serveErr) { s.logger.Error("serve error: ", serveErr) } }() @@ -304,15 +303,15 @@ func (s *Service) Start(stage adapter.StartStage) error { } func (s *Service) InterfaceUpdated() { - for _, cred := range s.allCredentials { - extCred, ok := cred.(*externalCredential) + for _, credential := range s.allCredentials { + external, ok := credential.(*externalCredential) if !ok { continue } - if extCred.reverse && extCred.connectorURL != nil { - extCred.reverseService = s - extCred.resetReverseContext() - go extCred.connectorLoop() + if external.reverse && external.connectorURL != nil { + external.reverseService = s + external.resetReverseContext() + go external.connectorLoop() } } } @@ -330,8 +329,8 @@ func (s *Service) Close() error { } s.webSocketGroup.Wait() - for _, cred := range s.allCredentials { - cred.close() + for _, credential := range s.allCredentials { + credential.close() } return err diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 9fb9c96d77..1a247d6cc4 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -318,6 +318,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if n > 0 { _, writeError := w.Write(buffer[:n]) if writeError != nil { + if E.IsClosedOrCanceled(writeError) { + return + } s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } @@ -471,6 +474,9 @@ func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.Re _, writeError := writer.Write(buffer[:n]) if writeError != nil { + if E.IsClosedOrCanceled(writeError) { + return + } s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index 29b95d063b..915fb837db 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -62,22 +62,22 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 - for _, cred := range provider.allCredentials() { - if !cred.isAvailable() { + for _, credential := range provider.allCredentials() { + if !credential.isAvailable() { continue } - if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { + if userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { continue } - if !userConfig.AllowExternalUsage && cred.isExternal() { + if !userConfig.AllowExternalUsage && credential.isExternal() { continue } - weight := cred.planWeight() - remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + weight := credential.planWeight() + remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization() if remaining5h < 0 { remaining5h = 0 } - remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + remainingWeekly := credential.weeklyCap() - credential.weeklyUtilization() if remainingWeekly < 0 { remainingWeekly = 0 } diff --git a/service/ocm/service_usage.go b/service/ocm/service_usage.go index 589fd093a6..19a853a7cd 100644 --- a/service/ocm/service_usage.go +++ b/service/ocm/service_usage.go @@ -55,13 +55,13 @@ type CostCombination struct { type AggregatedUsage struct { LastUpdated time.Time `json:"last_updated"` Combinations []CostCombination `json:"combinations"` - mutex sync.Mutex + access sync.Mutex filePath string logger log.ContextLogger lastSaveTime time.Time pendingSave bool saveTimer *time.Timer - saveMutex sync.Mutex + saveAccess sync.Mutex } type UsageStatsJSON struct { @@ -1035,8 +1035,8 @@ func deriveWeekStartUnix(cycleHint *WeeklyCycleHint) int64 { } func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON { - u.mutex.Lock() - defer u.mutex.Unlock() + u.access.Lock() + defer u.access.Unlock() result := &AggregatedUsageJSON{ LastUpdated: u.LastUpdated, @@ -1069,8 +1069,8 @@ func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON { } func (u *AggregatedUsage) Load() error { - u.mutex.Lock() - defer u.mutex.Unlock() + u.access.Lock() + defer u.access.Unlock() u.LastUpdated = time.Time{} u.Combinations = nil @@ -1116,9 +1116,9 @@ func (u *AggregatedUsage) Save() error { defer os.Remove(tmpFile) err = os.Rename(tmpFile, u.filePath) if err == nil { - u.saveMutex.Lock() + u.saveAccess.Lock() u.lastSaveTime = time.Now() - u.saveMutex.Unlock() + u.saveAccess.Unlock() } return err } @@ -1140,8 +1140,8 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(model string, contextWindow int, observedAt = time.Now() } - u.mutex.Lock() - defer u.mutex.Unlock() + u.access.Lock() + defer u.access.Unlock() u.LastUpdated = observedAt weekStartUnix := deriveWeekStartUnix(cycleHint) @@ -1156,8 +1156,8 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(model string, contextWindow int, func (u *AggregatedUsage) scheduleSave() { const saveInterval = time.Minute - u.saveMutex.Lock() - defer u.saveMutex.Unlock() + u.saveAccess.Lock() + defer u.saveAccess.Unlock() timeSinceLastSave := time.Since(u.lastSaveTime) @@ -1174,9 +1174,9 @@ func (u *AggregatedUsage) scheduleSave() { remainingTime := saveInterval - timeSinceLastSave u.saveTimer = time.AfterFunc(remainingTime, func() { - u.saveMutex.Lock() + u.saveAccess.Lock() u.pendingSave = false - u.saveMutex.Unlock() + u.saveAccess.Unlock() u.saveAsync() }) } @@ -1191,8 +1191,8 @@ func (u *AggregatedUsage) saveAsync() { } func (u *AggregatedUsage) cancelPendingSave() { - u.saveMutex.Lock() - defer u.saveMutex.Unlock() + u.saveAccess.Lock() + defer u.saveAccess.Unlock() if u.saveTimer != nil { u.saveTimer.Stop() diff --git a/service/ocm/service_user.go b/service/ocm/service_user.go index b69655e9ac..5f76808372 100644 --- a/service/ocm/service_user.go +++ b/service/ocm/service_user.go @@ -7,8 +7,8 @@ import ( ) type UserManager struct { - access sync.RWMutex - tokenMap map[string]string + access sync.RWMutex + tokenMap map[string]string } func (m *UserManager) UpdateUsers(users []option.OCMUser) { diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 4b640d9c5c..1b7b5baace 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -98,7 +98,7 @@ func (s *Service) handleWebSocket( sessionID string, userConfig *option.OCMUser, provider credentialProvider, - selectedCredential credential, + selectedCredential Credential, selection credentialSelection, isNew bool, ) { @@ -307,7 +307,7 @@ func (s *Service) handleWebSocket( waitGroup.Wait() } -func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { +func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential Credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { logged := false for { data, opCode, err := wsutil.ReadClientData(clientConn) @@ -359,7 +359,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn } } -func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { @@ -413,7 +413,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe } } -func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential credential) { +func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential Credential) { var rateLimitsEvent struct { RateLimits struct { Primary *struct { @@ -462,7 +462,7 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential selectedCredential.updateStateFromHeaders(headers) } -func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredential credential) { +func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredential Credential) { var errorEvent struct { Headers map[string]string `json:"headers"` } From 56af7313b231bdbf9f5a2f928d396bace1c1b4ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 21:22:24 +0800 Subject: [PATCH 47/96] ccm,ocm: don't treat usage API 429 as account over-limit The usage API itself has rate limits. A 429 from it means "poll less frequently", not that the account exceeded its usage quota. Previously incrementPollFailures() was called, marking the credential unusable and interrupting in-flight connections. Now: parse Retry-After, store as usageAPIRetryDelay, and retry after that delay. The credential stays usable and relies on passive header updates for usage data in the meantime. --- service/ccm/credential.go | 1 + service/ccm/credential_default.go | 18 +++++++++++++++++- service/ocm/credential.go | 1 + service/ocm/credential_default.go | 18 +++++++++++++++++- 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index d5cae9e1e3..b732617177 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -60,6 +60,7 @@ type credentialState struct { remotePlanWeight float64 lastUpdated time.Time consecutivePollFailures int + usageAPIRetryDelay time.Duration unavailable bool lastCredentialLoadAttempt time.Time lastCredentialLoadError string diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index c44ec41036..a6ed7ec872 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -436,8 +436,12 @@ func (c *defaultCredential) incrementPollFailures() { func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { c.stateAccess.RLock() failures := c.state.consecutivePollFailures + retryDelay := c.state.usageAPIRetryDelay c.stateAccess.RUnlock() if failures <= 0 { + if retryDelay > 0 { + return retryDelay + } return baseInterval } backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) @@ -518,7 +522,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { if response.StatusCode != http.StatusOK { if response.StatusCode == http.StatusTooManyRequests { - c.logger.Warn("poll usage for ", c.tag, ": rate limited") + retryDelay := time.Minute + if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" { + seconds, err := strconv.ParseInt(retryAfter, 10, 64) + if err == nil && seconds > 0 { + retryDelay = time.Duration(seconds) * time.Second + } + } + c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay)) + c.stateAccess.Lock() + c.state.usageAPIRetryDelay = retryDelay + c.stateAccess.Unlock() + return } body, _ := io.ReadAll(response.Body) c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) @@ -548,6 +563,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 + c.state.usageAPIRetryDelay = 0 c.state.fiveHourUtilization = usageResponse.FiveHour.Utilization if !usageResponse.FiveHour.ResetsAt.IsZero() { c.state.fiveHourReset = usageResponse.FiveHour.ResetsAt diff --git a/service/ocm/credential.go b/service/ocm/credential.go index e0ad9f5657..80c094cdb6 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -62,6 +62,7 @@ type credentialState struct { remotePlanWeight float64 lastUpdated time.Time consecutivePollFailures int + usageAPIRetryDelay time.Duration unavailable bool lastCredentialLoadAttempt time.Time lastCredentialLoadError string diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index b82af9d20f..70b6eb6c15 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -493,8 +493,12 @@ func (c *defaultCredential) incrementPollFailures() { func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { c.stateAccess.RLock() failures := c.state.consecutivePollFailures + retryDelay := c.state.usageAPIRetryDelay c.stateAccess.RUnlock() if failures <= 0 { + if retryDelay > 0 { + return retryDelay + } return baseInterval } backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) @@ -618,7 +622,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { if response.StatusCode != http.StatusOK { if response.StatusCode == http.StatusTooManyRequests { - c.logger.Warn("poll usage for ", c.tag, ": rate limited") + retryDelay := time.Minute + if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" { + seconds, err := strconv.ParseInt(retryAfter, 10, 64) + if err == nil && seconds > 0 { + retryDelay = time.Duration(seconds) * time.Second + } + } + c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay)) + c.stateAccess.Lock() + c.state.usageAPIRetryDelay = retryDelay + c.stateAccess.Unlock() + return } body, _ := io.ReadAll(response.Body) c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) @@ -649,6 +664,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 + c.state.usageAPIRetryDelay = 0 if usageResponse.RateLimit != nil { if w := usageResponse.RateLimit.PrimaryWindow; w != nil { c.state.fiveHourUtilization = w.UsedPercent From bc6e72408d20991459bde4a2e6b00771c4b7f0e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 22:04:45 +0800 Subject: [PATCH 48/96] ccm,ocm: block API key headers from being forwarded upstream --- service/ccm/credential_default.go | 2 +- service/ccm/credential_external.go | 2 +- service/ccm/reverse.go | 6 ++++++ service/ccm/service.go | 9 +++++++++ service/ccm/service_handler.go | 6 ++++++ service/ccm/service_status.go | 6 ++++++ service/ocm/credential_default.go | 2 +- service/ocm/credential_external.go | 2 +- service/ocm/reverse.go | 6 ++++++ service/ocm/service.go | 9 +++++++++ service/ocm/service_handler.go | 6 ++++++ service/ocm/service_status.go | 6 ++++++ service/ocm/service_websocket.go | 2 ++ 13 files changed, 60 insertions(+), 4 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index a6ed7ec872..0bbf188749 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -715,7 +715,7 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt } for key, values := range original.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" { proxyRequest.Header[key] = values } } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index eb75c5b082..a9876eaf50 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -371,7 +371,7 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht } for key, values := range original.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" { proxyRequest.Header[key] = values } } diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 97ef1751c1..b6b4c88a09 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -57,6 +57,12 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite return } + if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "API key authentication is not supported; use Authorization: Bearer with a CCM user token") + return + } + authHeader := r.Header.Get("Authorization") if authHeader == "" { writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") diff --git a/service/ccm/service.go b/service/ccm/service.go index 69964c02c7..043a147c15 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -139,6 +139,15 @@ func isReverseProxyHeader(header string) bool { } } +func isAPIKeyHeader(header string) bool { + switch strings.ToLower(header) { + case "x-api-key", "api-key": + return true + default: + return false + } +} + type Service struct { boxService.Adapter ctx context.Context diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 7a59cfe4a2..14ae3adeb2 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -104,6 +104,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "API key authentication is not supported; use Authorization: Bearer with a CCM user token") + return + } + var username string if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index 75929c59f5..cfc8c76352 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -20,6 +20,12 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { return } + if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "API key authentication is not supported; use Authorization: Bearer with a CCM user token") + return + } + authHeader := r.Header.Get("Authorization") if authHeader == "" { writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 70b6eb6c15..18612a5690 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -730,7 +730,7 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt } for key, values := range original.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" { proxyRequest.Header[key] = values } } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 02675780b0..bcfe6d234b 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -395,7 +395,7 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht } for key, values := range original.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" { proxyRequest.Header[key] = values } } diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index 494cb47162..b47826e60e 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -57,6 +57,12 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite return } + if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "API key authentication is not supported; use Authorization: Bearer with an OCM user token") + return + } + authHeader := r.Header.Get("Authorization") if authHeader == "" { writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") diff --git a/service/ocm/service.go b/service/ocm/service.go index 272bbb3a57..289152b64b 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -152,6 +152,15 @@ func isReverseProxyHeader(header string) bool { } } +func isAPIKeyHeader(header string) bool { + switch strings.ToLower(header) { + case "x-api-key", "api-key": + return true + default: + return false + } +} + type Service struct { boxService.Adapter ctx context.Context diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 1a247d6cc4..905b1f5aef 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -80,6 +80,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "API key authentication is not supported; use Authorization: Bearer with an OCM user token") + return + } + var username string if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index 915fb837db..e32d8244ed 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -20,6 +20,12 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { return } + if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "API key authentication is not supported; use Authorization: Bearer with an OCM user token") + return + } + authHeader := r.Header.Get("Authorization") if authHeader == "" { writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 1b7b5baace..ce20e1be77 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -82,6 +82,8 @@ func isForwardableWebSocketRequestHeader(key string) bool { switch { case lowerKey == "authorization": return false + case lowerKey == "x-api-key" || lowerKey == "api-key": + return false case strings.HasPrefix(lowerKey, "sec-websocket-"): return false default: From 8e9c61e6247bfa52009d180c17a4b10ae5740fa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 Mar 2026 12:24:46 +0800 Subject: [PATCH 49/96] ccm,ocm: normalize legacy fields into credentials at init, remove dual code path --- service/ccm/credential_builder.go | 150 ++++++++++++---------------- service/ccm/service.go | 58 +++++------ service/ccm/service_handler.go | 4 +- service/ccm/service_status.go | 73 +++++++------- service/ocm/credential_builder.go | 156 ++++++++++++------------------ service/ocm/service.go | 66 ++++++------- service/ocm/service_handler.go | 8 +- service/ocm/service_status.go | 73 +++++++------- 8 files changed, 266 insertions(+), 322 deletions(-) diff --git a/service/ccm/credential_builder.go b/service/ccm/credential_builder.go index 63bfd03951..94e0af957a 100644 --- a/service/ccm/credential_builder.go +++ b/service/ccm/credential_builder.go @@ -71,84 +71,68 @@ func resolveCredentialTags(tags []string, allCredentials map[string]Credential, } func validateCCMOptions(options option.CCMServiceOptions) error { - hasCredentials := len(options.Credentials) > 0 - hasLegacyPath := options.CredentialPath != "" - hasLegacyUsages := options.UsagesPath != "" - hasLegacyDetour := options.Detour != "" - - if hasCredentials && hasLegacyPath { - return E.New("credential_path and credentials are mutually exclusive") - } - if hasCredentials && hasLegacyUsages { - return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") - } - if hasCredentials && hasLegacyDetour { - return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") - } - - if hasCredentials { - tags := make(map[string]bool) - credentialTypes := make(map[string]string) - for _, credential := range options.Credentials { - if tags[credential.Tag] { - return E.New("duplicate credential tag: ", credential.Tag) + tags := make(map[string]bool) + credentialTypes := make(map[string]string) + for _, credential := range options.Credentials { + if tags[credential.Tag] { + return E.New("duplicate credential tag: ", credential.Tag) + } + tags[credential.Tag] = true + credentialTypes[credential.Tag] = credential.Type + if credential.Type == "default" || credential.Type == "" { + if credential.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") + } + if credential.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") } - tags[credential.Tag] = true - credentialTypes[credential.Tag] = credential.Type - if credential.Type == "default" || credential.Type == "" { - if credential.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") - } - if credential.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") - } - if credential.DefaultOptions.Limit5h > 100 { - return E.New("credential ", credential.Tag, ": limit_5h must be at most 100") - } - if credential.DefaultOptions.LimitWeekly > 100 { - return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100") - } - if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 { - return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive") - } - if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 { - return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") - } + if credential.DefaultOptions.Limit5h > 100 { + return E.New("credential ", credential.Tag, ": limit_5h must be at most 100") } - if credential.Type == "external" { - if credential.ExternalOptions.Token == "" { - return E.New("credential ", credential.Tag, ": external credential requires token") - } - if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" { - return E.New("credential ", credential.Tag, ": reverse external credential requires url") - } + if credential.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100") } - if credential.Type == "balancer" { - switch credential.BalancerOptions.Strategy { - case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: - default: - return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) - } - if credential.BalancerOptions.RebalanceThreshold < 0 { - return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative") - } + if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 { + return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive") + } + if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") } } - - for _, user := range options.Users { - if user.Credential == "" { - return E.New("user ", user.Name, " must specify credential in multi-credential mode") + if credential.Type == "external" { + if credential.ExternalOptions.Token == "" { + return E.New("credential ", credential.Tag, ": external credential requires token") } - if !tags[user.Credential] { - return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" { + return E.New("credential ", credential.Tag, ": reverse external credential requires url") + } + } + if credential.Type == "balancer" { + switch credential.BalancerOptions.Strategy { + case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: + default: + return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) } - if user.ExternalCredential != "" { - if !tags[user.ExternalCredential] { - return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) - } - if credentialTypes[user.ExternalCredential] != "external" { - return E.New("user ", user.Name, ": external_credential must reference an external type credential") - } + if credential.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative") + } + } + } + + singleCredential := len(options.Credentials) == 1 + for _, user := range options.Users { + if user.Credential == "" && !singleCredential { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if user.Credential != "" && !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + if user.ExternalCredential != "" { + if !tags[user.ExternalCredential] { + return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) + } + if credentialTypes[user.ExternalCredential] != "external" { + return E.New("user ", user.Name, ": external_credential must reference an external type credential") } } } @@ -159,16 +143,18 @@ func validateCCMOptions(options option.CCMServiceOptions) error { func credentialForUser( userConfigMap map[string]*option.CCMUser, providers map[string]credentialProvider, - legacyProvider credentialProvider, username string, ) (credentialProvider, error) { - if legacyProvider != nil { - return legacyProvider, nil - } userConfig, exists := userConfigMap[username] if !exists { return nil, E.New("no credential mapping for user: ", username) } + if userConfig.Credential == "" { + for _, provider := range providers { + return provider, nil + } + return nil, E.New("no credential available") + } provider, exists := providers[userConfig.Credential] if !exists { return nil, E.New("unknown credential: ", userConfig.Credential) @@ -176,17 +162,3 @@ func credentialForUser( return provider, nil } -func noUserCredentialProvider( - providers map[string]credentialProvider, - legacyProvider credentialProvider, - options option.CCMServiceOptions, -) credentialProvider { - if legacyProvider != nil { - return legacyProvider - } - if len(options.Credentials) > 0 { - tag := options.Credentials[0].Tag - return providers[tag] - } - return nil -} diff --git a/service/ccm/service.go b/service/ccm/service.go index 043a147c15..74952173c0 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -161,11 +161,6 @@ type Service struct { trackingGroup sync.WaitGroup shuttingDown bool - // Legacy mode (single credential) - legacyCredential *defaultCredential - legacyProvider credentialProvider - - // Multi-credential mode providers map[string]credentialProvider allCredentials []Credential userConfigMap map[string]*option.CCMUser @@ -174,6 +169,25 @@ type Service struct { func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) { initCCMUserAgent(logger) + hasLegacy := options.CredentialPath != "" || options.UsagesPath != "" || options.Detour != "" + if hasLegacy && len(options.Credentials) > 0 { + return nil, E.New("credential_path/usages_path/detour and credentials are mutually exclusive") + } + if len(options.Credentials) == 0 { + options.Credentials = []option.CCMCredential{{ + Type: "default", + Tag: "default", + DefaultOptions: option.CCMDefaultCredentialOptions{ + CredentialPath: options.CredentialPath, + UsagesPath: options.UsagesPath, + Detour: options.Detour, + }, + }} + options.CredentialPath = "" + options.UsagesPath = "" + options.Detour = "" + } + err := validateCCMOptions(options) if err != nil { return nil, E.Cause(err, "validate options") @@ -198,32 +212,18 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio userManager: userManager, } - if len(options.Credentials) > 0 { - providers, allCredentials, err := buildCredentialProviders(ctx, options, logger) - if err != nil { - return nil, E.Cause(err, "build credential providers") - } - service.providers = providers - service.allCredentials = allCredentials + providers, allCredentials, err := buildCredentialProviders(ctx, options, logger) + if err != nil { + return nil, E.Cause(err, "build credential providers") + } + service.providers = providers + service.allCredentials = allCredentials - userConfigMap := make(map[string]*option.CCMUser) - for i := range options.Users { - userConfigMap[options.Users[i].Name] = &options.Users[i] - } - service.userConfigMap = userConfigMap - } else { - credential, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{ - CredentialPath: options.CredentialPath, - UsagesPath: options.UsagesPath, - Detour: options.Detour, - }, logger) - if err != nil { - return nil, err - } - service.legacyCredential = credential - service.legacyProvider = &singleCredentialProvider{credential: credential} - service.allCredentials = []Credential{credential} + userConfigMap := make(map[string]*option.CCMUser) + for i := range options.Users { + userConfigMap[options.Users[i].Name] = &options.Users[i] } + service.userConfigMap = userConfigMap if options.TLS != nil { tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 14ae3adeb2..fdbb682033 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -168,14 +168,14 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(s.options.Users) > 0 { userConfig = s.userConfigMap[username] var err error - provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + provider, err = credentialForUser(s.userConfigMap, s.providers, username) if err != nil { s.logger.ErrorContext(ctx, "resolve credential: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) return } } else { - provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + provider = s.providers[s.options.Credentials[0].Tag] } if provider == nil { writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index cfc8c76352..bd8aa4b22b 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -15,42 +15,43 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { return } - if len(s.options.Users) == 0 { - writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") - return - } - - if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "API key authentication is not supported; use Authorization: Bearer with a CCM user token") - return - } + var provider credentialProvider + var userConfig *option.CCMUser + if len(s.options.Users) > 0 { + if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "API key authentication is not supported; use Authorization: Bearer with a CCM user token") + return + } - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - username, ok := s.userManager.Authenticate(clientToken) - if !ok { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + username, ok := s.userManager.Authenticate(clientToken) + if !ok { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } - userConfig := s.userConfigMap[username] - if userConfig == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") - return + userConfig = s.userConfigMap[username] + var err error + provider, err = credentialForUser(s.userConfigMap, s.providers, username) + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = s.providers[s.options.Credentials[0].Tag] } - - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") return } @@ -72,10 +73,10 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user if !credential.isAvailable() { continue } - if userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { + if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { continue } - if !userConfig.AllowExternalUsage && credential.isExternal() { + if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() { continue } weight := credential.planWeight() @@ -100,7 +101,7 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user } func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) + provider, err := credentialForUser(s.userConfigMap, s.providers, userConfig.Name) if err != nil { return } diff --git a/service/ocm/credential_builder.go b/service/ocm/credential_builder.go index e308d04d1b..c800e039d6 100644 --- a/service/ocm/credential_builder.go +++ b/service/ocm/credential_builder.go @@ -71,84 +71,68 @@ func resolveCredentialTags(tags []string, allCredentials map[string]Credential, } func validateOCMOptions(options option.OCMServiceOptions) error { - hasCredentials := len(options.Credentials) > 0 - hasLegacyPath := options.CredentialPath != "" - hasLegacyUsages := options.UsagesPath != "" - hasLegacyDetour := options.Detour != "" - - if hasCredentials && hasLegacyPath { - return E.New("credential_path and credentials are mutually exclusive") - } - if hasCredentials && hasLegacyUsages { - return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") - } - if hasCredentials && hasLegacyDetour { - return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") - } - - if hasCredentials { - tags := make(map[string]bool) - credentialTypes := make(map[string]string) - for _, credential := range options.Credentials { - if tags[credential.Tag] { - return E.New("duplicate credential tag: ", credential.Tag) - } - tags[credential.Tag] = true - credentialTypes[credential.Tag] = credential.Type - if credential.Type == "default" || credential.Type == "" { - if credential.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") - } - if credential.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") - } - if credential.DefaultOptions.Limit5h > 100 { - return E.New("credential ", credential.Tag, ": limit_5h must be at most 100") - } - if credential.DefaultOptions.LimitWeekly > 100 { - return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100") - } - if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 { - return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive") - } - if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 { - return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") - } - } - if credential.Type == "external" { - if credential.ExternalOptions.Token == "" { - return E.New("credential ", credential.Tag, ": external credential requires token") - } - if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" { - return E.New("credential ", credential.Tag, ": reverse external credential requires url") - } - } - if credential.Type == "balancer" { - switch credential.BalancerOptions.Strategy { - case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: - default: - return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) - } - if credential.BalancerOptions.RebalanceThreshold < 0 { - return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative") - } + tags := make(map[string]bool) + credentialTypes := make(map[string]string) + for _, credential := range options.Credentials { + if tags[credential.Tag] { + return E.New("duplicate credential tag: ", credential.Tag) + } + tags[credential.Tag] = true + credentialTypes[credential.Tag] = credential.Type + if credential.Type == "default" || credential.Type == "" { + if credential.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") + } + if credential.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") + } + if credential.DefaultOptions.Limit5h > 100 { + return E.New("credential ", credential.Tag, ": limit_5h must be at most 100") + } + if credential.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100") + } + if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 { + return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive") + } + if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") } } - - for _, user := range options.Users { - if user.Credential == "" { - return E.New("user ", user.Name, " must specify credential in multi-credential mode") + if credential.Type == "external" { + if credential.ExternalOptions.Token == "" { + return E.New("credential ", credential.Tag, ": external credential requires token") + } + if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" { + return E.New("credential ", credential.Tag, ": reverse external credential requires url") + } + } + if credential.Type == "balancer" { + switch credential.BalancerOptions.Strategy { + case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: + default: + return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) } - if !tags[user.Credential] { - return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + if credential.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative") } - if user.ExternalCredential != "" { - if !tags[user.ExternalCredential] { - return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) - } - if credentialTypes[user.ExternalCredential] != "external" { - return E.New("user ", user.Name, ": external_credential must reference an external type credential") - } + } + } + + singleCredential := len(options.Credentials) == 1 + for _, user := range options.Users { + if user.Credential == "" && !singleCredential { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if user.Credential != "" && !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + if user.ExternalCredential != "" { + if !tags[user.ExternalCredential] { + return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) + } + if credentialTypes[user.ExternalCredential] != "external" { + return E.New("user ", user.Name, ": external_credential must reference an external type credential") } } } @@ -190,16 +174,18 @@ func validateOCMCompositeCredentialModes( func credentialForUser( userConfigMap map[string]*option.OCMUser, providers map[string]credentialProvider, - legacyProvider credentialProvider, username string, ) (credentialProvider, error) { - if legacyProvider != nil { - return legacyProvider, nil - } userConfig, exists := userConfigMap[username] if !exists { return nil, E.New("no credential mapping for user: ", username) } + if userConfig.Credential == "" { + for _, provider := range providers { + return provider, nil + } + return nil, E.New("no credential available") + } provider, exists := providers[userConfig.Credential] if !exists { return nil, E.New("unknown credential: ", userConfig.Credential) @@ -207,17 +193,3 @@ func credentialForUser( return provider, nil } -func noUserCredentialProvider( - providers map[string]credentialProvider, - legacyProvider credentialProvider, - options option.OCMServiceOptions, -) credentialProvider { - if legacyProvider != nil { - return legacyProvider - } - if len(options.Credentials) > 0 { - tag := options.Credentials[0].Tag - return providers[tag] - } - return nil -} diff --git a/service/ocm/service.go b/service/ocm/service.go index 289152b64b..641872e5db 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -176,17 +176,31 @@ type Service struct { webSocketConns map[*webSocketSession]struct{} shuttingDown bool - // Legacy mode - legacyCredential *defaultCredential - legacyProvider credentialProvider - - // Multi-credential mode providers map[string]credentialProvider allCredentials []Credential userConfigMap map[string]*option.OCMUser } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) { + hasLegacy := options.CredentialPath != "" || options.UsagesPath != "" || options.Detour != "" + if hasLegacy && len(options.Credentials) > 0 { + return nil, E.New("credential_path/usages_path/detour and credentials are mutually exclusive") + } + if len(options.Credentials) == 0 { + options.Credentials = []option.OCMCredential{{ + Type: "default", + Tag: "default", + DefaultOptions: option.OCMDefaultCredentialOptions{ + CredentialPath: options.CredentialPath, + UsagesPath: options.UsagesPath, + Detour: options.Detour, + }, + }} + options.CredentialPath = "" + options.UsagesPath = "" + options.Detour = "" + } + err := validateOCMOptions(options) if err != nil { return nil, E.Cause(err, "validate options") @@ -212,32 +226,18 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio webSocketConns: make(map[*webSocketSession]struct{}), } - if len(options.Credentials) > 0 { - providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger) - if err != nil { - return nil, E.Cause(err, "build credential providers") - } - service.providers = providers - service.allCredentials = allCredentials + providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger) + if err != nil { + return nil, E.Cause(err, "build credential providers") + } + service.providers = providers + service.allCredentials = allCredentials - userConfigMap := make(map[string]*option.OCMUser) - for i := range options.Users { - userConfigMap[options.Users[i].Name] = &options.Users[i] - } - service.userConfigMap = userConfigMap - } else { - credential, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{ - CredentialPath: options.CredentialPath, - UsagesPath: options.UsagesPath, - Detour: options.Detour, - }, logger) - if err != nil { - return nil, err - } - service.legacyCredential = credential - service.legacyProvider = &singleCredentialProvider{credential: credential} - service.allCredentials = []Credential{credential} + userConfigMap := make(map[string]*option.OCMUser) + for i := range options.Users { + userConfigMap[options.Users[i].Name] = &options.Users[i] } + service.userConfigMap = userConfigMap if options.TLS != nil { tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) @@ -270,11 +270,9 @@ func (s *Service) Start(stage adapter.StartStage) error { s.interruptWebSocketSessionsForCredential(tag) }) } - if len(s.options.Credentials) > 0 { - err := validateOCMCompositeCredentialModes(s.options, s.providers) - if err != nil { - return E.Cause(err, "validate loaded credentials") - } + err := validateOCMCompositeCredentialModes(s.options, s.providers) + if err != nil { + return E.Cause(err, "validate loaded credentials") } router := chi.NewRouter() diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 905b1f5aef..7c9242f5af 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -53,9 +53,9 @@ func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { if len(s.options.Users) > 0 { - return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + return credentialForUser(s.userConfigMap, s.providers, username) } - provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + provider := s.providers[s.options.Credentials[0].Tag] if provider == nil { return nil, E.New("no credential available") } @@ -117,14 +117,14 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(s.options.Users) > 0 { userConfig = s.userConfigMap[username] var err error - provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + provider, err = credentialForUser(s.userConfigMap, s.providers, username) if err != nil { s.logger.ErrorContext(ctx, "resolve credential: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) return } } else { - provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + provider = s.providers[s.options.Credentials[0].Tag] } if provider == nil { writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index e32d8244ed..327d3a2da0 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -15,42 +15,43 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { return } - if len(s.options.Users) == 0 { - writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") - return - } - - if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "API key authentication is not supported; use Authorization: Bearer with an OCM user token") - return - } + var provider credentialProvider + var userConfig *option.OCMUser + if len(s.options.Users) > 0 { + if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "API key authentication is not supported; use Authorization: Bearer with an OCM user token") + return + } - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - username, ok := s.userManager.Authenticate(clientToken) - if !ok { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + username, ok := s.userManager.Authenticate(clientToken) + if !ok { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } - userConfig := s.userConfigMap[username] - if userConfig == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") - return + userConfig = s.userConfigMap[username] + var err error + provider, err = credentialForUser(s.userConfigMap, s.providers, username) + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = s.providers[s.options.Credentials[0].Tag] } - - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") return } @@ -72,10 +73,10 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user if !credential.isAvailable() { continue } - if userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { + if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { continue } - if !userConfig.AllowExternalUsage && credential.isExternal() { + if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() { continue } weight := credential.planWeight() @@ -100,7 +101,7 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user } func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) + provider, err := credentialForUser(s.userConfigMap, s.providers, userConfig.Name) if err != nil { return } From 656b09d1be1a8ac3705c2293a21f312d0ea12f23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 Mar 2026 13:03:31 +0800 Subject: [PATCH 50/96] ccm,ocm: never treat external usage endpoint failures as over-limit --- service/ccm/credential_external.go | 47 +++++------------------------- service/ocm/credential_external.go | 47 +++++------------------------- 2 files changed, 16 insertions(+), 78 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index a9876eaf50..9a3ca1e16c 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -549,10 +549,8 @@ func (c *externalCredential) pollUsage(ctx context.Context) { response, err := c.doPollUsageRequest(ctx) if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": ", err) - } - c.incrementPollFailures() + c.logger.Debug("poll usage for ", c.tag, ": ", err) + c.clearPollFailures() return } defer response.Body.Close() @@ -560,16 +558,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - // 404 means the remote does not have a status endpoint yet; - // usage will be updated passively from response headers. - if response.StatusCode == http.StatusNotFound { - c.stateAccess.Lock() - c.state.consecutivePollFailures = 0 - c.checkTransitionLocked() - c.stateAccess.Unlock() - } else { - c.incrementPollFailures() - } + c.clearPollFailures() return } @@ -581,7 +570,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { err = json.NewDecoder(response.Body).Decode(&statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.incrementPollFailures() + c.clearPollFailures() return } @@ -625,34 +614,14 @@ func (c *externalCredential) markUsagePollAttempted() { } func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateAccess.RLock() - failures := c.state.consecutivePollFailures - c.stateAccess.RUnlock() - if failures <= 0 { - return baseInterval - } - backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) - if backoff > httpRetryMaxBackoff { - return httpRetryMaxBackoff - } - return backoff -} - -func (c *externalCredential) isPollBackoffAtCap() bool { - c.stateAccess.RLock() - defer c.stateAccess.RUnlock() - failures := c.state.consecutivePollFailures - return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff + return baseInterval } -func (c *externalCredential) incrementPollFailures() { +func (c *externalCredential) clearPollFailures() { c.stateAccess.Lock() - c.state.consecutivePollFailures++ - shouldInterrupt := c.checkTransitionLocked() + c.state.consecutivePollFailures = 0 + c.checkTransitionLocked() c.stateAccess.Unlock() - if shouldInterrupt { - c.interruptConnections() - } } func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index bcfe6d234b..d924ae4ee1 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -588,10 +588,8 @@ func (c *externalCredential) pollUsage(ctx context.Context) { response, err := c.doPollUsageRequest(ctx) if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": ", err) - } - c.incrementPollFailures() + c.logger.Debug("poll usage for ", c.tag, ": ", err) + c.clearPollFailures() return } defer response.Body.Close() @@ -599,16 +597,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - // 404 means the remote does not have a status endpoint yet; - // usage will be updated passively from response headers. - if response.StatusCode == http.StatusNotFound { - c.stateAccess.Lock() - c.state.consecutivePollFailures = 0 - c.checkTransitionLocked() - c.stateAccess.Unlock() - } else { - c.incrementPollFailures() - } + c.clearPollFailures() return } @@ -620,7 +609,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { err = json.NewDecoder(response.Body).Decode(&statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.incrementPollFailures() + c.clearPollFailures() return } @@ -664,34 +653,14 @@ func (c *externalCredential) markUsagePollAttempted() { } func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateAccess.RLock() - failures := c.state.consecutivePollFailures - c.stateAccess.RUnlock() - if failures <= 0 { - return baseInterval - } - backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) - if backoff > httpRetryMaxBackoff { - return httpRetryMaxBackoff - } - return backoff -} - -func (c *externalCredential) isPollBackoffAtCap() bool { - c.stateAccess.RLock() - defer c.stateAccess.RUnlock() - failures := c.state.consecutivePollFailures - return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff + return baseInterval } -func (c *externalCredential) incrementPollFailures() { +func (c *externalCredential) clearPollFailures() { c.stateAccess.Lock() - c.state.consecutivePollFailures++ - shouldInterrupt := c.checkTransitionLocked() + c.state.consecutivePollFailures = 0 + c.checkTransitionLocked() c.stateAccess.Unlock() - if shouldInterrupt { - c.interruptConnections() - } } func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { From 763e0af010474ee01e0690f3b13580c149ae0536 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 Mar 2026 13:42:53 +0800 Subject: [PATCH 51/96] docs: complete ccm/ocm documentation for 1.14.0 features --- docs/configuration/service/ccm.md | 60 +++++++++++++++++++++++++--- docs/configuration/service/ccm.zh.md | 60 +++++++++++++++++++++++++--- docs/configuration/service/ocm.md | 60 +++++++++++++++++++++++++--- docs/configuration/service/ocm.zh.md | 60 +++++++++++++++++++++++++--- 4 files changed, 216 insertions(+), 24 deletions(-) diff --git a/docs/configuration/service/ccm.md b/docs/configuration/service/ccm.md index 691782968a..9aa7d95bb0 100644 --- a/docs/configuration/service/ccm.md +++ b/docs/configuration/service/ccm.md @@ -13,7 +13,10 @@ It handles OAuth authentication with Claude's API on your local machine while al !!! quote "Changes in sing-box 1.14.0" :material-plus: [credentials](#credentials) - :material-alert: [users](#users) + :material-alert: [credential_path](#credential_path) + :material-alert: [usages_path](#usages_path) + :material-alert: [users](#users) + :material-alert: [detour](#detour) ### Structure @@ -51,6 +54,8 @@ On macOS, credentials are read from the system keychain first, then fall back to Refreshed tokens are automatically written back to the same location. +!!! question "Since sing-box 1.14.0" + When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid. On macOS without an explicit `credential_path`, keychain changes are not watched. Automatic reload only applies to the credential file path. @@ -65,7 +70,7 @@ List of credential configurations for multi-credential mode. When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag. -Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field. +Each credential has a `type` field (`default`, `external`, `balancer`, or `fallback`) and a required `tag` field. ##### Default Credential @@ -76,7 +81,9 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a "usages_path": "/path/to/usages.json", "detour": "", "reserve_5h": 20, - "reserve_weekly": 20 + "reserve_weekly": 20, + "limit_5h": 0, + "limit_weekly": 0 } ``` @@ -85,8 +92,10 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de - `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`. - `usages_path`: Optional usage tracking file for this credential. - `detour`: Outbound tag for connecting to the Claude API with this credential. -- `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization. -- `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization. +- `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization. Conflict with `limit_5h`. +- `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization. Conflict with `limit_weekly`. +- `limit_5h`: Explicit utilization cap (1-99) for 5-hour window. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`. +- `limit_weekly`: Explicit utilization cap (1-99) for weekly window. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`. ##### Balancer Credential @@ -122,6 +131,34 @@ Uses credentials in order. Falls through to the next when the current one is exh - `credentials`: ==Required== Ordered list of default credential tags. - `poll_interval`: How often to poll upstream usage API. Default `60s`. +##### External Credential + +```json +{ + "tag": "remote", + "type": "external", + "url": "", + "server": "", + "server_port": 0, + "token": "", + "reverse": false, + "detour": "", + "usages_path": "", + "poll_interval": "30m" +} +``` + +Proxies requests through a remote CCM instance instead of using a local OAuth credential. + +- `url`: URL of the remote CCM instance. Omit in reverse receiver mode. +- `server`: Override server address for dialing, separate from URL hostname. +- `server_port`: Override server port for dialing. +- `token`: ==Required== Authentication token for the remote instance. +- `reverse`: Enable reverse proxy mode. When `url` is set with `reverse`, acts as a connector that dials out to the remote instance. When `url` is empty, acts as a receiver waiting for inbound reverse connections. +- `detour`: Outbound tag for connecting to the remote instance. +- `usages_path`: Optional usage tracking file. +- `poll_interval`: How often to poll the remote status endpoint. Default `30m`. + #### usages_path Path to the file for storing aggregated API usage statistics. @@ -137,6 +174,8 @@ Statistics are organized by model, context window (200k standard vs 1M premium), The statistics file is automatically saved every minute and upon service shutdown. +!!! question "Since sing-box 1.14.0" + Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials. #### users @@ -151,7 +190,9 @@ Object format: { "name": "", "token": "", - "credential": "" + "credential": "", + "external_credential": "", + "allow_external_usage": false } ``` @@ -159,7 +200,12 @@ Object fields: - `name`: Username identifier for tracking purposes. - `token`: Bearer token for authentication. Claude Code authenticates by setting the `ANTHROPIC_AUTH_TOKEN` environment variable to their token value. + +!!! question "Since sing-box 1.14.0" + - `credential`: Credential tag to use for this user. ==Required== when `credentials` is set. +- `external_credential`: Tag of an external credential dedicated to serving this user. Response rate-limit headers are rewritten with aggregated utilization from all other credentials available to this user. +- `allow_external_usage`: Allow this user to use external credentials. `false` by default. #### headers @@ -171,6 +217,8 @@ These headers will override any existing headers with the same name. Outbound tag for connecting to the Claude API. +!!! question "Since sing-box 1.14.0" + Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials. #### tls diff --git a/docs/configuration/service/ccm.zh.md b/docs/configuration/service/ccm.zh.md index f555fc4d2d..9a0e0380c6 100644 --- a/docs/configuration/service/ccm.zh.md +++ b/docs/configuration/service/ccm.zh.md @@ -13,7 +13,10 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许 !!! quote "sing-box 1.14.0 中的更改" :material-plus: [credentials](#credentials) - :material-alert: [users](#users) + :material-alert: [credential_path](#credential_path) + :material-alert: [usages_path](#usages_path) + :material-alert: [users](#users) + :material-alert: [detour](#detour) ### 结构 @@ -51,6 +54,8 @@ Claude Code OAuth 凭据文件的路径。 刷新的令牌会自动写回相同位置。 +!!! question "自 sing-box 1.14.0 起" + 当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。 在 macOS 上如果未显式设置 `credential_path`,不会监听钥匙串变化。自动重载只作用于凭据文件路径。 @@ -65,7 +70,7 @@ Claude Code OAuth 凭据文件的路径。 设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。 -每个凭据有一个 `type` 字段(`default`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 +每个凭据有一个 `type` 字段(`default`、`external`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 ##### 默认凭据 @@ -76,7 +81,9 @@ Claude Code OAuth 凭据文件的路径。 "usages_path": "/path/to/usages.json", "detour": "", "reserve_5h": 20, - "reserve_weekly": 20 + "reserve_weekly": 20, + "limit_5h": 0, + "limit_weekly": 0 } ``` @@ -85,8 +92,10 @@ Claude Code OAuth 凭据文件的路径。 - `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。 - `usages_path`:此凭据的可选使用跟踪文件。 - `detour`:此凭据用于连接 Claude API 的出站标签。 -- `reserve_5h`:5 小时窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 -- `reserve_weekly`:每周窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 +- `reserve_5h`:5 小时窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_5h` 冲突。 +- `reserve_weekly`:每周窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_weekly` 冲突。 +- `limit_5h`:5 小时窗口的显式利用率上限(1-99)。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。 +- `limit_weekly`:每周窗口的显式利用率上限(1-99)。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。 ##### 均衡凭据 @@ -122,6 +131,34 @@ Claude Code OAuth 凭据文件的路径。 - `credentials`:==必填== 有序的默认凭据标签列表。 - `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 +##### 外部凭据 + +```json +{ + "tag": "remote", + "type": "external", + "url": "", + "server": "", + "server_port": 0, + "token": "", + "reverse": false, + "detour": "", + "usages_path": "", + "poll_interval": "30m" +} +``` + +通过远程 CCM 实例代理请求,而非使用本地 OAuth 凭据。 + +- `url`:远程 CCM 实例的 URL。在反向接收模式下省略。 +- `server`:覆盖拨号的服务器地址,与 URL 主机名分开。 +- `server_port`:覆盖拨号的服务器端口。 +- `token`:==必填== 远程实例的身份验证令牌。 +- `reverse`:启用反向代理模式。当设置了 `url` 和 `reverse` 时,作为连接器主动拨出到远程实例。当 `url` 为空时,作为接收器等待入站反向连接。 +- `detour`:用于连接远程实例的出站标签。 +- `usages_path`:可选的使用跟踪文件。 +- `poll_interval`:轮询远程状态端点的间隔。默认 `30m`。 + #### usages_path 用于存储聚合 API 使用统计信息的文件路径。 @@ -137,6 +174,8 @@ Claude Code OAuth 凭据文件的路径。 统计文件每分钟自动保存一次,并在服务关闭时保存。 +!!! question "自 sing-box 1.14.0 起" + 与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。 #### users @@ -151,7 +190,9 @@ Claude Code OAuth 凭据文件的路径。 { "name": "", "token": "", - "credential": "" + "credential": "", + "external_credential": "", + "allow_external_usage": false } ``` @@ -159,7 +200,12 @@ Claude Code OAuth 凭据文件的路径。 - `name`:用于跟踪的用户名标识符。 - `token`:用于身份验证的 Bearer 令牌。Claude Code 通过设置 `ANTHROPIC_AUTH_TOKEN` 环境变量为其令牌值进行身份验证。 + +!!! question "自 sing-box 1.14.0 起" + - `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。 +- `external_credential`:专用于为此用户提供服务的外部凭据标签。响应的速率限制头会被重写为来自此用户所有其他可用凭据的聚合利用率。 +- `allow_external_usage`:允许此用户使用外部凭据。默认为 `false`。 #### headers @@ -171,6 +217,8 @@ Claude Code OAuth 凭据文件的路径。 用于连接 Claude API 的出站标签。 +!!! question "自 sing-box 1.14.0 起" + 与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。 #### tls diff --git a/docs/configuration/service/ocm.md b/docs/configuration/service/ocm.md index 4c63de0f78..63e232db52 100644 --- a/docs/configuration/service/ocm.md +++ b/docs/configuration/service/ocm.md @@ -13,7 +13,10 @@ It handles OAuth authentication with OpenAI's API on your local machine while al !!! quote "Changes in sing-box 1.14.0" :material-plus: [credentials](#credentials) - :material-alert: [users](#users) + :material-alert: [credential_path](#credential_path) + :material-alert: [usages_path](#usages_path) + :material-alert: [users](#users) + :material-alert: [detour](#detour) ### Structure @@ -49,6 +52,8 @@ If not specified, defaults to: Refreshed tokens are automatically written back to the same location. +!!! question "Since sing-box 1.14.0" + When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid. Conflict with `credentials`. @@ -61,7 +66,7 @@ List of credential configurations for multi-credential mode. When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag. -Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field. +Each credential has a `type` field (`default`, `external`, `balancer`, or `fallback`) and a required `tag` field. ##### Default Credential @@ -72,7 +77,9 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a "usages_path": "/path/to/usages.json", "detour": "", "reserve_5h": 20, - "reserve_weekly": 20 + "reserve_weekly": 20, + "limit_5h": 0, + "limit_weekly": 0 } ``` @@ -81,8 +88,10 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de - `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`. - `usages_path`: Optional usage tracking file for this credential. - `detour`: Outbound tag for connecting to the OpenAI API with this credential. -- `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization. -- `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization. +- `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization. Conflict with `limit_5h`. +- `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization. Conflict with `limit_weekly`. +- `limit_5h`: Explicit utilization cap (1-99) for primary rate limit window. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`. +- `limit_weekly`: Explicit utilization cap (1-99) for secondary (weekly) rate limit window. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`. ##### Balancer Credential @@ -118,6 +127,34 @@ Uses credentials in order. Falls through to the next when the current one is exh - `credentials`: ==Required== Ordered list of default credential tags. - `poll_interval`: How often to poll upstream usage API. Default `60s`. +##### External Credential + +```json +{ + "tag": "remote", + "type": "external", + "url": "", + "server": "", + "server_port": 0, + "token": "", + "reverse": false, + "detour": "", + "usages_path": "", + "poll_interval": "30m" +} +``` + +Proxies requests through a remote OCM instance instead of using a local OAuth credential. + +- `url`: URL of the remote OCM instance. Omit in reverse receiver mode. +- `server`: Override server address for dialing, separate from URL hostname. +- `server_port`: Override server port for dialing. +- `token`: ==Required== Authentication token for the remote instance. +- `reverse`: Enable reverse proxy mode. When `url` is set with `reverse`, acts as a connector that dials out to the remote instance. When `url` is empty, acts as a receiver waiting for inbound reverse connections. +- `detour`: Outbound tag for connecting to the remote instance. +- `usages_path`: Optional usage tracking file. +- `poll_interval`: How often to poll the remote status endpoint. Default `30m`. + #### usages_path Path to the file for storing aggregated API usage statistics. @@ -133,6 +170,8 @@ Statistics are organized by model and optionally by user when authentication is The statistics file is automatically saved every minute and upon service shutdown. +!!! question "Since sing-box 1.14.0" + Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials. #### users @@ -147,7 +186,9 @@ Object format: { "name": "", "token": "", - "credential": "" + "credential": "", + "external_credential": "", + "allow_external_usage": false } ``` @@ -155,7 +196,12 @@ Object fields: - `name`: Username identifier for tracking purposes. - `token`: Bearer token for authentication. Clients authenticate by setting the `Authorization: Bearer ` header. + +!!! question "Since sing-box 1.14.0" + - `credential`: Credential tag to use for this user. ==Required== when `credentials` is set. +- `external_credential`: Tag of an external credential dedicated to serving this user. Response rate-limit headers are rewritten with aggregated utilization from all other credentials available to this user. +- `allow_external_usage`: Allow this user to use external credentials. `false` by default. #### headers @@ -167,6 +213,8 @@ These headers will override any existing headers with the same name. Outbound tag for connecting to the OpenAI API. +!!! question "Since sing-box 1.14.0" + Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials. #### tls diff --git a/docs/configuration/service/ocm.zh.md b/docs/configuration/service/ocm.zh.md index 81b222ef9f..9839acb803 100644 --- a/docs/configuration/service/ocm.zh.md +++ b/docs/configuration/service/ocm.zh.md @@ -13,7 +13,10 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许 !!! quote "sing-box 1.14.0 中的更改" :material-plus: [credentials](#credentials) - :material-alert: [users](#users) + :material-alert: [credential_path](#credential_path) + :material-alert: [usages_path](#usages_path) + :material-alert: [users](#users) + :material-alert: [detour](#detour) ### 结构 @@ -49,6 +52,8 @@ OpenAI OAuth 凭据文件的路径。 刷新的令牌会自动写回相同位置。 +!!! question "自 sing-box 1.14.0 起" + 当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。 与 `credentials` 冲突。 @@ -61,7 +66,7 @@ OpenAI OAuth 凭据文件的路径。 设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。 -每个凭据有一个 `type` 字段(`default`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 +每个凭据有一个 `type` 字段(`default`、`external`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 ##### 默认凭据 @@ -72,7 +77,9 @@ OpenAI OAuth 凭据文件的路径。 "usages_path": "/path/to/usages.json", "detour": "", "reserve_5h": 20, - "reserve_weekly": 20 + "reserve_weekly": 20, + "limit_5h": 0, + "limit_weekly": 0 } ``` @@ -81,8 +88,10 @@ OpenAI OAuth 凭据文件的路径。 - `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。 - `usages_path`:此凭据的可选使用跟踪文件。 - `detour`:此凭据用于连接 OpenAI API 的出站标签。 -- `reserve_5h`:主要速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 -- `reserve_weekly`:次要(每周)速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 +- `reserve_5h`:主要速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_5h` 冲突。 +- `reserve_weekly`:次要(每周)速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_weekly` 冲突。 +- `limit_5h`:主要速率限制窗口的显式利用率上限(1-99)。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。 +- `limit_weekly`:次要(每周)速率限制窗口的显式利用率上限(1-99)。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。 ##### 均衡凭据 @@ -118,6 +127,34 @@ OpenAI OAuth 凭据文件的路径。 - `credentials`:==必填== 有序的默认凭据标签列表。 - `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 +##### 外部凭据 + +```json +{ + "tag": "remote", + "type": "external", + "url": "", + "server": "", + "server_port": 0, + "token": "", + "reverse": false, + "detour": "", + "usages_path": "", + "poll_interval": "30m" +} +``` + +通过远程 OCM 实例代理请求,而非使用本地 OAuth 凭据。 + +- `url`:远程 OCM 实例的 URL。在反向接收模式下省略。 +- `server`:覆盖拨号的服务器地址,与 URL 主机名分开。 +- `server_port`:覆盖拨号的服务器端口。 +- `token`:==必填== 远程实例的身份验证令牌。 +- `reverse`:启用反向代理模式。当设置了 `url` 和 `reverse` 时,作为连接器主动拨出到远程实例。当 `url` 为空时,作为接收器等待入站反向连接。 +- `detour`:用于连接远程实例的出站标签。 +- `usages_path`:可选的使用跟踪文件。 +- `poll_interval`:轮询远程状态端点的间隔。默认 `30m`。 + #### usages_path 用于存储聚合 API 使用统计信息的文件路径。 @@ -133,6 +170,8 @@ OpenAI OAuth 凭据文件的路径。 统计文件每分钟自动保存一次,并在服务关闭时保存。 +!!! question "自 sing-box 1.14.0 起" + 与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。 #### users @@ -147,7 +186,9 @@ OpenAI OAuth 凭据文件的路径。 { "name": "", "token": "", - "credential": "" + "credential": "", + "external_credential": "", + "allow_external_usage": false } ``` @@ -155,7 +196,12 @@ OpenAI OAuth 凭据文件的路径。 - `name`:用于跟踪的用户名标识符。 - `token`:用于身份验证的 Bearer 令牌。客户端通过设置 `Authorization: Bearer ` 头进行身份验证。 + +!!! question "自 sing-box 1.14.0 起" + - `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。 +- `external_credential`:专用于为此用户提供服务的外部凭据标签。响应的速率限制头会被重写为来自此用户所有其他可用凭据的聚合利用率。 +- `allow_external_usage`:允许此用户使用外部凭据。默认为 `false`。 #### headers @@ -167,6 +213,8 @@ OpenAI OAuth 凭据文件的路径。 用于连接 OpenAI API 的出站标签。 +!!! question "自 sing-box 1.14.0 起" + 与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。 #### tls From 9e3ec30d725285f53837c97179a3a77486210b3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 Mar 2026 20:41:47 +0800 Subject: [PATCH 52/96] docs: fix ccm and ocm credential docs --- docs/configuration/service/ccm.md | 21 +++++++++++---------- docs/configuration/service/ccm.zh.md | 21 +++++++++++---------- docs/configuration/service/ocm.md | 21 +++++++++++---------- docs/configuration/service/ocm.zh.md | 21 +++++++++++---------- 4 files changed, 44 insertions(+), 40 deletions(-) diff --git a/docs/configuration/service/ccm.md b/docs/configuration/service/ccm.md index 9aa7d95bb0..5823901395 100644 --- a/docs/configuration/service/ccm.md +++ b/docs/configuration/service/ccm.md @@ -70,7 +70,7 @@ List of credential configurations for multi-credential mode. When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag. -Each credential has a `type` field (`default`, `external`, `balancer`, or `fallback`) and a required `tag` field. +Each credential has a `type` field (`default`, `external`, or `balancer`) and a required `tag` field. ##### Default Credential @@ -94,8 +94,8 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de - `detour`: Outbound tag for connecting to the Claude API with this credential. - `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization. Conflict with `limit_5h`. - `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization. Conflict with `limit_weekly`. -- `limit_5h`: Explicit utilization cap (1-99) for 5-hour window. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`. -- `limit_weekly`: Explicit utilization cap (1-99) for weekly window. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`. +- `limit_5h`: Explicit utilization cap (0-100) for 5-hour window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`. +- `limit_weekly`: Explicit utilization cap (0-100) for weekly window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`. ##### Balancer Credential @@ -111,22 +111,23 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit. -- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default. +- `strategy`: Selection strategy. One of `least_used` `round_robin` `random` `fallback`. `least_used` will be used by default. - `credentials`: ==Required== List of default credential tags. - `poll_interval`: How often to poll upstream usage API. Default `60s`. -##### Fallback Credential +##### Fallback Strategy ```json { "tag": "backup", - "type": "fallback", + "type": "balancer", + "strategy": "fallback", "credentials": ["a", "b"], "poll_interval": "30s" } ``` -Uses credentials in order. Falls through to the next when the current one is exhausted. +A balancer with `strategy: "fallback"` uses credentials in order. It falls through to the next when the current one is exhausted. - `credentials`: ==Required== Ordered list of default credential tags. - `poll_interval`: How often to poll upstream usage API. Default `60s`. @@ -150,11 +151,11 @@ Uses credentials in order. Falls through to the next when the current one is exh Proxies requests through a remote CCM instance instead of using a local OAuth credential. -- `url`: URL of the remote CCM instance. Omit in reverse receiver mode. +- `url`: URL of the remote CCM instance. Omit to create a receiver that only waits for inbound reverse connections. - `server`: Override server address for dialing, separate from URL hostname. - `server_port`: Override server port for dialing. - `token`: ==Required== Authentication token for the remote instance. -- `reverse`: Enable reverse proxy mode. When `url` is set with `reverse`, acts as a connector that dials out to the remote instance. When `url` is empty, acts as a receiver waiting for inbound reverse connections. +- `reverse`: Enable connector mode. Requires `url`. A connector dials out to `/ccm/v1/reverse` on the remote instance and cannot serve local requests directly. When `url` is set without `reverse`, the credential proxies requests through the remote instance normally and prefers an established reverse connection when one is available. - `detour`: Outbound tag for connecting to the remote instance. - `usages_path`: Optional usage tracking file. - `poll_interval`: How often to poll the remote status endpoint. Default `30m`. @@ -204,7 +205,7 @@ Object fields: !!! question "Since sing-box 1.14.0" - `credential`: Credential tag to use for this user. ==Required== when `credentials` is set. -- `external_credential`: Tag of an external credential dedicated to serving this user. Response rate-limit headers are rewritten with aggregated utilization from all other credentials available to this user. +- `external_credential`: Tag of an external credential used only to rewrite response rate-limit headers with aggregated utilization from this user's other available credentials. It does not control request routing; request selection still comes from `credential` and `allow_external_usage`. - `allow_external_usage`: Allow this user to use external credentials. `false` by default. #### headers diff --git a/docs/configuration/service/ccm.zh.md b/docs/configuration/service/ccm.zh.md index 9a0e0380c6..586cb5bb15 100644 --- a/docs/configuration/service/ccm.zh.md +++ b/docs/configuration/service/ccm.zh.md @@ -70,7 +70,7 @@ Claude Code OAuth 凭据文件的路径。 设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。 -每个凭据有一个 `type` 字段(`default`、`external`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 +每个凭据有一个 `type` 字段(`default`、`external` 或 `balancer`)和一个必填的 `tag` 字段。 ##### 默认凭据 @@ -94,8 +94,8 @@ Claude Code OAuth 凭据文件的路径。 - `detour`:此凭据用于连接 Claude API 的出站标签。 - `reserve_5h`:5 小时窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_5h` 冲突。 - `reserve_weekly`:每周窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_weekly` 冲突。 -- `limit_5h`:5 小时窗口的显式利用率上限(1-99)。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。 -- `limit_weekly`:每周窗口的显式利用率上限(1-99)。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。 +- `limit_5h`:5 小时窗口的显式利用率上限(0-100)。`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。 +- `limit_weekly`:每周窗口的显式利用率上限(0-100)。`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。 ##### 均衡凭据 @@ -111,22 +111,23 @@ Claude Code OAuth 凭据文件的路径。 根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。 -- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`。 +- `strategy`:选择策略。可选值:`least_used` `round_robin` `random` `fallback`。默认使用 `least_used`。 - `credentials`:==必填== 默认凭据标签列表。 - `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 -##### 回退凭据 +##### 回退策略 ```json { "tag": "backup", - "type": "fallback", + "type": "balancer", + "strategy": "fallback", "credentials": ["a", "b"], "poll_interval": "30s" } ``` -按顺序使用凭据。当前凭据耗尽后切换到下一个。 +将 `strategy` 设为 `fallback` 的均衡凭据会按顺序使用凭据。当前凭据耗尽后切换到下一个。 - `credentials`:==必填== 有序的默认凭据标签列表。 - `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 @@ -150,11 +151,11 @@ Claude Code OAuth 凭据文件的路径。 通过远程 CCM 实例代理请求,而非使用本地 OAuth 凭据。 -- `url`:远程 CCM 实例的 URL。在反向接收模式下省略。 +- `url`:远程 CCM 实例的 URL。省略时,此凭据作为仅等待入站反向连接的接收器。 - `server`:覆盖拨号的服务器地址,与 URL 主机名分开。 - `server_port`:覆盖拨号的服务器端口。 - `token`:==必填== 远程实例的身份验证令牌。 -- `reverse`:启用反向代理模式。当设置了 `url` 和 `reverse` 时,作为连接器主动拨出到远程实例。当 `url` 为空时,作为接收器等待入站反向连接。 +- `reverse`:启用连接器模式。要求设置 `url`。启用后,此凭据会主动拨出到远程实例的 `/ccm/v1/reverse`,且不能直接为本地请求提供服务。当设置了 `url` 但未启用 `reverse` 时,此凭据会正常通过远程实例转发请求,并在反向连接建立后优先使用该反向连接。 - `detour`:用于连接远程实例的出站标签。 - `usages_path`:可选的使用跟踪文件。 - `poll_interval`:轮询远程状态端点的间隔。默认 `30m`。 @@ -204,7 +205,7 @@ Claude Code OAuth 凭据文件的路径。 !!! question "自 sing-box 1.14.0 起" - `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。 -- `external_credential`:专用于为此用户提供服务的外部凭据标签。响应的速率限制头会被重写为来自此用户所有其他可用凭据的聚合利用率。 +- `external_credential`:仅用于用此用户其他可用凭据的聚合利用率重写响应速率限制头的外部凭据标签。它不参与请求路由;请求选择仍由 `credential` 和 `allow_external_usage` 决定。 - `allow_external_usage`:允许此用户使用外部凭据。默认为 `false`。 #### headers diff --git a/docs/configuration/service/ocm.md b/docs/configuration/service/ocm.md index 63e232db52..43027d9dbb 100644 --- a/docs/configuration/service/ocm.md +++ b/docs/configuration/service/ocm.md @@ -66,7 +66,7 @@ List of credential configurations for multi-credential mode. When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag. -Each credential has a `type` field (`default`, `external`, `balancer`, or `fallback`) and a required `tag` field. +Each credential has a `type` field (`default`, `external`, or `balancer`) and a required `tag` field. ##### Default Credential @@ -90,8 +90,8 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de - `detour`: Outbound tag for connecting to the OpenAI API with this credential. - `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization. Conflict with `limit_5h`. - `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization. Conflict with `limit_weekly`. -- `limit_5h`: Explicit utilization cap (1-99) for primary rate limit window. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`. -- `limit_weekly`: Explicit utilization cap (1-99) for secondary (weekly) rate limit window. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`. +- `limit_5h`: Explicit utilization cap (0-100) for primary rate limit window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`. +- `limit_weekly`: Explicit utilization cap (0-100) for secondary (weekly) rate limit window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`. ##### Balancer Credential @@ -107,22 +107,23 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit. -- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default. +- `strategy`: Selection strategy. One of `least_used` `round_robin` `random` `fallback`. `least_used` will be used by default. - `credentials`: ==Required== List of default credential tags. - `poll_interval`: How often to poll upstream usage API. Default `60s`. -##### Fallback Credential +##### Fallback Strategy ```json { "tag": "backup", - "type": "fallback", + "type": "balancer", + "strategy": "fallback", "credentials": ["a", "b"], "poll_interval": "30s" } ``` -Uses credentials in order. Falls through to the next when the current one is exhausted. +A balancer with `strategy: "fallback"` uses credentials in order. It falls through to the next when the current one is exhausted. - `credentials`: ==Required== Ordered list of default credential tags. - `poll_interval`: How often to poll upstream usage API. Default `60s`. @@ -146,11 +147,11 @@ Uses credentials in order. Falls through to the next when the current one is exh Proxies requests through a remote OCM instance instead of using a local OAuth credential. -- `url`: URL of the remote OCM instance. Omit in reverse receiver mode. +- `url`: URL of the remote OCM instance. Omit to create a receiver that only waits for inbound reverse connections. - `server`: Override server address for dialing, separate from URL hostname. - `server_port`: Override server port for dialing. - `token`: ==Required== Authentication token for the remote instance. -- `reverse`: Enable reverse proxy mode. When `url` is set with `reverse`, acts as a connector that dials out to the remote instance. When `url` is empty, acts as a receiver waiting for inbound reverse connections. +- `reverse`: Enable connector mode. Requires `url`. A connector dials out to `/ocm/v1/reverse` on the remote instance and cannot serve local requests directly. When `url` is set without `reverse`, the credential proxies requests through the remote instance normally and prefers an established reverse connection when one is available. - `detour`: Outbound tag for connecting to the remote instance. - `usages_path`: Optional usage tracking file. - `poll_interval`: How often to poll the remote status endpoint. Default `30m`. @@ -200,7 +201,7 @@ Object fields: !!! question "Since sing-box 1.14.0" - `credential`: Credential tag to use for this user. ==Required== when `credentials` is set. -- `external_credential`: Tag of an external credential dedicated to serving this user. Response rate-limit headers are rewritten with aggregated utilization from all other credentials available to this user. +- `external_credential`: Tag of an external credential used only to rewrite response rate-limit headers with aggregated utilization from this user's other available credentials. It does not control request routing; request selection still comes from `credential` and `allow_external_usage`. - `allow_external_usage`: Allow this user to use external credentials. `false` by default. #### headers diff --git a/docs/configuration/service/ocm.zh.md b/docs/configuration/service/ocm.zh.md index 9839acb803..2d06206f07 100644 --- a/docs/configuration/service/ocm.zh.md +++ b/docs/configuration/service/ocm.zh.md @@ -66,7 +66,7 @@ OpenAI OAuth 凭据文件的路径。 设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。 -每个凭据有一个 `type` 字段(`default`、`external`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 +每个凭据有一个 `type` 字段(`default`、`external` 或 `balancer`)和一个必填的 `tag` 字段。 ##### 默认凭据 @@ -90,8 +90,8 @@ OpenAI OAuth 凭据文件的路径。 - `detour`:此凭据用于连接 OpenAI API 的出站标签。 - `reserve_5h`:主要速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_5h` 冲突。 - `reserve_weekly`:次要(每周)速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_weekly` 冲突。 -- `limit_5h`:主要速率限制窗口的显式利用率上限(1-99)。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。 -- `limit_weekly`:次要(每周)速率限制窗口的显式利用率上限(1-99)。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。 +- `limit_5h`:主要速率限制窗口的显式利用率上限(0-100)。`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。 +- `limit_weekly`:次要(每周)速率限制窗口的显式利用率上限(0-100)。`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。 ##### 均衡凭据 @@ -107,22 +107,23 @@ OpenAI OAuth 凭据文件的路径。 根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。 -- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`。 +- `strategy`:选择策略。可选值:`least_used` `round_robin` `random` `fallback`。默认使用 `least_used`。 - `credentials`:==必填== 默认凭据标签列表。 - `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 -##### 回退凭据 +##### 回退策略 ```json { "tag": "backup", - "type": "fallback", + "type": "balancer", + "strategy": "fallback", "credentials": ["a", "b"], "poll_interval": "30s" } ``` -按顺序使用凭据。当前凭据耗尽后切换到下一个。 +将 `strategy` 设为 `fallback` 的均衡凭据会按顺序使用凭据。当前凭据耗尽后切换到下一个。 - `credentials`:==必填== 有序的默认凭据标签列表。 - `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 @@ -146,11 +147,11 @@ OpenAI OAuth 凭据文件的路径。 通过远程 OCM 实例代理请求,而非使用本地 OAuth 凭据。 -- `url`:远程 OCM 实例的 URL。在反向接收模式下省略。 +- `url`:远程 OCM 实例的 URL。省略时,此凭据作为仅等待入站反向连接的接收器。 - `server`:覆盖拨号的服务器地址,与 URL 主机名分开。 - `server_port`:覆盖拨号的服务器端口。 - `token`:==必填== 远程实例的身份验证令牌。 -- `reverse`:启用反向代理模式。当设置了 `url` 和 `reverse` 时,作为连接器主动拨出到远程实例。当 `url` 为空时,作为接收器等待入站反向连接。 +- `reverse`:启用连接器模式。要求设置 `url`。启用后,此凭据会主动拨出到远程实例的 `/ocm/v1/reverse`,且不能直接为本地请求提供服务。当设置了 `url` 但未启用 `reverse` 时,此凭据会正常通过远程实例转发请求,并在反向连接建立后优先使用该反向连接。 - `detour`:用于连接远程实例的出站标签。 - `usages_path`:可选的使用跟踪文件。 - `poll_interval`:轮询远程状态端点的间隔。默认 `30m`。 @@ -200,7 +201,7 @@ OpenAI OAuth 凭据文件的路径。 !!! question "自 sing-box 1.14.0 起" - `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。 -- `external_credential`:专用于为此用户提供服务的外部凭据标签。响应的速率限制头会被重写为来自此用户所有其他可用凭据的聚合利用率。 +- `external_credential`:仅用于用此用户其他可用凭据的聚合利用率重写响应速率限制头的外部凭据标签。它不参与请求路由;请求选择仍由 `credential` 和 `allow_external_usage` 决定。 - `allow_external_usage`:允许此用户使用外部凭据。默认为 `false`。 #### headers From 14ade769563aba6008a8689cf690a48385ef5b28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 Mar 2026 19:09:29 +0800 Subject: [PATCH 53/96] ccm,ocm: remove dead code, fix timer leaks, eliminate redundant lookups - Remove unused onBecameUnusable field from CCM credential structs (OCM wires it for WebSocket interruption; CCM has no equivalent) - Replace time.After with time.NewTimer in doHTTPWithRetry and connectorLoop to avoid timer leaks on context cancellation - Pass already-resolved provider to rewriteResponseHeadersForExternalUser instead of re-resolving via credentialForUser - Hoist reverseYamuxConfig to package-level var (immutable, no need to allocate on every call) --- service/ccm/credential.go | 4 +++- service/ccm/credential_default.go | 6 +----- service/ccm/credential_external.go | 6 +----- service/ccm/reverse.go | 12 +++++++----- service/ccm/service_handler.go | 2 +- service/ccm/service_status.go | 7 +------ service/ocm/credential.go | 4 +++- service/ocm/reverse.go | 12 +++++++----- service/ocm/service_handler.go | 2 +- service/ocm/service_status.go | 7 +------ service/ocm/service_websocket.go | 2 +- 11 files changed, 27 insertions(+), 37 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index b732617177..9e16141667 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -26,10 +26,12 @@ func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func for attempt := range httpRetryMaxAttempts { if attempt > 0 { delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + timer := time.NewTimer(delay) select { case <-ctx.Done(): + timer.Stop() return nil, lastError - case <-time.After(delay): + case <-timer.C: } } request, err := buildRequest() diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 0bbf188749..021df5d272 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -44,8 +44,7 @@ type defaultCredential struct { watcherRetryAt time.Time // Connection interruption - onBecameUnusable func() - interrupted bool + interrupted bool requestContext context.Context cancelRequests context.CancelFunc requestAccess sync.Mutex @@ -353,9 +352,6 @@ func (c *defaultCredential) interruptConnections() { c.cancelRequests() c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) c.requestAccess.Unlock() - if c.onBecameUnusable != nil { - c.onBecameUnusable() - } } func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 9a3ca1e16c..21e695b241 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -40,8 +40,7 @@ type externalCredential struct { usageTracker *AggregatedUsage logger log.ContextLogger - onBecameUnusable func() - interrupted bool + interrupted bool requestContext context.Context cancelRequests context.CancelFunc requestAccess sync.Mutex @@ -494,9 +493,6 @@ func (c *externalCredential) interruptConnections() { c.cancelRequests() c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) c.requestAccess.Unlock() - if c.onBecameUnusable != nil { - c.onBecameUnusable() - } } func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Response, error) { diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index b6b4c88a09..0da0c567cf 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -17,14 +17,14 @@ import ( "github.com/hashicorp/yamux" ) -func reverseYamuxConfig() *yamux.Config { +var defaultYamuxConfig = func() *yamux.Config { config := yamux.DefaultConfig() config.KeepAliveInterval = 15 * time.Second config.ConnectionWriteTimeout = 10 * time.Second config.MaxStreamWindowSize = 512 * 1024 config.LogOutput = io.Discard return config -} +}() type bufferedConn struct { reader *bufio.Reader @@ -108,7 +108,7 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite return } - session, err := yamux.Client(conn, reverseYamuxConfig()) + session, err := yamux.Client(conn, defaultYamuxConfig) if err != nil { conn.Close() s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) @@ -161,9 +161,11 @@ func (c *externalCredential) connectorLoop() { consecutiveFailures++ backoff := connectorBackoff(consecutiveFailures) c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) + timer := time.NewTimer(backoff) select { - case <-time.After(backoff): + case <-timer.C: case <-ctx.Done(): + timer.Stop() return } } @@ -236,7 +238,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio } } - session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, reverseYamuxConfig()) + session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, defaultYamuxConfig) if err != nil { conn.Close() return 0, E.Cause(err, "create yamux server") diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index fdbb682033..33d6317de7 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -313,7 +313,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Rewrite response headers for external users if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) + s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig) } for key, values := range response.Header { diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index bd8aa4b22b..20c4780e53 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -100,12 +100,7 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user totalWeight } -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { - provider, err := credentialForUser(s.userConfigMap, s.providers, userConfig.Name) - if err != nil { - return - } - +func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) { avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 80c094cdb6..8a56bfcec9 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -29,10 +29,12 @@ func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func for attempt := range httpRetryMaxAttempts { if attempt > 0 { delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + timer := time.NewTimer(delay) select { case <-ctx.Done(): + timer.Stop() return nil, lastError - case <-time.After(delay): + case <-timer.C: } } request, err := buildRequest() diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index b47826e60e..f97df5b87a 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -17,14 +17,14 @@ import ( "github.com/hashicorp/yamux" ) -func reverseYamuxConfig() *yamux.Config { +var defaultYamuxConfig = func() *yamux.Config { config := yamux.DefaultConfig() config.KeepAliveInterval = 15 * time.Second config.ConnectionWriteTimeout = 10 * time.Second config.MaxStreamWindowSize = 512 * 1024 config.LogOutput = io.Discard return config -} +}() type bufferedConn struct { reader *bufio.Reader @@ -108,7 +108,7 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite return } - session, err := yamux.Client(conn, reverseYamuxConfig()) + session, err := yamux.Client(conn, defaultYamuxConfig) if err != nil { conn.Close() s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) @@ -161,9 +161,11 @@ func (c *externalCredential) connectorLoop() { consecutiveFailures++ backoff := connectorBackoff(consecutiveFailures) c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) + timer := time.NewTimer(backoff) select { - case <-time.After(backoff): + case <-timer.C: case <-ctx.Done(): + timer.Stop() return } } @@ -236,7 +238,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio } } - session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, reverseYamuxConfig()) + session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, defaultYamuxConfig) if err != nil { conn.Close() return 0, E.Cause(err, "create yamux server") diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 7c9242f5af..0a7698d005 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -293,7 +293,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Rewrite response headers for external users if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) + s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig) } for key, values := range response.Header { diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index 327d3a2da0..7057b60f72 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -100,12 +100,7 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user totalWeight } -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { - provider, err := credentialForUser(s.userConfigMap, s.providers, userConfig.Name) - if err != nil { - return - } - +func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) { avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index ce20e1be77..25ded03676 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -253,7 +253,7 @@ func (s *Service) handleWebSocket( } } if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, userConfig) + s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, provider, userConfig) } clientUpgrader := ws.HTTPUpgrader{ From 2dd093a32e8fae75f3621afe29fad1f7f631a436 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 Mar 2026 21:09:08 +0800 Subject: [PATCH 54/96] ccm,ocm: fix data race, remove dead code, clean up inefficiencies --- service/ccm/credential_default.go | 3 ++- service/ccm/credential_provider.go | 8 ++++++++ service/ccm/service.go | 6 ++---- service/ccm/service_usage.go | 2 +- service/ocm/credential_default.go | 7 +------ service/ocm/credential_provider.go | 8 ++++++++ service/ocm/service_usage.go | 2 +- 7 files changed, 23 insertions(+), 13 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 021df5d272..50791064b1 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -642,8 +642,9 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C if rateLimitTier != "" { c.state.rateLimitTier = rateLimitTier } + resolvedAccountType := c.state.accountType c.stateAccess.Unlock() - c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier)) + c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier)) } func (c *defaultCredential) close() { diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go index 5500df6a14..d4f5abe2f4 100644 --- a/service/ccm/credential_provider.go +++ b/service/ccm/credential_provider.go @@ -348,6 +348,14 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) { } p.sessionAccess.Unlock() + p.interruptAccess.Lock() + for key, entry := range p.credentialInterrupts { + if entry.context.Err() != nil { + delete(p.credentialInterrupts, key) + } + } + p.interruptAccess.Unlock() + for _, credential := range p.credentials { if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { credential.pollUsage(ctx) diff --git a/service/ccm/service.go b/service/ccm/service.go index 74952173c0..b1c637d141 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -5,7 +5,7 @@ import ( "encoding/json" "net/http" "strings" - "sync" + "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" @@ -157,9 +157,7 @@ type Service struct { listener *listener.Listener tlsConfig tls.ServerConfig httpServer *http.Server - userManager *UserManager - trackingGroup sync.WaitGroup - shuttingDown bool + userManager *UserManager providers map[string]credentialProvider allCredentials []Credential diff --git a/service/ccm/service_usage.go b/service/ccm/service_usage.go index e23db66542..ff14b05430 100644 --- a/service/ccm/service_usage.go +++ b/service/ccm/service_usage.go @@ -652,7 +652,7 @@ func (u *AggregatedUsage) AddUsageWithCycleHint( addUsageToCombinations(&u.Combinations, model, contextWindow, weekStartUnix, messagesCount, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens, cacheCreation5MinuteTokens, cacheCreation1HourTokens, user) - go u.scheduleSave() + u.scheduleSave() return nil } diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 18612a5690..ed4214fab7 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -587,12 +587,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { return } - var usageURL string - if c.isAPIKeyMode() { - usageURL = openaiAPIBaseURL + "/api/codex/usage" - } else { - usageURL = strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage" - } + usageURL := strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage" accountID := c.getAccountID() pollClient := &http.Client{ diff --git a/service/ocm/credential_provider.go b/service/ocm/credential_provider.go index 6f3da6b43c..092b40c679 100644 --- a/service/ocm/credential_provider.go +++ b/service/ocm/credential_provider.go @@ -375,6 +375,14 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) { } p.sessionAccess.Unlock() + p.interruptAccess.Lock() + for key, entry := range p.credentialInterrupts { + if entry.context.Err() != nil { + delete(p.credentialInterrupts, key) + } + } + p.interruptAccess.Unlock() + for _, credential := range p.credentials { if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { credential.pollUsage(ctx) diff --git a/service/ocm/service_usage.go b/service/ocm/service_usage.go index 19a853a7cd..10ebc23558 100644 --- a/service/ocm/service_usage.go +++ b/service/ocm/service_usage.go @@ -1148,7 +1148,7 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(model string, contextWindow int, addUsageToCombinations(&u.Combinations, model, normalizedServiceTier, contextWindow, weekStartUnix, user, inputTokens, outputTokens, cachedTokens) - go u.scheduleSave() + u.scheduleSave() return nil } From f3c30220946ea6406b1c4af5971d7bdf084de65d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 16 Mar 2026 22:10:10 +0800 Subject: [PATCH 55/96] ccm,ocm: fix session race, track fallback sessions, skip warmup logging Fix data race in selectCredential where concurrent goroutines could overwrite each other's session entries by adding compare-and-delete and store-if-absent patterns with retry loop. Track sessions for fallback strategy so isNew is reported correctly. Skip logging and usage tracking for websocket warmup requests (generate: false). --- service/ccm/credential_provider.go | 103 +++++++++++++++++------------ service/ocm/credential_provider.go | 103 +++++++++++++++++------------ service/ocm/service_websocket.go | 6 +- 3 files changed, 126 insertions(+), 86 deletions(-) diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go index d4f5abe2f4..8d993c6cae 100644 --- a/service/ccm/credential_provider.go +++ b/service/ccm/credential_provider.go @@ -130,63 +130,82 @@ func newBalancerProvider(credentials []Credential, strategy string, pollInterval } func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) { - if p.strategy == C.BalancerStrategyFallback { - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allCredentialsUnavailableError(p.credentials) + selectionScope := selection.scopeOrDefault() + for { + if p.strategy == C.BalancerStrategyFallback { + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + return best, p.storeSessionIfAbsent(sessionID, sessionEntry{createdAt: time.Now()}), nil } - return best, false, nil - } - selectionScope := selection.scopeOrDefault() - if sessionID != "" { - p.sessionAccess.RLock() - entry, exists := p.sessions[sessionID] - p.sessionAccess.RUnlock() - if exists { - if entry.selectionScope == selectionScope { - for _, credential := range p.credentials { - if credential.tagName() == entry.tag && selection.allows(credential) && credential.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { - better := p.pickLeastUsed(selection.filter) - if better != nil && better.tagName() != credential.tagName() { - effectiveThreshold := p.rebalanceThreshold / credential.planWeight() - delta := credential.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", credential.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", credential.planWeight(), ")") - p.rebalanceCredential(credential.tagName(), selectionScope) - break + if sessionID != "" { + p.sessionAccess.RLock() + entry, exists := p.sessions[sessionID] + p.sessionAccess.RUnlock() + if exists { + if entry.selectionScope == selectionScope { + for _, credential := range p.credentials { + if credential.tagName() == entry.tag && selection.allows(credential) && credential.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != credential.tagName() { + effectiveThreshold := p.rebalanceThreshold / credential.planWeight() + delta := credential.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", credential.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", credential.planWeight(), ")") + p.rebalanceCredential(credential.tagName(), selectionScope) + break + } } } + return credential, false, nil } - return credential, false, nil } } + p.sessionAccess.Lock() + currentEntry, stillExists := p.sessions[sessionID] + if stillExists && currentEntry == entry { + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } else { + p.sessionAccess.Unlock() + continue + } } - p.sessionAccess.Lock() - delete(p.sessions, sessionID) - p.sessionAccess.Unlock() } - } - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allCredentialsUnavailableError(p.credentials) - } - - isNew := sessionID != "" - if isNew { - p.sessionAccess.Lock() - p.sessions[sessionID] = sessionEntry{ + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + if p.storeSessionIfAbsent(sessionID, sessionEntry{ tag: best.tagName(), selectionScope: selectionScope, createdAt: time.Now(), + }) { + return best, true, nil } - p.sessionAccess.Unlock() + if sessionID == "" { + return best, false, nil + } + } +} + +func (p *balancerProvider) storeSessionIfAbsent(sessionID string, entry sessionEntry) bool { + if sessionID == "" { + return false + } + p.sessionAccess.Lock() + defer p.sessionAccess.Unlock() + if _, exists := p.sessions[sessionID]; exists { + return false } - return best, isNew, nil + p.sessions[sessionID] = entry + return true } func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { diff --git a/service/ocm/credential_provider.go b/service/ocm/credential_provider.go index 092b40c679..421258cd63 100644 --- a/service/ocm/credential_provider.go +++ b/service/ocm/credential_provider.go @@ -134,63 +134,82 @@ func newBalancerProvider(credentials []Credential, strategy string, pollInterval } func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) { - if p.strategy == C.BalancerStrategyFallback { - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allRateLimitedError(p.credentials) + selectionScope := selection.scopeOrDefault() + for { + if p.strategy == C.BalancerStrategyFallback { + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + return best, p.storeSessionIfAbsent(sessionID, sessionEntry{createdAt: time.Now()}), nil } - return best, false, nil - } - selectionScope := selection.scopeOrDefault() - if sessionID != "" { - p.sessionAccess.RLock() - entry, exists := p.sessions[sessionID] - p.sessionAccess.RUnlock() - if exists { - if entry.selectionScope == selectionScope { - for _, credential := range p.credentials { - if credential.tagName() == entry.tag && compositeCredentialSelectable(credential) && selection.allows(credential) && credential.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { - better := p.pickLeastUsed(selection.filter) - if better != nil && better.tagName() != credential.tagName() { - effectiveThreshold := p.rebalanceThreshold / credential.planWeight() - delta := credential.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", credential.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", credential.planWeight(), ")") - p.rebalanceCredential(credential.tagName(), selectionScope) - break + if sessionID != "" { + p.sessionAccess.RLock() + entry, exists := p.sessions[sessionID] + p.sessionAccess.RUnlock() + if exists { + if entry.selectionScope == selectionScope { + for _, credential := range p.credentials { + if credential.tagName() == entry.tag && compositeCredentialSelectable(credential) && selection.allows(credential) && credential.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != credential.tagName() { + effectiveThreshold := p.rebalanceThreshold / credential.planWeight() + delta := credential.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", credential.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", credential.planWeight(), ")") + p.rebalanceCredential(credential.tagName(), selectionScope) + break + } } } + return credential, false, nil } - return credential, false, nil } } + p.sessionAccess.Lock() + currentEntry, stillExists := p.sessions[sessionID] + if stillExists && currentEntry == entry { + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } else { + p.sessionAccess.Unlock() + continue + } } - p.sessionAccess.Lock() - delete(p.sessions, sessionID) - p.sessionAccess.Unlock() } - } - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allRateLimitedError(p.credentials) - } - - isNew := sessionID != "" - if isNew { - p.sessionAccess.Lock() - p.sessions[sessionID] = sessionEntry{ + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + if p.storeSessionIfAbsent(sessionID, sessionEntry{ tag: best.tagName(), selectionScope: selectionScope, createdAt: time.Now(), + }) { + return best, true, nil } - p.sessionAccess.Unlock() + if sessionID == "" { + return best, false, nil + } + } +} + +func (p *balancerProvider) storeSessionIfAbsent(sessionID string, entry sessionEntry) bool { + if sessionID == "" { + return false + } + p.sessionAccess.Lock() + defer p.sessionAccess.Unlock() + if _, exists := p.sessions[sessionID]; exists { + return false } - return best, isNew, nil + p.sessions[sessionID] = entry + return true } func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 25ded03676..066692ace7 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -325,9 +325,11 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn Type string `json:"type"` Model string `json:"model"` ServiceTier string `json:"service_tier"` + Generate *bool `json:"generate"` } if json.Unmarshal(data, &request) == nil && request.Type == "response.create" && request.Model != "" { - if isNew && !logged { + isWarmup := request.Generate != nil && !*request.Generate + if !isWarmup && isNew && !logged { logged = true logParts := []any{"assigned credential ", selectedCredential.tagName()} if sessionID != "" { @@ -342,7 +344,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn } s.logger.DebugContext(ctx, logParts...) } - if selectedCredential.usageTrackerOrNil() != nil { + if !isWarmup && selectedCredential.usageTrackerOrNil() != nil { select { case modelChannel <- request.Model: default: From f84832a36981b9250fe51782e627eb5dfd57c5dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 16:03:35 +0800 Subject: [PATCH 56/96] Add stream watch endpoint --- service/ccm/credential.go | 3 + service/ccm/credential_builder.go | 1 - service/ccm/credential_default.go | 49 ++++- service/ccm/credential_external.go | 205 ++++++++++++++++++- service/ccm/credential_file.go | 12 ++ service/ccm/credential_status_test.go | 280 ++++++++++++++++++++++++++ service/ccm/service.go | 26 ++- service/ccm/service_status.go | 76 +++++++ service/ocm/credential.go | 2 + service/ocm/credential_builder.go | 1 - service/ocm/credential_default.go | 28 +++ service/ocm/credential_external.go | 198 +++++++++++++++++- service/ocm/credential_file.go | 10 + service/ocm/credential_status_test.go | 263 ++++++++++++++++++++++++ service/ocm/service.go | 19 +- service/ocm/service_status.go | 76 +++++++ 16 files changed, 1225 insertions(+), 24 deletions(-) create mode 100644 service/ccm/credential_status_test.go create mode 100644 service/ocm/credential_status_test.go diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 9e16141667..9defed434c 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -6,6 +6,8 @@ import ( "strconv" "sync" "time" + + "github.com/sagernet/sing/common/observable" ) const ( @@ -115,6 +117,7 @@ type Credential interface { wrapRequestContext(ctx context.Context) *credentialRequestContext interruptConnections() + setStatusSubscriber(*observable.Subscriber[struct{}]) start() error pollUsage(ctx context.Context) lastUpdatedTime() time.Time diff --git a/service/ccm/credential_builder.go b/service/ccm/credential_builder.go index 94e0af957a..8ffd42b0dd 100644 --- a/service/ccm/credential_builder.go +++ b/service/ccm/credential_builder.go @@ -161,4 +161,3 @@ func credentialForUser( } return provider, nil } - diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 50791064b1..4a94050738 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -21,6 +21,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/common/observable" ) type defaultCredential struct { @@ -43,11 +44,13 @@ type defaultCredential struct { watcher *fswatch.Watcher watcherRetryAt time.Time + statusSubscriber *observable.Subscriber[struct{}] + // Connection interruption interrupted bool - requestContext context.Context - cancelRequests context.CancelFunc - requestAccess sync.Mutex + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex } func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { @@ -139,6 +142,23 @@ func (c *defaultCredential) start() error { return nil } +func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) { + c.statusSubscriber = subscriber +} + +func (c *defaultCredential) emitStatusUpdate() { + if c.statusSubscriber != nil { + c.statusSubscriber.Emit(struct{}{}) + } +} + +func (c *defaultCredential) statusAggregateStateLocked() (bool, float64) { + if c.state.unavailable { + return false, 0 + } + return true, ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) +} + func (c *defaultCredential) getAccessToken() (string, error) { c.retryCredentialReloadIfNeeded() @@ -186,13 +206,19 @@ func (c *defaultCredential) getAccessToken() (string, error) { if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { c.credentials = latestCredentials c.stateAccess.Lock() + wasAvailable, oldWeight := c.statusAggregateStateLocked() c.state.unavailable = false c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.state.accountType = latestCredentials.SubscriptionType c.state.rateLimitTier = latestCredentials.RateLimitTier c.checkTransitionLocked() + isAvailable, newWeight := c.statusAggregateStateLocked() + shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight c.stateAccess.Unlock() + if shouldEmit { + c.emitStatusUpdate() + } if !latestCredentials.needsRefresh() { return latestCredentials.AccessToken, nil } @@ -201,13 +227,19 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.credentials = newCredentials c.stateAccess.Lock() + wasAvailable, oldWeight := c.statusAggregateStateLocked() c.state.unavailable = false c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.state.accountType = newCredentials.SubscriptionType c.state.rateLimitTier = newCredentials.RateLimitTier c.checkTransitionLocked() + isAvailable, newWeight := c.statusAggregateStateLocked() + shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight c.stateAccess.Unlock() + if shouldEmit { + c.emitStatusUpdate() + } err = platformWriteCredentials(newCredentials, c.credentialPath) if err != nil { @@ -277,6 +309,9 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if shouldInterrupt { c.interruptConnections() } + if hadData { + c.emitStatusUpdate() + } } func (c *defaultCredential) markRateLimited(resetAt time.Time) { @@ -289,6 +324,7 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { if shouldInterrupt { c.interruptConnections() } + c.emitStatusUpdate() } func (c *defaultCredential) isUsable() bool { @@ -584,6 +620,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { if shouldInterrupt { c.interruptConnections() } + c.emitStatusUpdate() if needsProfileFetch { c.fetchProfile(ctx, httpClient, accessToken) @@ -636,6 +673,7 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C rateLimitTier := profileResponse.Organization.RateLimitTier c.stateAccess.Lock() + wasAvailable, oldWeight := c.statusAggregateStateLocked() if accountType != "" && c.state.accountType == "" { c.state.accountType = accountType } @@ -643,7 +681,12 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C c.state.rateLimitTier = rateLimitTier } resolvedAccountType := c.state.accountType + isAvailable, newWeight := c.statusAggregateStateLocked() + shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight c.stateAccess.Unlock() + if shouldEmit { + c.emitStatusUpdate() + } c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier)) } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 21e695b241..c66a54e8b3 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -22,11 +22,15 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/common/observable" "github.com/hashicorp/yamux" ) -const reverseProxyBaseURL = "http://reverse-proxy" +const ( + reverseProxyBaseURL = "http://reverse-proxy" + statusStreamHeader = "X-CCM-Status-Stream" +) type externalCredential struct { tag string @@ -40,10 +44,12 @@ type externalCredential struct { usageTracker *AggregatedUsage logger log.ContextLogger + statusSubscriber *observable.Subscriber[struct{}] + interrupted bool - requestContext context.Context - cancelRequests context.CancelFunc - requestAccess sync.Mutex + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex // Reverse proxy fields reverse bool @@ -61,6 +67,12 @@ type externalCredential struct { reverseService http.Handler } +type statusStreamResult struct { + duration time.Duration + frames int + oneShot bool +} + func externalCredentialURLPort(parsedURL *url.URL) uint16 { portString := parsedURL.Port() if portString != "" { @@ -218,6 +230,16 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx return credential, nil } +func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) { + c.statusSubscriber = subscriber +} + +func (c *externalCredential) emitStatusUpdate() { + if c.statusSubscriber != nil { + c.statusSubscriber.Emit(struct{}{}) + } +} + func (c *externalCredential) start() error { if c.usageTracker != nil { err := c.usageTracker.Load() @@ -227,6 +249,8 @@ func (c *externalCredential) start() error { } if c.reverse && c.connectorURL != nil { go c.connectorLoop() + } else { + go c.statusStreamLoop() } return nil } @@ -317,6 +341,7 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { if shouldInterrupt { c.interruptConnections() } + c.emitStatusUpdate() } func (c *externalCredential) earliestReset() time.Time { @@ -458,6 +483,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { if shouldInterrupt { c.interruptConnections() } + if hadData { + c.emitStatusUpdate() + } } func (c *externalCredential) checkTransitionLocked() bool { @@ -595,6 +623,142 @@ func (c *externalCredential) pollUsage(ctx context.Context) { if shouldInterrupt { c.interruptConnections() } + c.emitStatusUpdate() +} + +func (c *externalCredential) statusStreamLoop() { + var consecutiveFailures int + ctx := c.getReverseContext() + for { + select { + case <-ctx.Done(): + return + default: + } + + result, err := c.connectStatusStream(ctx) + if ctx.Err() != nil { + return + } + var backoff time.Duration + var oneShot bool + consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures) + if oneShot { + c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff) + } else { + c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) + } + timer := time.NewTimer(backoff) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return + } + } +} + +func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) { + startTime := time.Now() + result := statusStreamResult{} + response, err := c.doStreamStatusRequest(ctx) + if err != nil { + result.duration = time.Since(startTime) + return result, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + result.duration = time.Since(startTime) + return result, E.New("status ", response.StatusCode, " ", string(body)) + } + + decoder := json.NewDecoder(response.Body) + isStatusStream := response.Header.Get(statusStreamHeader) == "true" + previousLastUpdated := c.lastUpdatedTime() + var firstFrameUpdatedAt time.Time + for { + var statusResponse struct { + FiveHourUtilization float64 `json:"five_hour_utilization"` + WeeklyUtilization float64 `json:"weekly_utilization"` + PlanWeight float64 `json:"plan_weight"` + } + err = decoder.Decode(&statusResponse) + if err != nil { + result.duration = time.Since(startTime) + if result.frames == 1 && err == io.EOF && !isStatusStream { + result.oneShot = true + c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated) + } + return result, err + } + + c.stateAccess.Lock() + c.state.consecutivePollFailures = 0 + c.state.fiveHourUtilization = statusResponse.FiveHourUtilization + c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight + } + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + result.frames++ + updatedAt := c.markUsageStreamUpdated() + if result.frames == 1 { + firstFrameUpdatedAt = updatedAt + } + c.emitStatusUpdate() + } +} + +func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) { + if result.oneShot { + return 0, c.pollInterval, true + } + if result.duration >= connectorBackoffResetThreshold { + consecutiveFailures = 0 + } + consecutiveFailures++ + return consecutiveFailures, connectorBackoff(consecutiveFailures), false +} + +func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) { + buildRequest := func(baseURL string) (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ccm/v1/status?watch=true", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+c.token) + return request, nil + } + if c.reverseHTTPClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + request, err := buildRequest(reverseProxyBaseURL) + if err != nil { + return nil, err + } + response, err := c.reverseHTTPClient.Do(request) + if err == nil { + return response, nil + } + } + } + if c.forwardHTTPClient != nil { + request, err := buildRequest(c.baseURL) + if err != nil { + return nil, err + } + return c.forwardHTTPClient.Do(request) + } + return nil, E.New("no transport available") } func (c *externalCredential) lastUpdatedTime() time.Time { @@ -603,6 +767,25 @@ func (c *externalCredential) lastUpdatedTime() time.Time { return c.state.lastUpdated } +func (c *externalCredential) markUsageStreamUpdated() time.Time { + c.stateAccess.Lock() + defer c.stateAccess.Unlock() + now := time.Now() + c.state.lastUpdated = now + return now +} + +func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) { + if expectedCurrent.IsZero() { + return + } + c.stateAccess.Lock() + defer c.stateAccess.Unlock() + if c.state.lastUpdated.Equal(expectedCurrent) { + c.state.lastUpdated = previous + } +} + func (c *externalCredential) markUsagePollAttempted() { c.stateAccess.Lock() defer c.stateAccess.Unlock() @@ -665,26 +848,40 @@ func (c *externalCredential) getReverseSession() *yamux.Session { } func (c *externalCredential) setReverseSession(session *yamux.Session) bool { + var emitStatus bool c.reverseAccess.Lock() if c.closed { c.reverseAccess.Unlock() return false } + wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() old := c.reverseSession c.reverseSession = session + isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() + emitStatus = wasAvailable != isAvailable c.reverseAccess.Unlock() if old != nil { old.Close() } + if emitStatus { + c.emitStatusUpdate() + } return true } func (c *externalCredential) clearReverseSession(session *yamux.Session) { + var emitStatus bool c.reverseAccess.Lock() + wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() if c.reverseSession == session { c.reverseSession = nil } + isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() + emitStatus = wasAvailable != isAvailable c.reverseAccess.Unlock() + if emitStatus { + c.emitStatusUpdate() + } } func (c *externalCredential) getReverseContext() context.Context { diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index 72d9da0100..c0ada93abe 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -111,12 +111,18 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.access.Unlock() c.stateAccess.Lock() + wasAvailable, oldWeight := c.statusAggregateStateLocked() c.state.unavailable = false c.state.lastCredentialLoadError = "" c.state.accountType = credentials.SubscriptionType c.state.rateLimitTier = credentials.RateLimitTier c.checkTransitionLocked() + isAvailable, newWeight := c.statusAggregateStateLocked() + shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight c.stateAccess.Unlock() + if shouldEmit { + c.emitStatusUpdate() + } return nil } @@ -128,16 +134,22 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error { c.access.Unlock() c.stateAccess.Lock() + wasAvailable, oldWeight := c.statusAggregateStateLocked() c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() c.state.accountType = "" c.state.rateLimitTier = "" shouldInterrupt := c.checkTransitionLocked() + isAvailable, newWeight := c.statusAggregateStateLocked() + shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight c.stateAccess.Unlock() if shouldInterrupt && hadCredentials { c.interruptConnections() } + if shouldEmit { + c.emitStatusUpdate() + } return err } diff --git a/service/ccm/credential_status_test.go b/service/ccm/credential_status_test.go new file mode 100644 index 0000000000..e675f73cc4 --- /dev/null +++ b/service/ccm/credential_status_test.go @@ -0,0 +1,280 @@ +package ccm + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common/observable" + + "github.com/hashicorp/yamux" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) { + return f(request) +} + +func drainStatusEvents(subscription observable.Subscription[struct{}]) int { + var count int + for { + select { + case <-subscription: + count++ + default: + return count + } + } +} + +func newTestLogger() log.ContextLogger { + return log.NewNOPFactory().Logger() +} + +func newTestCCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) { + t.Helper() + subscriber := observable.NewSubscriber[struct{}](8) + subscription, _ := subscriber.Subscription() + credential := &externalCredential{ + tag: "test", + baseURL: "http://example.com", + token: "token", + pollInterval: 25 * time.Millisecond, + forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) { + if request.URL.String() != "http://example.com/ccm/v1/status?watch=true" { + t.Fatalf("unexpected request URL: %s", request.URL.String()) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: headers.Clone(), + Body: io.NopCloser(strings.NewReader(body)), + }, nil + })}, + logger: newTestLogger(), + statusSubscriber: subscriber, + } + return credential, subscription +} + +func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) { + t.Helper() + clientConn, serverConn := net.Pipe() + clientSession, err := yamux.Client(clientConn, defaultYamuxConfig) + if err != nil { + t.Fatalf("create yamux client: %v", err) + } + serverSession, err := yamux.Server(serverConn, defaultYamuxConfig) + if err != nil { + clientSession.Close() + t.Fatalf("create yamux server: %v", err) + } + t.Cleanup(func() { + clientSession.Close() + serverSession.Close() + }) + return clientSession, serverSession +} + +func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) { + credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) + oldTime := time.Unix(123, 0) + credential.stateAccess.Lock() + credential.state.lastUpdated = oldTime + credential.stateAccess.Unlock() + + result, err := credential.connectStatusStream(context.Background()) + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } + if !result.oneShot { + t.Fatal("expected one-shot result") + } + if result.frames != 1 { + t.Fatalf("expected 1 frame, got %d", result.frames) + } + if !credential.lastUpdatedTime().Equal(oldTime) { + t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime()) + } + if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { + t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event, got %d", count) + } + + failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) + if !oneShot { + t.Fatal("expected one-shot backoff branch") + } + if failures != 0 { + t.Fatalf("expected failures reset, got %d", failures) + } + if backoff != credential.pollInterval { + t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff) + } +} + +func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { + headers := make(http.Header) + headers.Set(statusStreamHeader, "true") + credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers) + oldTime := time.Unix(123, 0) + credential.stateAccess.Lock() + credential.state.lastUpdated = oldTime + credential.stateAccess.Unlock() + + result, err := credential.connectStatusStream(context.Background()) + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } + if result.oneShot { + t.Fatal("did not expect one-shot result") + } + if result.frames != 1 { + t.Fatalf("expected 1 frame, got %d", result.frames) + } + if credential.lastUpdatedTime().Equal(oldTime) { + t.Fatal("expected lastUpdated to remain refreshed") + } + if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { + t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event, got %d", count) + } + + failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) + if oneShot { + t.Fatal("did not expect one-shot backoff branch") + } + if failures != 4 { + t.Fatalf("expected failures incremented to 4, got %d", failures) + } + if backoff < 16*time.Second || backoff >= 24*time.Second { + t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff) + } +} + +func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) { + credential, subscription := newTestCCMExternalCredential(t, strings.Join([]string{ + "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}", + "{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}", + }, "\n"), nil) + oldTime := time.Unix(123, 0) + credential.stateAccess.Lock() + credential.state.lastUpdated = oldTime + credential.stateAccess.Unlock() + + result, err := credential.connectStatusStream(context.Background()) + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } + if result.oneShot { + t.Fatal("did not expect one-shot result") + } + if result.frames != 2 { + t.Fatalf("expected 2 frames, got %d", result.frames) + } + if credential.lastUpdatedTime().Equal(oldTime) { + t.Fatal("expected lastUpdated to remain refreshed") + } + if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 { + t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) + } + if count := drainStatusEvents(subscription); count != 2 { + t.Fatalf("expected 2 status events, got %d", count) + } +} + +func TestDefaultCredentialStatusChangesEmitStatus(t *testing.T) { + credentialPath := filepath.Join(t.TempDir(), "credentials.json") + err := os.WriteFile(credentialPath, []byte("{\"claudeAiOauth\":{\"accessToken\":\"token\",\"refreshToken\":\"\",\"expiresAt\":0,\"subscriptionType\":\"max\"}}\n"), 0o600) + if err != nil { + t.Fatalf("write credential file: %v", err) + } + + subscriber := observable.NewSubscriber[struct{}](8) + subscription, _ := subscriber.Subscription() + credential := &defaultCredential{ + tag: "test", + credentialPath: credentialPath, + logger: newTestLogger(), + statusSubscriber: subscriber, + } + + err = credential.markCredentialsUnavailable(errors.New("boom")) + if err == nil { + t.Fatal("expected error from markCredentialsUnavailable") + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after unavailable transition, got %d", count) + } + + err = credential.reloadCredentials(true) + if err != nil { + t.Fatalf("reload credentials: %v", err) + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after recovery, got %d", count) + } + if weight := credential.planWeight(); weight != 5 { + t.Fatalf("expected initial max weight 5, got %v", weight) + } + + profileClient := &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader( + "{\"organization\":{\"organization_type\":\"claude_max\",\"rate_limit_tier\":\"default_claude_max_20x\"}}", + )), + }, nil + })} + credential.fetchProfile(context.Background(), profileClient, "token") + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after weight change, got %d", count) + } + if weight := credential.planWeight(); weight != 10 { + t.Fatalf("expected upgraded max weight 10, got %v", weight) + } +} + +func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) { + subscriber := observable.NewSubscriber[struct{}](8) + subscription, _ := subscriber.Subscription() + credential := &externalCredential{ + tag: "receiver", + baseURL: reverseProxyBaseURL, + pollInterval: time.Minute, + logger: newTestLogger(), + statusSubscriber: subscriber, + } + + clientSession, _ := newTestYamuxSessionPair(t) + if !credential.setReverseSession(clientSession) { + t.Fatal("expected reverse session to be accepted") + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after reverse session up, got %d", count) + } + if !credential.isAvailable() { + t.Fatal("expected receiver credential to become available") + } + + credential.clearReverseSession(clientSession) + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after reverse session down, got %d", count) + } + if credential.isAvailable() { + t.Fatal("expected receiver credential to become unavailable") + } +} diff --git a/service/ccm/service.go b/service/ccm/service.go index b1c637d141..8a9d8f17f0 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -6,7 +6,6 @@ import ( "net/http" "strings" - "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" "github.com/sagernet/sing-box/common/listener" @@ -17,6 +16,7 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/observable" aTLS "github.com/sagernet/sing/common/tls" "github.com/go-chi/chi/v5" @@ -150,18 +150,21 @@ func isAPIKeyHeader(header string) bool { type Service struct { boxService.Adapter - ctx context.Context - logger log.ContextLogger - options option.CCMServiceOptions - httpHeaders http.Header - listener *listener.Listener - tlsConfig tls.ServerConfig - httpServer *http.Server + ctx context.Context + logger log.ContextLogger + options option.CCMServiceOptions + httpHeaders http.Header + listener *listener.Listener + tlsConfig tls.ServerConfig + httpServer *http.Server userManager *UserManager providers map[string]credentialProvider allCredentials []Credential userConfigMap map[string]*option.CCMUser + + statusSubscriber *observable.Subscriber[struct{}] + statusObserver *observable.Observer[struct{}] } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) { @@ -195,6 +198,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio tokenMap: make(map[string]string), } + statusSubscriber := observable.NewSubscriber[struct{}](16) service := &Service{ Adapter: boxService.NewAdapter(C.TypeCCM, tag), ctx: ctx, @@ -207,7 +211,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio Network: []string{N.NetworkTCP}, Listen: options.ListenOptions, }), - userManager: userManager, + userManager: userManager, + statusSubscriber: statusSubscriber, + statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8), } providers, allCredentials, err := buildCredentialProviders(ctx, options, logger) @@ -242,6 +248,7 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) for _, credential := range s.allCredentials { + credential.setStatusSubscriber(s.statusSubscriber) if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil { external.reverseService = s } @@ -300,6 +307,7 @@ func (s *Service) InterfaceUpdated() { } func (s *Service) Close() error { + s.statusObserver.Close() err := common.Close( common.PtrOrNil(s.httpServer), common.PtrOrNil(s.listener), diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index 20c4780e53..82e3d5a0d4 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -1,6 +1,7 @@ package ccm import ( + "bytes" "encoding/json" "net/http" "strconv" @@ -55,6 +56,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { return } + if r.URL.Query().Get("watch") == "true" { + s.handleStatusStream(w, r, provider, userConfig) + return + } + provider.pollIfStale(r.Context()) avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) @@ -67,6 +73,76 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { }) } +func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) { + flusher, ok := w.(http.Flusher) + if !ok { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported") + return + } + + subscription, done, err := s.statusObserver.Subscribe() + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing") + return + } + defer s.statusObserver.UnSubscribe(subscription) + + provider.pollIfStale(r.Context()) + + w.Header().Set("Content-Type", "application/json") + w.Header().Set(statusStreamHeader, "true") + w.WriteHeader(http.StatusOK) + + lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig) + buf := &bytes.Buffer{} + json.NewEncoder(buf).Encode(map[string]float64{ + "five_hour_utilization": lastFiveHour, + "weekly_utilization": lastWeekly, + "plan_weight": lastWeight, + }) + _, writeErr := w.Write(buf.Bytes()) + if writeErr != nil { + return + } + flusher.Flush() + + for { + select { + case <-r.Context().Done(): + return + case <-done: + return + case <-subscription: + for { + select { + case <-subscription: + default: + goto drained + } + } + drained: + fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig) + if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight { + continue + } + lastFiveHour = fiveHour + lastWeekly = weekly + lastWeight = weight + buf.Reset() + json.NewEncoder(buf).Encode(map[string]float64{ + "five_hour_utilization": fiveHour, + "weekly_utilization": weekly, + "plan_weight": weight, + }) + _, writeErr = w.Write(buf.Bytes()) + if writeErr != nil { + return + } + flusher.Flush() + } + } +} + func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 for _, credential := range provider.allCredentials() { diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 8a56bfcec9..c777ea5c9b 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -9,6 +9,7 @@ import ( "time" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/observable" ) const ( @@ -118,6 +119,7 @@ type Credential interface { interruptConnections() setOnBecameUnusable(fn func()) + setStatusSubscriber(*observable.Subscriber[struct{}]) start() error pollUsage(ctx context.Context) lastUpdatedTime() time.Time diff --git a/service/ocm/credential_builder.go b/service/ocm/credential_builder.go index c800e039d6..fb69c62c93 100644 --- a/service/ocm/credential_builder.go +++ b/service/ocm/credential_builder.go @@ -192,4 +192,3 @@ func credentialForUser( } return provider, nil } - diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index ed4214fab7..0bfeebdc15 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -22,6 +22,7 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/common/observable" ) type defaultCredential struct { @@ -45,6 +46,8 @@ type defaultCredential struct { watcher *fswatch.Watcher watcherRetryAt time.Time + statusSubscriber *observable.Subscriber[struct{}] + // Connection interruption onBecameUnusable func() interrupted bool @@ -147,6 +150,16 @@ func (c *defaultCredential) setOnBecameUnusable(fn func()) { c.onBecameUnusable = fn } +func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) { + c.statusSubscriber = subscriber +} + +func (c *defaultCredential) emitStatusUpdate() { + if c.statusSubscriber != nil { + c.statusSubscriber.Emit(struct{}{}) + } +} + func (c *defaultCredential) tagName() string { return c.tag } @@ -202,11 +215,16 @@ func (c *defaultCredential) getAccessToken() (string, error) { if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { c.credentials = latestCredentials c.stateAccess.Lock() + wasAvailable := !c.state.unavailable c.state.unavailable = false c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.checkTransitionLocked() + shouldEmit := wasAvailable != !c.state.unavailable c.stateAccess.Unlock() + if shouldEmit { + c.emitStatusUpdate() + } if !latestCredentials.needsRefresh() { return latestCredentials.getAccessToken(), nil } @@ -215,11 +233,16 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.credentials = newCredentials c.stateAccess.Lock() + wasAvailable := !c.state.unavailable c.state.unavailable = false c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.checkTransitionLocked() + shouldEmit := wasAvailable != !c.state.unavailable c.stateAccess.Unlock() + if shouldEmit { + c.emitStatusUpdate() + } err = platformWriteCredentials(newCredentials, c.credentialPath) if err != nil { @@ -329,6 +352,9 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if shouldInterrupt { c.interruptConnections() } + if hadData { + c.emitStatusUpdate() + } } func (c *defaultCredential) markRateLimited(resetAt time.Time) { @@ -341,6 +367,7 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { if shouldInterrupt { c.interruptConnections() } + c.emitStatusUpdate() } func (c *defaultCredential) isUsable() bool { @@ -692,6 +719,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { if shouldInterrupt { c.interruptConnections() } + c.emitStatusUpdate() } func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index d924ae4ee1..a4835df7aa 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -23,11 +23,15 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/common/observable" "github.com/hashicorp/yamux" ) -const reverseProxyBaseURL = "http://reverse-proxy" +const ( + reverseProxyBaseURL = "http://reverse-proxy" + statusStreamHeader = "X-OCM-Status-Stream" +) type externalCredential struct { tag string @@ -42,6 +46,7 @@ type externalCredential struct { usageTracker *AggregatedUsage logger log.ContextLogger + statusSubscriber *observable.Subscriber[struct{}] onBecameUnusable func() interrupted bool requestContext context.Context @@ -69,6 +74,12 @@ type reverseSessionDialer struct { credential *externalCredential } +type statusStreamResult struct { + duration time.Duration + frames int + oneShot bool +} + func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { if N.NetworkName(network) != N.NetworkTCP { return nil, os.ErrInvalid @@ -249,6 +260,8 @@ func (c *externalCredential) start() error { } if c.reverse && c.connectorURL != nil { go c.connectorLoop() + } else { + go c.statusStreamLoop() } return nil } @@ -257,6 +270,16 @@ func (c *externalCredential) setOnBecameUnusable(fn func()) { c.onBecameUnusable = fn } +func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) { + c.statusSubscriber = subscriber +} + +func (c *externalCredential) emitStatusUpdate() { + if c.statusSubscriber != nil { + c.statusSubscriber.Emit(struct{}{}) + } +} + func (c *externalCredential) tagName() string { return c.tag } @@ -342,6 +365,7 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { if shouldInterrupt { c.interruptConnections() } + c.emitStatusUpdate() } func (c *externalCredential) earliestReset() time.Time { @@ -498,6 +522,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { if shouldInterrupt { c.interruptConnections() } + if hadData { + c.emitStatusUpdate() + } } func (c *externalCredential) checkTransitionLocked() bool { @@ -638,6 +665,142 @@ func (c *externalCredential) pollUsage(ctx context.Context) { if shouldInterrupt { c.interruptConnections() } + c.emitStatusUpdate() +} + +func (c *externalCredential) statusStreamLoop() { + var consecutiveFailures int + ctx := c.getReverseContext() + for { + select { + case <-ctx.Done(): + return + default: + } + + result, err := c.connectStatusStream(ctx) + if ctx.Err() != nil { + return + } + var backoff time.Duration + var oneShot bool + consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures) + if oneShot { + c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff) + } else { + c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) + } + timer := time.NewTimer(backoff) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return + } + } +} + +func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) { + startTime := time.Now() + result := statusStreamResult{} + response, err := c.doStreamStatusRequest(ctx) + if err != nil { + result.duration = time.Since(startTime) + return result, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + result.duration = time.Since(startTime) + return result, E.New("status ", response.StatusCode, " ", string(body)) + } + + decoder := json.NewDecoder(response.Body) + isStatusStream := response.Header.Get(statusStreamHeader) == "true" + previousLastUpdated := c.lastUpdatedTime() + var firstFrameUpdatedAt time.Time + for { + var statusResponse struct { + FiveHourUtilization float64 `json:"five_hour_utilization"` + WeeklyUtilization float64 `json:"weekly_utilization"` + PlanWeight float64 `json:"plan_weight"` + } + err = decoder.Decode(&statusResponse) + if err != nil { + result.duration = time.Since(startTime) + if result.frames == 1 && err == io.EOF && !isStatusStream { + result.oneShot = true + c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated) + } + return result, err + } + + c.stateAccess.Lock() + c.state.consecutivePollFailures = 0 + c.state.fiveHourUtilization = statusResponse.FiveHourUtilization + c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight + } + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + result.frames++ + updatedAt := c.markUsageStreamUpdated() + if result.frames == 1 { + firstFrameUpdatedAt = updatedAt + } + c.emitStatusUpdate() + } +} + +func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) { + if result.oneShot { + return 0, c.pollInterval, true + } + if result.duration >= connectorBackoffResetThreshold { + consecutiveFailures = 0 + } + consecutiveFailures++ + return consecutiveFailures, connectorBackoff(consecutiveFailures), false +} + +func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) { + buildRequest := func(baseURL string) (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ocm/v1/status?watch=true", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+c.token) + return request, nil + } + if c.reverseHTTPClient != nil { + session := c.getReverseSession() + if session != nil && !session.IsClosed() { + request, err := buildRequest(reverseProxyBaseURL) + if err != nil { + return nil, err + } + response, err := c.reverseHTTPClient.Do(request) + if err == nil { + return response, nil + } + } + } + if c.forwardHTTPClient != nil { + request, err := buildRequest(c.baseURL) + if err != nil { + return nil, err + } + return c.forwardHTTPClient.Do(request) + } + return nil, E.New("no transport available") } func (c *externalCredential) lastUpdatedTime() time.Time { @@ -646,6 +809,25 @@ func (c *externalCredential) lastUpdatedTime() time.Time { return c.state.lastUpdated } +func (c *externalCredential) markUsageStreamUpdated() time.Time { + c.stateAccess.Lock() + defer c.stateAccess.Unlock() + now := time.Now() + c.state.lastUpdated = now + return now +} + +func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) { + if expectedCurrent.IsZero() { + return + } + c.stateAccess.Lock() + defer c.stateAccess.Unlock() + if c.state.lastUpdated.Equal(expectedCurrent) { + c.state.lastUpdated = previous + } +} + func (c *externalCredential) markUsagePollAttempted() { c.stateAccess.Lock() defer c.stateAccess.Unlock() @@ -736,26 +918,40 @@ func (c *externalCredential) getReverseSession() *yamux.Session { } func (c *externalCredential) setReverseSession(session *yamux.Session) bool { + var emitStatus bool c.reverseAccess.Lock() if c.closed { c.reverseAccess.Unlock() return false } + wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() old := c.reverseSession c.reverseSession = session + isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() + emitStatus = wasAvailable != isAvailable c.reverseAccess.Unlock() if old != nil { old.Close() } + if emitStatus { + c.emitStatusUpdate() + } return true } func (c *externalCredential) clearReverseSession(session *yamux.Session) { + var emitStatus bool c.reverseAccess.Lock() + wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() if c.reverseSession == session { c.reverseSession = nil } + isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() + emitStatus = wasAvailable != isAvailable c.reverseAccess.Unlock() + if emitStatus { + c.emitStatusUpdate() + } } func (c *externalCredential) getReverseContext() context.Context { diff --git a/service/ocm/credential_file.go b/service/ocm/credential_file.go index 861dbdb864..b15417a46f 100644 --- a/service/ocm/credential_file.go +++ b/service/ocm/credential_file.go @@ -111,10 +111,15 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.access.Unlock() c.stateAccess.Lock() + wasAvailable := !c.state.unavailable c.state.unavailable = false c.state.lastCredentialLoadError = "" c.checkTransitionLocked() + shouldEmit := wasAvailable != !c.state.unavailable c.stateAccess.Unlock() + if shouldEmit { + c.emitStatusUpdate() + } return nil } @@ -126,14 +131,19 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error { c.access.Unlock() c.stateAccess.Lock() + wasAvailable := !c.state.unavailable c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() shouldInterrupt := c.checkTransitionLocked() + shouldEmit := wasAvailable != !c.state.unavailable c.stateAccess.Unlock() if shouldInterrupt && hadCredentials { c.interruptConnections() } + if shouldEmit { + c.emitStatusUpdate() + } return err } diff --git a/service/ocm/credential_status_test.go b/service/ocm/credential_status_test.go new file mode 100644 index 0000000000..d45fdebf03 --- /dev/null +++ b/service/ocm/credential_status_test.go @@ -0,0 +1,263 @@ +package ocm + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common/observable" + + "github.com/hashicorp/yamux" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) { + return f(request) +} + +func drainStatusEvents(subscription observable.Subscription[struct{}]) int { + var count int + for { + select { + case <-subscription: + count++ + default: + return count + } + } +} + +func newTestLogger() log.ContextLogger { + return log.NewNOPFactory().Logger() +} + +func newTestOCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) { + t.Helper() + subscriber := observable.NewSubscriber[struct{}](8) + subscription, _ := subscriber.Subscription() + credential := &externalCredential{ + tag: "test", + baseURL: "http://example.com", + token: "token", + pollInterval: 25 * time.Millisecond, + forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) { + if request.URL.String() != "http://example.com/ocm/v1/status?watch=true" { + t.Fatalf("unexpected request URL: %s", request.URL.String()) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: headers.Clone(), + Body: io.NopCloser(strings.NewReader(body)), + }, nil + })}, + logger: newTestLogger(), + statusSubscriber: subscriber, + } + return credential, subscription +} + +func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) { + t.Helper() + clientConn, serverConn := net.Pipe() + clientSession, err := yamux.Client(clientConn, defaultYamuxConfig) + if err != nil { + t.Fatalf("create yamux client: %v", err) + } + serverSession, err := yamux.Server(serverConn, defaultYamuxConfig) + if err != nil { + clientSession.Close() + t.Fatalf("create yamux server: %v", err) + } + t.Cleanup(func() { + clientSession.Close() + serverSession.Close() + }) + return clientSession, serverSession +} + +func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) { + credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) + oldTime := time.Unix(123, 0) + credential.stateAccess.Lock() + credential.state.lastUpdated = oldTime + credential.stateAccess.Unlock() + + result, err := credential.connectStatusStream(context.Background()) + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } + if !result.oneShot { + t.Fatal("expected one-shot result") + } + if result.frames != 1 { + t.Fatalf("expected 1 frame, got %d", result.frames) + } + if !credential.lastUpdatedTime().Equal(oldTime) { + t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime()) + } + if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { + t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event, got %d", count) + } + + failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) + if !oneShot { + t.Fatal("expected one-shot backoff branch") + } + if failures != 0 { + t.Fatalf("expected failures reset, got %d", failures) + } + if backoff != credential.pollInterval { + t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff) + } +} + +func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { + headers := make(http.Header) + headers.Set(statusStreamHeader, "true") + credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers) + oldTime := time.Unix(123, 0) + credential.stateAccess.Lock() + credential.state.lastUpdated = oldTime + credential.stateAccess.Unlock() + + result, err := credential.connectStatusStream(context.Background()) + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } + if result.oneShot { + t.Fatal("did not expect one-shot result") + } + if result.frames != 1 { + t.Fatalf("expected 1 frame, got %d", result.frames) + } + if credential.lastUpdatedTime().Equal(oldTime) { + t.Fatal("expected lastUpdated to remain refreshed") + } + if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { + t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event, got %d", count) + } + + failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) + if oneShot { + t.Fatal("did not expect one-shot backoff branch") + } + if failures != 4 { + t.Fatalf("expected failures incremented to 4, got %d", failures) + } + if backoff < 16*time.Second || backoff >= 24*time.Second { + t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff) + } +} + +func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) { + credential, subscription := newTestOCMExternalCredential(t, strings.Join([]string{ + "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}", + "{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}", + }, "\n"), nil) + oldTime := time.Unix(123, 0) + credential.stateAccess.Lock() + credential.state.lastUpdated = oldTime + credential.stateAccess.Unlock() + + result, err := credential.connectStatusStream(context.Background()) + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } + if result.oneShot { + t.Fatal("did not expect one-shot result") + } + if result.frames != 2 { + t.Fatalf("expected 2 frames, got %d", result.frames) + } + if credential.lastUpdatedTime().Equal(oldTime) { + t.Fatal("expected lastUpdated to remain refreshed") + } + if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 { + t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) + } + if count := drainStatusEvents(subscription); count != 2 { + t.Fatalf("expected 2 status events, got %d", count) + } +} + +func TestDefaultCredentialAvailabilityChangesEmitStatus(t *testing.T) { + credentialPath := filepath.Join(t.TempDir(), "auth.json") + err := os.WriteFile(credentialPath, []byte("{\"OPENAI_API_KEY\":\"sk-test\"}\n"), 0o600) + if err != nil { + t.Fatalf("write credential file: %v", err) + } + + subscriber := observable.NewSubscriber[struct{}](8) + subscription, _ := subscriber.Subscription() + credential := &defaultCredential{ + tag: "test", + credentialPath: credentialPath, + logger: newTestLogger(), + statusSubscriber: subscriber, + } + + err = credential.markCredentialsUnavailable(errors.New("boom")) + if err == nil { + t.Fatal("expected error from markCredentialsUnavailable") + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after unavailable transition, got %d", count) + } + + err = credential.reloadCredentials(true) + if err != nil { + t.Fatalf("reload credentials: %v", err) + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after recovery, got %d", count) + } + if !credential.isAvailable() { + t.Fatal("expected credential to become available") + } +} + +func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) { + subscriber := observable.NewSubscriber[struct{}](8) + subscription, _ := subscriber.Subscription() + credential := &externalCredential{ + tag: "receiver", + baseURL: reverseProxyBaseURL, + pollInterval: time.Minute, + logger: newTestLogger(), + statusSubscriber: subscriber, + } + + clientSession, _ := newTestYamuxSessionPair(t) + if !credential.setReverseSession(clientSession) { + t.Fatal("expected reverse session to be accepted") + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after reverse session up, got %d", count) + } + if !credential.isAvailable() { + t.Fatal("expected receiver credential to become available") + } + + credential.clearReverseSession(clientSession) + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event after reverse session down, got %d", count) + } + if credential.isAvailable() { + t.Fatal("expected receiver credential to become unavailable") + } +} diff --git a/service/ocm/service.go b/service/ocm/service.go index 641872e5db..7c54115a3b 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -18,6 +18,7 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/observable" aTLS "github.com/sagernet/sing/common/tls" "github.com/go-chi/chi/v5" @@ -176,9 +177,11 @@ type Service struct { webSocketConns map[*webSocketSession]struct{} shuttingDown bool - providers map[string]credentialProvider - allCredentials []Credential - userConfigMap map[string]*option.OCMUser + providers map[string]credentialProvider + allCredentials []Credential + userConfigMap map[string]*option.OCMUser + statusSubscriber *observable.Subscriber[struct{}] + statusObserver *observable.Observer[struct{}] } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) { @@ -210,6 +213,8 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio tokenMap: make(map[string]string), } + statusSubscriber := observable.NewSubscriber[struct{}](16) + service := &Service{ Adapter: boxService.NewAdapter(C.TypeOCM, tag), ctx: ctx, @@ -222,8 +227,10 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio Network: []string{N.NetworkTCP}, Listen: options.ListenOptions, }), - userManager: userManager, - webSocketConns: make(map[*webSocketSession]struct{}), + userManager: userManager, + statusSubscriber: statusSubscriber, + statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8), + webSocketConns: make(map[*webSocketSession]struct{}), } providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger) @@ -258,6 +265,7 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) for _, credential := range s.allCredentials { + credential.setStatusSubscriber(s.statusSubscriber) if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil { external.reverseService = s } @@ -324,6 +332,7 @@ func (s *Service) InterfaceUpdated() { } func (s *Service) Close() error { + s.statusObserver.Close() webSocketSessions := s.startWebSocketShutdown() err := common.Close( diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index 7057b60f72..773862979b 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -1,6 +1,7 @@ package ocm import ( + "bytes" "encoding/json" "net/http" "strconv" @@ -55,6 +56,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { return } + if r.URL.Query().Get("watch") == "true" { + s.handleStatusStream(w, r, provider, userConfig) + return + } + provider.pollIfStale(r.Context()) avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) @@ -67,6 +73,76 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { }) } +func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.OCMUser) { + flusher, ok := w.(http.Flusher) + if !ok { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported") + return + } + + subscription, done, err := s.statusObserver.Subscribe() + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing") + return + } + defer s.statusObserver.UnSubscribe(subscription) + + provider.pollIfStale(r.Context()) + + w.Header().Set("Content-Type", "application/json") + w.Header().Set(statusStreamHeader, "true") + w.WriteHeader(http.StatusOK) + + lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig) + buf := &bytes.Buffer{} + json.NewEncoder(buf).Encode(map[string]float64{ + "five_hour_utilization": lastFiveHour, + "weekly_utilization": lastWeekly, + "plan_weight": lastWeight, + }) + _, writeErr := w.Write(buf.Bytes()) + if writeErr != nil { + return + } + flusher.Flush() + + for { + select { + case <-r.Context().Done(): + return + case <-done: + return + case <-subscription: + for { + select { + case <-subscription: + default: + goto drained + } + } + drained: + fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig) + if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight { + continue + } + lastFiveHour = fiveHour + lastWeekly = weekly + lastWeight = weight + buf.Reset() + json.NewEncoder(buf).Encode(map[string]float64{ + "five_hour_utilization": fiveHour, + "weekly_utilization": weekly, + "plan_weight": weight, + }) + _, writeErr = w.Write(buf.Bytes()) + if writeErr != nil { + return + } + flusher.Flush() + } + } +} + func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 for _, credential := range provider.allCredentials() { From 4a6a211775fb913d0c338f0c1b14ab8af81f3a79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 16:10:59 +0800 Subject: [PATCH 57/96] ccm,ocm: reduce status emission noise, simplify emit-guard pattern Guard updateStateFromHeaders emission with value-change detection to avoid unnecessary computeAggregatedUtilization scans on every proxied response. Replace statusAggregateStateLocked two-value return with comparable statusSnapshot struct. Define statusPayload type for the status wire format, replacing anonymous structs and map literals. --- service/ccm/credential_default.go | 29 ++++++++++++++++------------- service/ccm/credential_external.go | 10 ++++------ service/ccm/credential_file.go | 10 ++++------ service/ccm/service_status.go | 30 ++++++++++++++++++------------ service/ocm/credential_default.go | 3 ++- service/ocm/credential_external.go | 10 ++++------ service/ocm/service_status.go | 30 ++++++++++++++++++------------ 7 files changed, 66 insertions(+), 56 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 4a94050738..f23ba3f652 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -152,11 +152,16 @@ func (c *defaultCredential) emitStatusUpdate() { } } -func (c *defaultCredential) statusAggregateStateLocked() (bool, float64) { +type statusSnapshot struct { + available bool + weight float64 +} + +func (c *defaultCredential) statusSnapshotLocked() statusSnapshot { if c.state.unavailable { - return false, 0 + return statusSnapshot{} } - return true, ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) + return statusSnapshot{true, ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)} } func (c *defaultCredential) getAccessToken() (string, error) { @@ -206,15 +211,14 @@ func (c *defaultCredential) getAccessToken() (string, error) { if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { c.credentials = latestCredentials c.stateAccess.Lock() - wasAvailable, oldWeight := c.statusAggregateStateLocked() + before := c.statusSnapshotLocked() c.state.unavailable = false c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.state.accountType = latestCredentials.SubscriptionType c.state.rateLimitTier = latestCredentials.RateLimitTier c.checkTransitionLocked() - isAvailable, newWeight := c.statusAggregateStateLocked() - shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight + shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() if shouldEmit { c.emitStatusUpdate() @@ -227,15 +231,14 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.credentials = newCredentials c.stateAccess.Lock() - wasAvailable, oldWeight := c.statusAggregateStateLocked() + before := c.statusSnapshotLocked() c.state.unavailable = false c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.state.accountType = newCredentials.SubscriptionType c.state.rateLimitTier = newCredentials.RateLimitTier c.checkTransitionLocked() - isAvailable, newWeight := c.statusAggregateStateLocked() - shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight + shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() if shouldEmit { c.emitStatusUpdate() @@ -304,12 +307,13 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { } c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } + shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } - if hadData { + if shouldEmit { c.emitStatusUpdate() } } @@ -673,7 +677,7 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C rateLimitTier := profileResponse.Organization.RateLimitTier c.stateAccess.Lock() - wasAvailable, oldWeight := c.statusAggregateStateLocked() + before := c.statusSnapshotLocked() if accountType != "" && c.state.accountType == "" { c.state.accountType = accountType } @@ -681,8 +685,7 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C c.state.rateLimitTier = rateLimitTier } resolvedAccountType := c.state.accountType - isAvailable, newWeight := c.statusAggregateStateLocked() - shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight + shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() if shouldEmit { c.emitStatusUpdate() diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index c66a54e8b3..fbd82e60ac 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -436,6 +436,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization + oldPlanWeight := c.state.remotePlanWeight hadData := false if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { @@ -478,12 +479,13 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } + shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly || c.state.remotePlanWeight != oldPlanWeight) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } - if hadData { + if shouldEmit { c.emitStatusUpdate() } } @@ -679,11 +681,7 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr previousLastUpdated := c.lastUpdatedTime() var firstFrameUpdatedAt time.Time for { - var statusResponse struct { - FiveHourUtilization float64 `json:"five_hour_utilization"` - WeeklyUtilization float64 `json:"weekly_utilization"` - PlanWeight float64 `json:"plan_weight"` - } + var statusResponse statusPayload err = decoder.Decode(&statusResponse) if err != nil { result.duration = time.Since(startTime) diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index c0ada93abe..4a65314712 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -111,14 +111,13 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.access.Unlock() c.stateAccess.Lock() - wasAvailable, oldWeight := c.statusAggregateStateLocked() + before := c.statusSnapshotLocked() c.state.unavailable = false c.state.lastCredentialLoadError = "" c.state.accountType = credentials.SubscriptionType c.state.rateLimitTier = credentials.RateLimitTier c.checkTransitionLocked() - isAvailable, newWeight := c.statusAggregateStateLocked() - shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight + shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() if shouldEmit { c.emitStatusUpdate() @@ -134,14 +133,13 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error { c.access.Unlock() c.stateAccess.Lock() - wasAvailable, oldWeight := c.statusAggregateStateLocked() + before := c.statusSnapshotLocked() c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() c.state.accountType = "" c.state.rateLimitTier = "" shouldInterrupt := c.checkTransitionLocked() - isAvailable, newWeight := c.statusAggregateStateLocked() - shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight + shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() if shouldInterrupt && hadCredentials { diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index 82e3d5a0d4..b5771d4b84 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -10,6 +10,12 @@ import ( "github.com/sagernet/sing-box/option" ) +type statusPayload struct { + FiveHourUtilization float64 `json:"five_hour_utilization"` + WeeklyUtilization float64 `json:"weekly_utilization"` + PlanWeight float64 `json:"plan_weight"` +} + func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -66,10 +72,10 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]float64{ - "five_hour_utilization": avgFiveHour, - "weekly_utilization": avgWeekly, - "plan_weight": totalWeight, + json.NewEncoder(w).Encode(statusPayload{ + FiveHourUtilization: avgFiveHour, + WeeklyUtilization: avgWeekly, + PlanWeight: totalWeight, }) } @@ -95,10 +101,10 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig) buf := &bytes.Buffer{} - json.NewEncoder(buf).Encode(map[string]float64{ - "five_hour_utilization": lastFiveHour, - "weekly_utilization": lastWeekly, - "plan_weight": lastWeight, + json.NewEncoder(buf).Encode(statusPayload{ + FiveHourUtilization: lastFiveHour, + WeeklyUtilization: lastWeekly, + PlanWeight: lastWeight, }) _, writeErr := w.Write(buf.Bytes()) if writeErr != nil { @@ -129,10 +135,10 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro lastWeekly = weekly lastWeight = weight buf.Reset() - json.NewEncoder(buf).Encode(map[string]float64{ - "five_hour_utilization": fiveHour, - "weekly_utilization": weekly, - "plan_weight": weight, + json.NewEncoder(buf).Encode(statusPayload{ + FiveHourUtilization: fiveHour, + WeeklyUtilization: weekly, + PlanWeight: weight, }) _, writeErr = w.Write(buf.Bytes()) if writeErr != nil { diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 0bfeebdc15..89daf56e79 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -347,12 +347,13 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { } c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } + shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } - if hadData { + if shouldEmit { c.emitStatusUpdate() } } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index a4835df7aa..f4a3889f8b 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -460,6 +460,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization + oldPlanWeight := c.state.remotePlanWeight hadData := false activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) @@ -517,12 +518,13 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } + shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly || c.state.remotePlanWeight != oldPlanWeight) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } - if hadData { + if shouldEmit { c.emitStatusUpdate() } } @@ -721,11 +723,7 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr previousLastUpdated := c.lastUpdatedTime() var firstFrameUpdatedAt time.Time for { - var statusResponse struct { - FiveHourUtilization float64 `json:"five_hour_utilization"` - WeeklyUtilization float64 `json:"weekly_utilization"` - PlanWeight float64 `json:"plan_weight"` - } + var statusResponse statusPayload err = decoder.Decode(&statusResponse) if err != nil { result.duration = time.Since(startTime) diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index 773862979b..bc167a65ca 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -10,6 +10,12 @@ import ( "github.com/sagernet/sing-box/option" ) +type statusPayload struct { + FiveHourUtilization float64 `json:"five_hour_utilization"` + WeeklyUtilization float64 `json:"weekly_utilization"` + PlanWeight float64 `json:"plan_weight"` +} + func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -66,10 +72,10 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]float64{ - "five_hour_utilization": avgFiveHour, - "weekly_utilization": avgWeekly, - "plan_weight": totalWeight, + json.NewEncoder(w).Encode(statusPayload{ + FiveHourUtilization: avgFiveHour, + WeeklyUtilization: avgWeekly, + PlanWeight: totalWeight, }) } @@ -95,10 +101,10 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig) buf := &bytes.Buffer{} - json.NewEncoder(buf).Encode(map[string]float64{ - "five_hour_utilization": lastFiveHour, - "weekly_utilization": lastWeekly, - "plan_weight": lastWeight, + json.NewEncoder(buf).Encode(statusPayload{ + FiveHourUtilization: lastFiveHour, + WeeklyUtilization: lastWeekly, + PlanWeight: lastWeight, }) _, writeErr := w.Write(buf.Bytes()) if writeErr != nil { @@ -129,10 +135,10 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro lastWeekly = weekly lastWeight = weight buf.Reset() - json.NewEncoder(buf).Encode(map[string]float64{ - "five_hour_utilization": fiveHour, - "weekly_utilization": weekly, - "plan_weight": weight, + json.NewEncoder(buf).Encode(statusPayload{ + FiveHourUtilization: fiveHour, + WeeklyUtilization: weekly, + PlanWeight: weight, }) _, writeErr = w.Write(buf.Bytes()) if writeErr != nil { From cf2d677043767e19736b9cda96569bbdae59f7f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 16:32:03 +0800 Subject: [PATCH 58/96] ocm: emit status updates for plan-weight-only changes --- service/ocm/credential_external.go | 4 +++- service/ocm/credential_status_test.go | 33 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index f4a3889f8b..dd13aca60d 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -518,7 +518,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } - shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly || c.state.remotePlanWeight != oldPlanWeight) + utilizationChanged := c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly + planWeightChanged := c.state.remotePlanWeight != oldPlanWeight + shouldEmit := (hadData && utilizationChanged) || planWeightChanged shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { diff --git a/service/ocm/credential_status_test.go b/service/ocm/credential_status_test.go index d45fdebf03..955338fce5 100644 --- a/service/ocm/credential_status_test.go +++ b/service/ocm/credential_status_test.go @@ -195,6 +195,39 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test } } +func TestExternalCredentialPlanWeightOnlyRateLimitsEventEmitsStatus(t *testing.T) { + subscriber := observable.NewSubscriber[struct{}](8) + subscription, _ := subscriber.Subscription() + credential := &externalCredential{ + tag: "test", + logger: newTestLogger(), + statusSubscriber: subscriber, + } + credential.stateAccess.Lock() + credential.state.remotePlanWeight = 2 + oldTime := time.Unix(123, 0) + credential.state.lastUpdated = oldTime + credential.stateAccess.Unlock() + + (&Service{}).handleWebSocketRateLimitsEvent([]byte(`{"plan_weight":3}`), credential) + + if weight := credential.planWeight(); weight != 3 { + t.Fatalf("expected plan weight 3, got %v", weight) + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event, got %d", count) + } + if !credential.lastUpdatedTime().Equal(oldTime) { + t.Fatalf("expected lastUpdated to stay %v, got %v", oldTime, credential.lastUpdatedTime()) + } + + (&Service{}).handleWebSocketRateLimitsEvent([]byte(`{"plan_weight":3}`), credential) + + if count := drainStatusEvents(subscription); count != 0 { + t.Fatalf("expected no status event for unchanged plan weight, got %d", count) + } +} + func TestDefaultCredentialAvailabilityChangesEmitStatus(t *testing.T) { credentialPath := filepath.Join(t.TempDir(), "auth.json") err := os.WriteFile(credentialPath, []byte("{\"OPENAI_API_KEY\":\"sk-test\"}\n"), 0o600) From 7d15d9d282ef533db1ccc2045198320bc2dd9c21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 16:46:54 +0800 Subject: [PATCH 59/96] ccm: emit status updates for plan-weight-only changes --- service/ccm/credential_external.go | 4 ++- service/ccm/credential_status_test.go | 35 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index fbd82e60ac..3d39a88ca2 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -479,7 +479,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } - shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly || c.state.remotePlanWeight != oldPlanWeight) + utilizationChanged := c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly + planWeightChanged := c.state.remotePlanWeight != oldPlanWeight + shouldEmit := (hadData && utilizationChanged) || planWeightChanged shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { diff --git a/service/ccm/credential_status_test.go b/service/ccm/credential_status_test.go index e675f73cc4..f92b27e85c 100644 --- a/service/ccm/credential_status_test.go +++ b/service/ccm/credential_status_test.go @@ -195,6 +195,41 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test } } +func TestExternalCredentialPlanWeightOnlyHeaderEmitsStatus(t *testing.T) { + subscriber := observable.NewSubscriber[struct{}](8) + subscription, _ := subscriber.Subscription() + credential := &externalCredential{ + tag: "test", + logger: newTestLogger(), + statusSubscriber: subscriber, + } + credential.stateAccess.Lock() + credential.state.remotePlanWeight = 2 + oldTime := time.Unix(123, 0) + credential.state.lastUpdated = oldTime + credential.stateAccess.Unlock() + + headers := make(http.Header) + headers.Set("X-CCM-Plan-Weight", "3") + credential.updateStateFromHeaders(headers) + + if weight := credential.planWeight(); weight != 3 { + t.Fatalf("expected plan weight 3, got %v", weight) + } + if count := drainStatusEvents(subscription); count != 1 { + t.Fatalf("expected 1 status event, got %d", count) + } + if !credential.lastUpdatedTime().Equal(oldTime) { + t.Fatalf("expected lastUpdated to stay %v, got %v", oldTime, credential.lastUpdatedTime()) + } + + credential.updateStateFromHeaders(headers) + + if count := drainStatusEvents(subscription); count != 0 { + t.Fatalf("expected no status event for unchanged plan weight, got %d", count) + } +} + func TestDefaultCredentialStatusChangesEmitStatus(t *testing.T) { credentialPath := filepath.Join(t.TempDir(), "credentials.json") err := os.WriteFile(credentialPath, []byte("{\"claudeAiOauth\":{\"accessToken\":\"token\",\"refreshToken\":\"\",\"expiresAt\":0,\"subscriptionType\":\"max\"}}\n"), 0o600) From 0a054b9aa4fd3c5f60e47c3c139295166cb66227 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 18:13:54 +0800 Subject: [PATCH 60/96] ccm,ocm: propagate reset times, rewrite headers for all users, add WS status push - Add fiveHourReset/weeklyReset to statusPayload and aggregatedStatus with weight-averaged reset time aggregation across credential pools - Rewrite response headers (utilization + reset times) for all users, not just external credential users - Rewrite WebSocket rate_limits events for all users with aggregated values - Add proactive WebSocket status push: synthetic codex.rate_limits events sent on connection start and on status changes via statusObserver - Remove one-shot stream forward compatibility (statusStreamHeader, restoreLastUpdatedIfUnchanged, oneShot detection) --- service/ccm/credential.go | 1 + service/ccm/credential_default.go | 6 ++ service/ccm/credential_external.go | 71 ++++++-------- service/ccm/credential_status_test.go | 54 +---------- service/ccm/service_handler.go | 5 +- service/ccm/service_status.go | 132 ++++++++++++++++++------- service/ocm/credential.go | 1 + service/ocm/credential_default.go | 6 ++ service/ocm/credential_external.go | 71 ++++++-------- service/ocm/credential_status_test.go | 54 +---------- service/ocm/service_handler.go | 5 +- service/ocm/service_status.go | 133 +++++++++++++++++++------- service/ocm/service_websocket.go | 129 +++++++++++++++++++++---- 13 files changed, 381 insertions(+), 287 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 9defed434c..6f41ba1283 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -105,6 +105,7 @@ type Credential interface { fiveHourCap() float64 weeklyCap() float64 planWeight() float64 + fiveHourResetTime() time.Time weeklyResetTime() time.Time markRateLimited(resetAt time.Time) earliestReset() time.Time diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index f23ba3f652..bf88fc836c 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -421,6 +421,12 @@ func (c *defaultCredential) planWeight() float64 { return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) } +func (c *defaultCredential) fiveHourResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourReset +} + func (c *defaultCredential) weeklyResetTime() time.Time { c.stateAccess.RLock() defer c.stateAccess.RUnlock() diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 3d39a88ca2..cdbbc48449 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -27,10 +27,7 @@ import ( "github.com/hashicorp/yamux" ) -const ( - reverseProxyBaseURL = "http://reverse-proxy" - statusStreamHeader = "X-CCM-Status-Stream" -) +const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { tag string @@ -70,7 +67,6 @@ type externalCredential struct { type statusStreamResult struct { duration time.Duration frames int - oneShot bool } func externalCredentialURLPort(parsedURL *url.URL) uint16 { @@ -325,6 +321,12 @@ func (c *externalCredential) planWeight() float64 { return 10 } +func (c *externalCredential) fiveHourResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourReset +} + func (c *externalCredential) weeklyResetTime() time.Time { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -592,7 +594,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } err = json.NewDecoder(response.Body).Decode(&statusResponse) @@ -612,6 +616,12 @@ func (c *externalCredential) pollUsage(ctx context.Context) { if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } @@ -645,13 +655,8 @@ func (c *externalCredential) statusStreamLoop() { return } var backoff time.Duration - var oneShot bool - consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures) - if oneShot { - c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff) - } else { - c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) - } + consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures) + c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) timer := time.NewTimer(backoff) select { case <-timer.C: @@ -679,18 +684,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } decoder := json.NewDecoder(response.Body) - isStatusStream := response.Header.Get(statusStreamHeader) == "true" - previousLastUpdated := c.lastUpdatedTime() - var firstFrameUpdatedAt time.Time for { var statusResponse statusPayload err = decoder.Decode(&statusResponse) if err != nil { result.duration = time.Since(startTime) - if result.frames == 1 && err == io.EOF && !isStatusStream { - result.oneShot = true - c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated) - } return result, err } @@ -701,6 +699,12 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } @@ -710,23 +714,17 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr c.interruptConnections() } result.frames++ - updatedAt := c.markUsageStreamUpdated() - if result.frames == 1 { - firstFrameUpdatedAt = updatedAt - } + c.markUsageStreamUpdated() c.emitStatusUpdate() } } -func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) { - if result.oneShot { - return 0, c.pollInterval, true - } +func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) { if result.duration >= connectorBackoffResetThreshold { consecutiveFailures = 0 } consecutiveFailures++ - return consecutiveFailures, connectorBackoff(consecutiveFailures), false + return consecutiveFailures, connectorBackoff(consecutiveFailures) } func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) { @@ -767,23 +765,10 @@ func (c *externalCredential) lastUpdatedTime() time.Time { return c.state.lastUpdated } -func (c *externalCredential) markUsageStreamUpdated() time.Time { +func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() defer c.stateAccess.Unlock() - now := time.Now() - c.state.lastUpdated = now - return now -} - -func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) { - if expectedCurrent.IsZero() { - return - } - c.stateAccess.Lock() - defer c.stateAccess.Unlock() - if c.state.lastUpdated.Equal(expectedCurrent) { - c.state.lastUpdated = previous - } + c.state.lastUpdated = time.Now() } func (c *externalCredential) markUsagePollAttempted() { diff --git a/service/ccm/credential_status_test.go b/service/ccm/credential_status_test.go index f92b27e85c..9353f1d836 100644 --- a/service/ccm/credential_status_test.go +++ b/service/ccm/credential_status_test.go @@ -84,49 +84,8 @@ func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) { return clientSession, serverSession } -func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) { - credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) - oldTime := time.Unix(123, 0) - credential.stateAccess.Lock() - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - result, err := credential.connectStatusStream(context.Background()) - if err != io.EOF { - t.Fatalf("expected EOF, got %v", err) - } - if !result.oneShot { - t.Fatal("expected one-shot result") - } - if result.frames != 1 { - t.Fatalf("expected 1 frame, got %d", result.frames) - } - if !credential.lastUpdatedTime().Equal(oldTime) { - t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime()) - } - if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { - t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event, got %d", count) - } - - failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) - if !oneShot { - t.Fatal("expected one-shot backoff branch") - } - if failures != 0 { - t.Fatalf("expected failures reset, got %d", failures) - } - if backoff != credential.pollInterval { - t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff) - } -} - func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { - headers := make(http.Header) - headers.Set(statusStreamHeader, "true") - credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers) + credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) oldTime := time.Unix(123, 0) credential.stateAccess.Lock() credential.state.lastUpdated = oldTime @@ -136,9 +95,6 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes if err != io.EOF { t.Fatalf("expected EOF, got %v", err) } - if result.oneShot { - t.Fatal("did not expect one-shot result") - } if result.frames != 1 { t.Fatalf("expected 1 frame, got %d", result.frames) } @@ -152,10 +108,7 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes t.Fatalf("expected 1 status event, got %d", count) } - failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) - if oneShot { - t.Fatal("did not expect one-shot backoff branch") - } + failures, backoff := credential.nextStatusStreamBackoff(result, 3) if failures != 4 { t.Fatalf("expected failures incremented to 4, got %d", failures) } @@ -178,9 +131,6 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test if err != io.EOF { t.Fatalf("expected EOF, got %v", err) } - if result.oneShot { - t.Fatal("did not expect one-shot result") - } if result.frames != 2 { t.Fatalf("expected 2 frames, got %d", result.frames) } diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 33d6317de7..1ccbd83ff4 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -311,10 +311,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Rewrite response headers for external users - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig) - } + s.rewriteResponseHeaders(response.Header, provider, userConfig) for key, values := range response.Header { if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index b5771d4b84..e3aa43e307 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -6,16 +6,52 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/sagernet/sing-box/option" ) type statusPayload struct { FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } +type aggregatedStatus struct { + fiveHourUtilization float64 + weeklyUtilization float64 + totalWeight float64 + fiveHourReset time.Time + weeklyReset time.Time +} + +func resetToEpoch(t time.Time) int64 { + if t.IsZero() { + return 0 + } + return t.Unix() +} + +func (s aggregatedStatus) equal(other aggregatedStatus) bool { + return s.fiveHourUtilization == other.fiveHourUtilization && + s.weeklyUtilization == other.weeklyUtilization && + s.totalWeight == other.totalWeight && + resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) && + resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset) +} + +func (s aggregatedStatus) toPayload() statusPayload { + return statusPayload{ + FiveHourUtilization: s.fiveHourUtilization, + FiveHourReset: resetToEpoch(s.fiveHourReset), + WeeklyUtilization: s.weeklyUtilization, + WeeklyReset: resetToEpoch(s.weeklyReset), + PlanWeight: s.totalWeight, + } +} + func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -68,15 +104,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + status := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(statusPayload{ - FiveHourUtilization: avgFiveHour, - WeeklyUtilization: avgWeekly, - PlanWeight: totalWeight, - }) + json.NewEncoder(w).Encode(status.toPayload()) } func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) { @@ -96,16 +128,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro provider.pollIfStale(r.Context()) w.Header().Set("Content-Type", "application/json") - w.Header().Set(statusStreamHeader, "true") w.WriteHeader(http.StatusOK) - lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig) + last := s.computeAggregatedUtilization(provider, userConfig) buf := &bytes.Buffer{} - json.NewEncoder(buf).Encode(statusPayload{ - FiveHourUtilization: lastFiveHour, - WeeklyUtilization: lastWeekly, - PlanWeight: lastWeight, - }) + json.NewEncoder(buf).Encode(last.toPayload()) _, writeErr := w.Write(buf.Bytes()) if writeErr != nil { return @@ -127,19 +154,13 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } } drained: - fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig) - if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight { + current := s.computeAggregatedUtilization(provider, userConfig) + if current.equal(last) { continue } - lastFiveHour = fiveHour - lastWeekly = weekly - lastWeight = weight + last = current buf.Reset() - json.NewEncoder(buf).Encode(statusPayload{ - FiveHourUtilization: fiveHour, - WeeklyUtilization: weekly, - PlanWeight: weight, - }) + json.NewEncoder(buf).Encode(current.toPayload()) _, writeErr = w.Write(buf.Bytes()) if writeErr != nil { return @@ -149,8 +170,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } } -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus { var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + now := time.Now() + var totalWeightedHoursUntil5hReset, total5hResetWeight float64 + var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 for _, credential := range provider.allCredentials() { if !credential.isAvailable() { continue @@ -173,21 +197,59 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user totalWeightedRemaining5h += remaining5h * weight totalWeightedRemainingWeekly += remainingWeekly * weight totalWeight += weight + + fiveHourReset := credential.fiveHourResetTime() + if !fiveHourReset.IsZero() { + hours := fiveHourReset.Sub(now).Hours() + if hours < 0 { + hours = 0 + } + totalWeightedHoursUntil5hReset += hours * weight + total5hResetWeight += weight + } + weeklyReset := credential.weeklyResetTime() + if !weeklyReset.IsZero() { + hours := weeklyReset.Sub(now).Hours() + if hours < 0 { + hours = 0 + } + totalWeightedHoursUntilWeeklyReset += hours * weight + totalWeeklyResetWeight += weight + } } if totalWeight == 0 { - return 100, 100, 0 + return aggregatedStatus{ + fiveHourUtilization: 100, + weeklyUtilization: 100, + } + } + result := aggregatedStatus{ + fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight, + weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight: totalWeight, } - return 100 - totalWeightedRemaining5h/totalWeight, - 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight + if total5hResetWeight > 0 { + avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight + result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + } + if totalWeeklyResetWeight > 0 { + avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight + result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + } + return result } -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) { - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) - headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) - if totalWeight > 0 { - headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) +func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) { + status := s.computeAggregatedUtilization(provider, userConfig) + headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(status.fiveHourUtilization/100, 'f', 6, 64)) + headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(status.weeklyUtilization/100, 'f', 6, 64)) + if !status.fiveHourReset.IsZero() { + headers.Set("anthropic-ratelimit-unified-5h-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) + } + if !status.weeklyReset.IsZero() { + headers.Set("anthropic-ratelimit-unified-7d-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10)) + } + if status.totalWeight > 0 { + headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } } diff --git a/service/ocm/credential.go b/service/ocm/credential.go index c777ea5c9b..1478f5f193 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -107,6 +107,7 @@ type Credential interface { weeklyCap() float64 planWeight() float64 weeklyResetTime() time.Time + fiveHourResetTime() time.Time markRateLimited(resetAt time.Time) earliestReset() time.Time unavailableError() error diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 89daf56e79..1e0af9847a 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -476,6 +476,12 @@ func (c *defaultCredential) weeklyResetTime() time.Time { return c.state.weeklyReset } +func (c *defaultCredential) fiveHourResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourReset +} + func (c *defaultCredential) isAvailable() bool { c.retryCredentialReloadIfNeeded() diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index dd13aca60d..b796ff0bb4 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -28,10 +28,7 @@ import ( "github.com/hashicorp/yamux" ) -const ( - reverseProxyBaseURL = "http://reverse-proxy" - statusStreamHeader = "X-OCM-Status-Stream" -) +const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { tag string @@ -77,7 +74,6 @@ type reverseSessionDialer struct { type statusStreamResult struct { duration time.Duration frames int - oneShot bool } func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { @@ -355,6 +351,12 @@ func (c *externalCredential) weeklyResetTime() time.Time { return c.state.weeklyReset } +func (c *externalCredential) fiveHourResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourReset +} + func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) c.stateAccess.Lock() @@ -634,7 +636,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } err = json.NewDecoder(response.Body).Decode(&statusResponse) @@ -651,6 +655,12 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } @@ -687,13 +697,8 @@ func (c *externalCredential) statusStreamLoop() { return } var backoff time.Duration - var oneShot bool - consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures) - if oneShot { - c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff) - } else { - c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) - } + consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures) + c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) timer := time.NewTimer(backoff) select { case <-timer.C: @@ -721,18 +726,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } decoder := json.NewDecoder(response.Body) - isStatusStream := response.Header.Get(statusStreamHeader) == "true" - previousLastUpdated := c.lastUpdatedTime() - var firstFrameUpdatedAt time.Time for { var statusResponse statusPayload err = decoder.Decode(&statusResponse) if err != nil { result.duration = time.Since(startTime) - if result.frames == 1 && err == io.EOF && !isStatusStream { - result.oneShot = true - c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated) - } return result, err } @@ -740,6 +738,12 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } @@ -752,23 +756,17 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr c.interruptConnections() } result.frames++ - updatedAt := c.markUsageStreamUpdated() - if result.frames == 1 { - firstFrameUpdatedAt = updatedAt - } + c.markUsageStreamUpdated() c.emitStatusUpdate() } } -func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) { - if result.oneShot { - return 0, c.pollInterval, true - } +func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) { if result.duration >= connectorBackoffResetThreshold { consecutiveFailures = 0 } consecutiveFailures++ - return consecutiveFailures, connectorBackoff(consecutiveFailures), false + return consecutiveFailures, connectorBackoff(consecutiveFailures) } func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) { @@ -809,23 +807,10 @@ func (c *externalCredential) lastUpdatedTime() time.Time { return c.state.lastUpdated } -func (c *externalCredential) markUsageStreamUpdated() time.Time { +func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() defer c.stateAccess.Unlock() - now := time.Now() - c.state.lastUpdated = now - return now -} - -func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) { - if expectedCurrent.IsZero() { - return - } - c.stateAccess.Lock() - defer c.stateAccess.Unlock() - if c.state.lastUpdated.Equal(expectedCurrent) { - c.state.lastUpdated = previous - } + c.state.lastUpdated = time.Now() } func (c *externalCredential) markUsagePollAttempted() { diff --git a/service/ocm/credential_status_test.go b/service/ocm/credential_status_test.go index 955338fce5..2865a23808 100644 --- a/service/ocm/credential_status_test.go +++ b/service/ocm/credential_status_test.go @@ -84,49 +84,8 @@ func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) { return clientSession, serverSession } -func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) { - credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) - oldTime := time.Unix(123, 0) - credential.stateAccess.Lock() - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - result, err := credential.connectStatusStream(context.Background()) - if err != io.EOF { - t.Fatalf("expected EOF, got %v", err) - } - if !result.oneShot { - t.Fatal("expected one-shot result") - } - if result.frames != 1 { - t.Fatalf("expected 1 frame, got %d", result.frames) - } - if !credential.lastUpdatedTime().Equal(oldTime) { - t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime()) - } - if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { - t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event, got %d", count) - } - - failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) - if !oneShot { - t.Fatal("expected one-shot backoff branch") - } - if failures != 0 { - t.Fatalf("expected failures reset, got %d", failures) - } - if backoff != credential.pollInterval { - t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff) - } -} - func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { - headers := make(http.Header) - headers.Set(statusStreamHeader, "true") - credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers) + credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) oldTime := time.Unix(123, 0) credential.stateAccess.Lock() credential.state.lastUpdated = oldTime @@ -136,9 +95,6 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes if err != io.EOF { t.Fatalf("expected EOF, got %v", err) } - if result.oneShot { - t.Fatal("did not expect one-shot result") - } if result.frames != 1 { t.Fatalf("expected 1 frame, got %d", result.frames) } @@ -152,10 +108,7 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes t.Fatalf("expected 1 status event, got %d", count) } - failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) - if oneShot { - t.Fatal("did not expect one-shot backoff branch") - } + failures, backoff := credential.nextStatusStreamBackoff(result, 3) if failures != 4 { t.Fatalf("expected failures incremented to 4, got %d", failures) } @@ -178,9 +131,6 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test if err != io.EOF { t.Fatalf("expected EOF, got %v", err) } - if result.oneShot { - t.Fatal("did not expect one-shot result") - } if result.frames != 2 { t.Fatalf("expected 2 frames, got %d", result.frames) } diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 0a7698d005..52e35f39b9 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -291,10 +291,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Rewrite response headers for external users - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig) - } + s.rewriteResponseHeaders(response.Header, provider, userConfig) for key, values := range response.Header { if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index bc167a65ca..f66b469c2a 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -6,16 +6,52 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/sagernet/sing-box/option" ) type statusPayload struct { FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } +type aggregatedStatus struct { + fiveHourUtilization float64 + weeklyUtilization float64 + totalWeight float64 + fiveHourReset time.Time + weeklyReset time.Time +} + +func resetToEpoch(t time.Time) int64 { + if t.IsZero() { + return 0 + } + return t.Unix() +} + +func (s aggregatedStatus) equal(other aggregatedStatus) bool { + return s.fiveHourUtilization == other.fiveHourUtilization && + s.weeklyUtilization == other.weeklyUtilization && + s.totalWeight == other.totalWeight && + resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) && + resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset) +} + +func (s aggregatedStatus) toPayload() statusPayload { + return statusPayload{ + FiveHourUtilization: s.fiveHourUtilization, + FiveHourReset: resetToEpoch(s.fiveHourReset), + WeeklyUtilization: s.weeklyUtilization, + WeeklyReset: resetToEpoch(s.weeklyReset), + PlanWeight: s.totalWeight, + } +} + func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -68,15 +104,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + status := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(statusPayload{ - FiveHourUtilization: avgFiveHour, - WeeklyUtilization: avgWeekly, - PlanWeight: totalWeight, - }) + json.NewEncoder(w).Encode(status.toPayload()) } func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.OCMUser) { @@ -96,16 +128,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro provider.pollIfStale(r.Context()) w.Header().Set("Content-Type", "application/json") - w.Header().Set(statusStreamHeader, "true") w.WriteHeader(http.StatusOK) - lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig) + last := s.computeAggregatedUtilization(provider, userConfig) buf := &bytes.Buffer{} - json.NewEncoder(buf).Encode(statusPayload{ - FiveHourUtilization: lastFiveHour, - WeeklyUtilization: lastWeekly, - PlanWeight: lastWeight, - }) + json.NewEncoder(buf).Encode(last.toPayload()) _, writeErr := w.Write(buf.Bytes()) if writeErr != nil { return @@ -127,19 +154,13 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } } drained: - fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig) - if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight { + current := s.computeAggregatedUtilization(provider, userConfig) + if current.equal(last) { continue } - lastFiveHour = fiveHour - lastWeekly = weekly - lastWeight = weight + last = current buf.Reset() - json.NewEncoder(buf).Encode(statusPayload{ - FiveHourUtilization: fiveHour, - WeeklyUtilization: weekly, - PlanWeight: weight, - }) + json.NewEncoder(buf).Encode(current.toPayload()) _, writeErr = w.Write(buf.Bytes()) if writeErr != nil { return @@ -149,8 +170,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } } -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) aggregatedStatus { var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + now := time.Now() + var totalWeightedHoursUntil5hReset, total5hResetWeight float64 + var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 for _, credential := range provider.allCredentials() { if !credential.isAvailable() { continue @@ -173,26 +197,63 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user totalWeightedRemaining5h += remaining5h * weight totalWeightedRemainingWeekly += remainingWeekly * weight totalWeight += weight + + fiveHourReset := credential.fiveHourResetTime() + if !fiveHourReset.IsZero() { + hours := fiveHourReset.Sub(now).Hours() + if hours < 0 { + hours = 0 + } + totalWeightedHoursUntil5hReset += hours * weight + total5hResetWeight += weight + } + weeklyReset := credential.weeklyResetTime() + if !weeklyReset.IsZero() { + hours := weeklyReset.Sub(now).Hours() + if hours < 0 { + hours = 0 + } + totalWeightedHoursUntilWeeklyReset += hours * weight + totalWeeklyResetWeight += weight + } } if totalWeight == 0 { - return 100, 100, 0 + return aggregatedStatus{ + fiveHourUtilization: 100, + weeklyUtilization: 100, + } + } + result := aggregatedStatus{ + fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight, + weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight: totalWeight, + } + if total5hResetWeight > 0 { + avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight + result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour))) } - return 100 - totalWeightedRemaining5h/totalWeight, - 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight + if totalWeeklyResetWeight > 0 { + avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight + result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + } + return result } -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) { - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - +func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) { + status := s.computeAggregatedUtilization(provider, userConfig) activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) if activeLimitIdentifier == "" { activeLimitIdentifier = "codex" } - - headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) - headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) - if totalWeight > 0 { - headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64)) + headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64)) + if !status.fiveHourReset.IsZero() { + headers.Set("x-"+activeLimitIdentifier+"-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) + } + if !status.weeklyReset.IsZero() { + headers.Set("x-"+activeLimitIdentifier+"-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10)) + } + if status.totalWeight > 0 { + headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } } diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 066692ace7..bb9640d545 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -252,9 +252,7 @@ func (s *Service) handleWebSocket( clientResponseHeaders[key] = append([]string(nil), values...) } } - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, provider, userConfig) - } + s.rewriteResponseHeaders(clientResponseHeaders, provider, userConfig) clientUpgrader := ws.HTTPUpgrader{ Header: clientResponseHeaders, @@ -292,10 +290,16 @@ func (s *Service) handleWebSocket( upstreamReadWriter = upstreamConn } + rateLimitIdentifier := normalizeRateLimitIdentifier(upstreamResponseHeaders.Get("x-codex-active-limit")) + if rateLimitIdentifier == "" { + rateLimitIdentifier = "codex" + } + + var clientWriteAccess sync.Mutex modelChannel := make(chan string, 1) var waitGroup sync.WaitGroup - waitGroup.Add(2) + waitGroup.Add(3) go func() { defer waitGroup.Done() defer session.Close() @@ -304,7 +308,12 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) + }() + go func() { + defer waitGroup.Done() + defer session.Close() + s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, provider, userConfig, rateLimitIdentifier) }() waitGroup.Wait() } @@ -363,7 +372,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn } } -func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { @@ -384,11 +393,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe switch event.Type { case "codex.rate_limits": s.handleWebSocketRateLimitsEvent(data, selectedCredential) - if userConfig != nil && userConfig.ExternalCredential != "" { - rewritten, rewriteErr := s.rewriteWebSocketRateLimitsForExternalUser(data, provider, userConfig) - if rewriteErr == nil { - data = rewritten - } + rewritten, rewriteErr := s.rewriteWebSocketRateLimits(data, provider, userConfig) + if rewriteErr == nil { + data = rewritten } case "error": if event.StatusCode == http.StatusTooManyRequests { @@ -407,7 +414,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe } } + clientWriteAccess.Lock() err = wsutil.WriteServerMessage(clientConn, opCode, data) + clientWriteAccess.Unlock() if err != nil { if !E.IsClosedOrCanceled(err) { s.logger.DebugContext(ctx, "write client websocket: ", err) @@ -483,7 +492,7 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia selectedCredential.markRateLimited(resetAt) } -func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) { +func (s *Service) rewriteWebSocketRateLimits(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) { var event map[string]json.RawMessage err := json.Unmarshal(data, &event) if err != nil { @@ -501,13 +510,13 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide return nil, err } - averageFiveHour, averageWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + status := s.computeAggregatedUtilization(provider, userConfig) - if totalWeight > 0 { - event["plan_weight"], _ = json.Marshal(totalWeight) + if status.totalWeight > 0 { + event["plan_weight"], _ = json.Marshal(status.totalWeight) } - primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour) + primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], status.fiveHourUtilization, resetToEpoch(status.fiveHourReset)) if err != nil { return nil, err } @@ -515,7 +524,7 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide rateLimits["primary"] = primaryData } - secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], averageWeekly) + secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], status.weeklyUtilization, resetToEpoch(status.weeklyReset)) if err != nil { return nil, err } @@ -531,7 +540,7 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide return json.Marshal(event) } -func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) (json.RawMessage, error) { +func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64, resetAt int64) (json.RawMessage, error) { if len(data) == 0 || string(data) == "null" { return nil, nil } @@ -547,9 +556,93 @@ func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) return nil, err } + if resetAt > 0 { + window["reset_at"], err = json.Marshal(resetAt) + if err != nil { + return nil, err + } + } + return json.Marshal(window) } +func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, provider credentialProvider, userConfig *option.OCMUser, rateLimitIdentifier string) { + subscription, done, err := s.statusObserver.Subscribe() + if err != nil { + return + } + defer s.statusObserver.UnSubscribe(subscription) + + last := s.computeAggregatedUtilization(provider, userConfig) + data := buildSyntheticRateLimitsEvent(rateLimitIdentifier, last) + clientWriteAccess.Lock() + err = wsutil.WriteServerMessage(clientConn, ws.OpText, data) + clientWriteAccess.Unlock() + if err != nil { + return + } + + for { + select { + case <-ctx.Done(): + return + case <-done: + return + case <-subscription: + for { + select { + case <-subscription: + default: + goto drained + } + } + drained: + current := s.computeAggregatedUtilization(provider, userConfig) + if current.equal(last) { + continue + } + last = current + data = buildSyntheticRateLimitsEvent(rateLimitIdentifier, current) + clientWriteAccess.Lock() + err = wsutil.WriteServerMessage(clientConn, ws.OpText, data) + clientWriteAccess.Unlock() + if err != nil { + return + } + } + } +} + +func buildSyntheticRateLimitsEvent(identifier string, status aggregatedStatus) []byte { + type rateLimitWindow struct { + UsedPercent float64 `json:"used_percent"` + ResetAt int64 `json:"reset_at,omitempty"` + } + event := struct { + Type string `json:"type"` + RateLimits struct { + Primary *rateLimitWindow `json:"primary,omitempty"` + Secondary *rateLimitWindow `json:"secondary,omitempty"` + } `json:"rate_limits"` + LimitName string `json:"limit_name"` + PlanWeight float64 `json:"plan_weight,omitempty"` + }{ + Type: "codex.rate_limits", + LimitName: identifier, + PlanWeight: status.totalWeight, + } + event.RateLimits.Primary = &rateLimitWindow{ + UsedPercent: status.fiveHourUtilization, + ResetAt: resetToEpoch(status.fiveHourReset), + } + event.RateLimits.Secondary = &rateLimitWindow{ + UsedPercent: status.weeklyUtilization, + ResetAt: resetToEpoch(status.weeklyReset), + } + data, _ := json.Marshal(event) + return data +} + func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) { var streamEvent responses.ResponseStreamEventUnion if json.Unmarshal(data, &streamEvent) != nil { From f57eff33bbda859e9202ccaf1561e0654f79cd25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 20:00:54 +0800 Subject: [PATCH 61/96] ccm,ocm: fix WS push lifecycle, deduplicate rate_limits, stabilize reset aggregation - Add closed channel to webSocketSession for push goroutine shutdown on connection close, preventing session leak and Service.Close() hang - Intercept upstream codex.rate_limits events instead of forwarding; push goroutine is now the sole sender of aggregated rate_limits - Emit status updates on reset-only changes (fiveHourResetChanged, weeklyResetChanged) so push goroutine picks up reset advances - Skip expired resets (hours <= 0) in aggregation instead of clamping to now, avoiding unstable reset_at output and spurious status ticks - Delete stale upstream reset headers when aggregated reset is zero - Hardcode "codex" identifier everywhere: handleWebSocketRateLimitsEvent, buildSyntheticRateLimitsEvent, rewriteResponseHeaders - Remove rewriteWebSocketRateLimits, rewriteWebSocketRateLimitWindow, identifier tracking (TypedValue), and unused imports --- service/ccm/credential_default.go | 2 +- service/ccm/credential_external.go | 5 +- service/ccm/service_status.go | 18 ++-- service/ocm/credential_default.go | 2 +- service/ocm/credential_external.go | 5 +- service/ocm/service_status.go | 18 ++-- service/ocm/service_websocket.go | 127 +++++------------------------ 7 files changed, 50 insertions(+), 127 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index bf88fc836c..c004d1e844 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -307,7 +307,7 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { } c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } - shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly) + shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly || fiveHourResetChanged || weeklyResetChanged) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index cdbbc48449..1731dd5d09 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -439,6 +439,8 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization oldPlanWeight := c.state.remotePlanWeight + oldFiveHourReset := c.state.fiveHourReset + oldWeeklyReset := c.state.weeklyReset hadData := false if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { @@ -483,7 +485,8 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } utilizationChanged := c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly planWeightChanged := c.state.remotePlanWeight != oldPlanWeight - shouldEmit := (hadData && utilizationChanged) || planWeightChanged + resetChanged := c.state.fiveHourReset != oldFiveHourReset || c.state.weeklyReset != oldWeeklyReset + shouldEmit := (hadData && (utilizationChanged || resetChanged)) || planWeightChanged shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index e3aa43e307..50ba7ffe45 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -201,20 +201,18 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user fiveHourReset := credential.fiveHourResetTime() if !fiveHourReset.IsZero() { hours := fiveHourReset.Sub(now).Hours() - if hours < 0 { - hours = 0 + if hours > 0 { + totalWeightedHoursUntil5hReset += hours * weight + total5hResetWeight += weight } - totalWeightedHoursUntil5hReset += hours * weight - total5hResetWeight += weight } weeklyReset := credential.weeklyResetTime() if !weeklyReset.IsZero() { hours := weeklyReset.Sub(now).Hours() - if hours < 0 { - hours = 0 + if hours > 0 { + totalWeightedHoursUntilWeeklyReset += hours * weight + totalWeeklyResetWeight += weight } - totalWeightedHoursUntilWeeklyReset += hours * weight - totalWeeklyResetWeight += weight } } if totalWeight == 0 { @@ -245,9 +243,13 @@ func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentia headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(status.weeklyUtilization/100, 'f', 6, 64)) if !status.fiveHourReset.IsZero() { headers.Set("anthropic-ratelimit-unified-5h-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) + } else { + headers.Del("anthropic-ratelimit-unified-5h-reset") } if !status.weeklyReset.IsZero() { headers.Set("anthropic-ratelimit-unified-7d-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10)) + } else { + headers.Del("anthropic-ratelimit-unified-7d-reset") } if status.totalWeight > 0 { headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 1e0af9847a..dd610e88de 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -347,7 +347,7 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { } c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } - shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly) + shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly || fiveHourResetChanged || weeklyResetChanged) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index b796ff0bb4..531738bcc2 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -463,6 +463,8 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization oldPlanWeight := c.state.remotePlanWeight + oldFiveHourReset := c.state.fiveHourReset + oldWeeklyReset := c.state.weeklyReset hadData := false activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) @@ -522,7 +524,8 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } utilizationChanged := c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly planWeightChanged := c.state.remotePlanWeight != oldPlanWeight - shouldEmit := (hadData && utilizationChanged) || planWeightChanged + resetChanged := c.state.fiveHourReset != oldFiveHourReset || c.state.weeklyReset != oldWeeklyReset + shouldEmit := (hadData && (utilizationChanged || resetChanged)) || planWeightChanged shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index f66b469c2a..cb23f42d73 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -201,20 +201,18 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user fiveHourReset := credential.fiveHourResetTime() if !fiveHourReset.IsZero() { hours := fiveHourReset.Sub(now).Hours() - if hours < 0 { - hours = 0 + if hours > 0 { + totalWeightedHoursUntil5hReset += hours * weight + total5hResetWeight += weight } - totalWeightedHoursUntil5hReset += hours * weight - total5hResetWeight += weight } weeklyReset := credential.weeklyResetTime() if !weeklyReset.IsZero() { hours := weeklyReset.Sub(now).Hours() - if hours < 0 { - hours = 0 + if hours > 0 { + totalWeightedHoursUntilWeeklyReset += hours * weight + totalWeeklyResetWeight += weight } - totalWeightedHoursUntilWeeklyReset += hours * weight - totalWeeklyResetWeight += weight } } if totalWeight == 0 { @@ -249,9 +247,13 @@ func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentia headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64)) if !status.fiveHourReset.IsZero() { headers.Set("x-"+activeLimitIdentifier+"-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) + } else { + headers.Del("x-" + activeLimitIdentifier + "-primary-reset-at") } if !status.weeklyReset.IsZero() { headers.Set("x-"+activeLimitIdentifier+"-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10)) + } else { + headers.Del("x-" + activeLimitIdentifier + "-secondary-reset-at") } if status.totalWeight > 0 { headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index bb9640d545..520b6c58b5 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -31,10 +31,12 @@ type webSocketSession struct { credentialTag string releaseProviderInterrupt func() closeOnce sync.Once + closed chan struct{} } func (s *webSocketSession) Close() { s.closeOnce.Do(func() { + close(s.closed) if s.releaseProviderInterrupt != nil { s.releaseProviderInterrupt() } @@ -273,6 +275,7 @@ func (s *Service) handleWebSocket( upstreamConn: upstreamConn, credentialTag: selectedCredential.tagName(), releaseProviderInterrupt: requestContext.releaseCredentialInterrupt, + closed: make(chan struct{}), } if !s.registerWebSocketSession(session) { session.Close() @@ -290,11 +293,6 @@ func (s *Service) handleWebSocket( upstreamReadWriter = upstreamConn } - rateLimitIdentifier := normalizeRateLimitIdentifier(upstreamResponseHeaders.Get("x-codex-active-limit")) - if rateLimitIdentifier == "" { - rateLimitIdentifier = "codex" - } - var clientWriteAccess sync.Mutex modelChannel := make(chan string, 1) var waitGroup sync.WaitGroup @@ -308,12 +306,12 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint) }() go func() { defer waitGroup.Done() defer session.Close() - s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, provider, userConfig, rateLimitIdentifier) + s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, session.closed, provider, userConfig) }() waitGroup.Wait() } @@ -372,7 +370,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn } } -func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { @@ -393,10 +391,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe switch event.Type { case "codex.rate_limits": s.handleWebSocketRateLimitsEvent(data, selectedCredential) - rewritten, rewriteErr := s.rewriteWebSocketRateLimits(data, provider, userConfig) - if rewriteErr == nil { - data = rewritten - } + continue case "error": if event.StatusCode == http.StatusTooManyRequests { s.handleWebSocketErrorRateLimited(data, selectedCredential) @@ -438,35 +433,25 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential ResetAt int64 `json:"reset_at"` } `json:"secondary"` } `json:"rate_limits"` - LimitName string `json:"limit_name"` - MeteredLimitName string `json:"metered_limit_name"` - PlanWeight float64 `json:"plan_weight"` + PlanWeight float64 `json:"plan_weight"` } err := json.Unmarshal(data, &rateLimitsEvent) if err != nil { return } - identifier := rateLimitsEvent.MeteredLimitName - if identifier == "" { - identifier = rateLimitsEvent.LimitName - } - if identifier == "" { - identifier = "codex" - } - identifier = normalizeRateLimitIdentifier(identifier) headers := make(http.Header) - headers.Set("x-codex-active-limit", identifier) + headers.Set("x-codex-active-limit", "codex") if w := rateLimitsEvent.RateLimits.Primary; w != nil { - headers.Set("x-"+identifier+"-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) + headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) if w.ResetAt > 0 { - headers.Set("x-"+identifier+"-primary-reset-at", strconv.FormatInt(w.ResetAt, 10)) + headers.Set("x-codex-primary-reset-at", strconv.FormatInt(w.ResetAt, 10)) } } if w := rateLimitsEvent.RateLimits.Secondary; w != nil { - headers.Set("x-"+identifier+"-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) + headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) if w.ResetAt > 0 { - headers.Set("x-"+identifier+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10)) + headers.Set("x-codex-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10)) } } if rateLimitsEvent.PlanWeight > 0 { @@ -492,81 +477,7 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia selectedCredential.markRateLimited(resetAt) } -func (s *Service) rewriteWebSocketRateLimits(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) { - var event map[string]json.RawMessage - err := json.Unmarshal(data, &event) - if err != nil { - return nil, err - } - - rateLimitsData, exists := event["rate_limits"] - if !exists || len(rateLimitsData) == 0 || string(rateLimitsData) == "null" { - return data, nil - } - - var rateLimits map[string]json.RawMessage - err = json.Unmarshal(rateLimitsData, &rateLimits) - if err != nil { - return nil, err - } - - status := s.computeAggregatedUtilization(provider, userConfig) - - if status.totalWeight > 0 { - event["plan_weight"], _ = json.Marshal(status.totalWeight) - } - - primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], status.fiveHourUtilization, resetToEpoch(status.fiveHourReset)) - if err != nil { - return nil, err - } - if primaryData != nil { - rateLimits["primary"] = primaryData - } - - secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], status.weeklyUtilization, resetToEpoch(status.weeklyReset)) - if err != nil { - return nil, err - } - if secondaryData != nil { - rateLimits["secondary"] = secondaryData - } - - event["rate_limits"], err = json.Marshal(rateLimits) - if err != nil { - return nil, err - } - - return json.Marshal(event) -} - -func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64, resetAt int64) (json.RawMessage, error) { - if len(data) == 0 || string(data) == "null" { - return nil, nil - } - - var window map[string]json.RawMessage - err := json.Unmarshal(data, &window) - if err != nil { - return nil, err - } - - window["used_percent"], err = json.Marshal(usedPercent) - if err != nil { - return nil, err - } - - if resetAt > 0 { - window["reset_at"], err = json.Marshal(resetAt) - if err != nil { - return nil, err - } - } - - return json.Marshal(window) -} - -func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, provider credentialProvider, userConfig *option.OCMUser, rateLimitIdentifier string) { +func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, sessionClosed <-chan struct{}, provider credentialProvider, userConfig *option.OCMUser) { subscription, done, err := s.statusObserver.Subscribe() if err != nil { return @@ -574,7 +485,7 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn defer s.statusObserver.UnSubscribe(subscription) last := s.computeAggregatedUtilization(provider, userConfig) - data := buildSyntheticRateLimitsEvent(rateLimitIdentifier, last) + data := buildSyntheticRateLimitsEvent(last) clientWriteAccess.Lock() err = wsutil.WriteServerMessage(clientConn, ws.OpText, data) clientWriteAccess.Unlock() @@ -588,6 +499,8 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn return case <-done: return + case <-sessionClosed: + return case <-subscription: for { select { @@ -602,7 +515,7 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn continue } last = current - data = buildSyntheticRateLimitsEvent(rateLimitIdentifier, current) + data = buildSyntheticRateLimitsEvent(current) clientWriteAccess.Lock() err = wsutil.WriteServerMessage(clientConn, ws.OpText, data) clientWriteAccess.Unlock() @@ -613,7 +526,7 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn } } -func buildSyntheticRateLimitsEvent(identifier string, status aggregatedStatus) []byte { +func buildSyntheticRateLimitsEvent(status aggregatedStatus) []byte { type rateLimitWindow struct { UsedPercent float64 `json:"used_percent"` ResetAt int64 `json:"reset_at,omitempty"` @@ -628,7 +541,7 @@ func buildSyntheticRateLimitsEvent(identifier string, status aggregatedStatus) [ PlanWeight float64 `json:"plan_weight,omitempty"` }{ Type: "codex.rate_limits", - LimitName: identifier, + LimitName: "codex", PlanWeight: status.totalWeight, } event.RateLimits.Primary = &rateLimitWindow{ From 969defeef07a8191a798a4f5509107367f2bbd3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 20:17:56 +0800 Subject: [PATCH 62/96] ccm,ocm: validate external status response fields --- service/ccm/credential_external.go | 44 ++++++++++++++++++++++++++++-- service/ocm/credential_external.go | 44 ++++++++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 1731dd5d09..2445a8509e 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -595,6 +595,26 @@ func (c *externalCredential) pollUsage(ctx context.Context) { return } + body, err := io.ReadAll(response.Body) + if err != nil { + c.logger.Debug("poll usage for ", c.tag, ": read body: ", err) + c.clearPollFailures() + return + } + var rawFields map[string]json.RawMessage + err = json.Unmarshal(body, &rawFields) + if err != nil { + c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.clearPollFailures() + return + } + if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || + rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || + rawFields["plan_weight"] == nil { + c.logger.Error("poll usage for ", c.tag, ": invalid response") + c.clearPollFailures() + return + } var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` FiveHourReset int64 `json:"five_hour_reset"` @@ -602,7 +622,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } - err = json.NewDecoder(response.Body).Decode(&statusResponse) + err = json.Unmarshal(body, &statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) c.clearPollFailures() @@ -688,12 +708,30 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr decoder := json.NewDecoder(response.Body) for { - var statusResponse statusPayload - err = decoder.Decode(&statusResponse) + var rawMessage json.RawMessage + err = decoder.Decode(&rawMessage) if err != nil { result.duration = time.Since(startTime) return result, err } + var rawFields map[string]json.RawMessage + err = json.Unmarshal(rawMessage, &rawFields) + if err != nil { + result.duration = time.Since(startTime) + return result, E.Cause(err, "decode status frame") + } + if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || + rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || + rawFields["plan_weight"] == nil { + result.duration = time.Since(startTime) + return result, E.New("invalid response") + } + var statusResponse statusPayload + err = json.Unmarshal(rawMessage, &statusResponse) + if err != nil { + result.duration = time.Since(startTime) + return result, E.Cause(err, "decode status frame") + } c.stateAccess.Lock() c.state.consecutivePollFailures = 0 diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 531738bcc2..27342aa72a 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -637,6 +637,26 @@ func (c *externalCredential) pollUsage(ctx context.Context) { return } + body, err := io.ReadAll(response.Body) + if err != nil { + c.logger.Debug("poll usage for ", c.tag, ": read body: ", err) + c.clearPollFailures() + return + } + var rawFields map[string]json.RawMessage + err = json.Unmarshal(body, &rawFields) + if err != nil { + c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.clearPollFailures() + return + } + if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || + rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || + rawFields["plan_weight"] == nil { + c.logger.Error("poll usage for ", c.tag, ": invalid response") + c.clearPollFailures() + return + } var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` FiveHourReset int64 `json:"five_hour_reset"` @@ -644,7 +664,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } - err = json.NewDecoder(response.Body).Decode(&statusResponse) + err = json.Unmarshal(body, &statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) c.clearPollFailures() @@ -730,12 +750,30 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr decoder := json.NewDecoder(response.Body) for { - var statusResponse statusPayload - err = decoder.Decode(&statusResponse) + var rawMessage json.RawMessage + err = decoder.Decode(&rawMessage) if err != nil { result.duration = time.Since(startTime) return result, err } + var rawFields map[string]json.RawMessage + err = json.Unmarshal(rawMessage, &rawFields) + if err != nil { + result.duration = time.Since(startTime) + return result, E.Cause(err, "decode status frame") + } + if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || + rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || + rawFields["plan_weight"] == nil { + result.duration = time.Since(startTime) + return result, E.New("invalid response") + } + var statusResponse statusPayload + err = json.Unmarshal(rawMessage, &statusResponse) + if err != nil { + result.duration = time.Since(startTime) + return result, E.Cause(err, "decode status frame") + } c.stateAccess.Lock() c.state.consecutivePollFailures = 0 From 99e19e70330585ebfcfaf3e72f994cc9f39fc55f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 20:47:42 +0800 Subject: [PATCH 63/96] service: stop retrying fatal watch status errors --- service/ccm/credential_external.go | 9 + service/ccm/credential_status_test.go | 265 -------------------------- service/ocm/credential_external.go | 9 + service/ocm/credential_status_test.go | 246 ------------------------ 4 files changed, 18 insertions(+), 511 deletions(-) delete mode 100644 service/ccm/credential_status_test.go delete mode 100644 service/ocm/credential_status_test.go diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 2445a8509e..ba42ad64eb 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -5,6 +5,7 @@ import ( "context" stdTLS "crypto/tls" "encoding/json" + "errors" "io" "net" "net/http" @@ -677,6 +678,10 @@ func (c *externalCredential) statusStreamLoop() { if ctx.Err() != nil { return } + if !shouldRetryStatusStreamError(err) { + c.logger.Warn("status stream for ", c.tag, " disconnected: ", err, ", not retrying") + return + } var backoff time.Duration consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures) c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) @@ -760,6 +765,10 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } } +func shouldRetryStatusStreamError(err error) bool { + return errors.Is(err, io.ErrUnexpectedEOF) || E.IsClosedOrCanceled(err) +} + func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) { if result.duration >= connectorBackoffResetThreshold { consecutiveFailures = 0 diff --git a/service/ccm/credential_status_test.go b/service/ccm/credential_status_test.go deleted file mode 100644 index 9353f1d836..0000000000 --- a/service/ccm/credential_status_test.go +++ /dev/null @@ -1,265 +0,0 @@ -package ccm - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common/observable" - - "github.com/hashicorp/yamux" -) - -type roundTripperFunc func(*http.Request) (*http.Response, error) - -func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) { - return f(request) -} - -func drainStatusEvents(subscription observable.Subscription[struct{}]) int { - var count int - for { - select { - case <-subscription: - count++ - default: - return count - } - } -} - -func newTestLogger() log.ContextLogger { - return log.NewNOPFactory().Logger() -} - -func newTestCCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) { - t.Helper() - subscriber := observable.NewSubscriber[struct{}](8) - subscription, _ := subscriber.Subscription() - credential := &externalCredential{ - tag: "test", - baseURL: "http://example.com", - token: "token", - pollInterval: 25 * time.Millisecond, - forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) { - if request.URL.String() != "http://example.com/ccm/v1/status?watch=true" { - t.Fatalf("unexpected request URL: %s", request.URL.String()) - } - return &http.Response{ - StatusCode: http.StatusOK, - Header: headers.Clone(), - Body: io.NopCloser(strings.NewReader(body)), - }, nil - })}, - logger: newTestLogger(), - statusSubscriber: subscriber, - } - return credential, subscription -} - -func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) { - t.Helper() - clientConn, serverConn := net.Pipe() - clientSession, err := yamux.Client(clientConn, defaultYamuxConfig) - if err != nil { - t.Fatalf("create yamux client: %v", err) - } - serverSession, err := yamux.Server(serverConn, defaultYamuxConfig) - if err != nil { - clientSession.Close() - t.Fatalf("create yamux server: %v", err) - } - t.Cleanup(func() { - clientSession.Close() - serverSession.Close() - }) - return clientSession, serverSession -} - -func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { - credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) - oldTime := time.Unix(123, 0) - credential.stateAccess.Lock() - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - result, err := credential.connectStatusStream(context.Background()) - if err != io.EOF { - t.Fatalf("expected EOF, got %v", err) - } - if result.frames != 1 { - t.Fatalf("expected 1 frame, got %d", result.frames) - } - if credential.lastUpdatedTime().Equal(oldTime) { - t.Fatal("expected lastUpdated to remain refreshed") - } - if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { - t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event, got %d", count) - } - - failures, backoff := credential.nextStatusStreamBackoff(result, 3) - if failures != 4 { - t.Fatalf("expected failures incremented to 4, got %d", failures) - } - if backoff < 16*time.Second || backoff >= 24*time.Second { - t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff) - } -} - -func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) { - credential, subscription := newTestCCMExternalCredential(t, strings.Join([]string{ - "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}", - "{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}", - }, "\n"), nil) - oldTime := time.Unix(123, 0) - credential.stateAccess.Lock() - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - result, err := credential.connectStatusStream(context.Background()) - if err != io.EOF { - t.Fatalf("expected EOF, got %v", err) - } - if result.frames != 2 { - t.Fatalf("expected 2 frames, got %d", result.frames) - } - if credential.lastUpdatedTime().Equal(oldTime) { - t.Fatal("expected lastUpdated to remain refreshed") - } - if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 { - t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) - } - if count := drainStatusEvents(subscription); count != 2 { - t.Fatalf("expected 2 status events, got %d", count) - } -} - -func TestExternalCredentialPlanWeightOnlyHeaderEmitsStatus(t *testing.T) { - subscriber := observable.NewSubscriber[struct{}](8) - subscription, _ := subscriber.Subscription() - credential := &externalCredential{ - tag: "test", - logger: newTestLogger(), - statusSubscriber: subscriber, - } - credential.stateAccess.Lock() - credential.state.remotePlanWeight = 2 - oldTime := time.Unix(123, 0) - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - headers := make(http.Header) - headers.Set("X-CCM-Plan-Weight", "3") - credential.updateStateFromHeaders(headers) - - if weight := credential.planWeight(); weight != 3 { - t.Fatalf("expected plan weight 3, got %v", weight) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event, got %d", count) - } - if !credential.lastUpdatedTime().Equal(oldTime) { - t.Fatalf("expected lastUpdated to stay %v, got %v", oldTime, credential.lastUpdatedTime()) - } - - credential.updateStateFromHeaders(headers) - - if count := drainStatusEvents(subscription); count != 0 { - t.Fatalf("expected no status event for unchanged plan weight, got %d", count) - } -} - -func TestDefaultCredentialStatusChangesEmitStatus(t *testing.T) { - credentialPath := filepath.Join(t.TempDir(), "credentials.json") - err := os.WriteFile(credentialPath, []byte("{\"claudeAiOauth\":{\"accessToken\":\"token\",\"refreshToken\":\"\",\"expiresAt\":0,\"subscriptionType\":\"max\"}}\n"), 0o600) - if err != nil { - t.Fatalf("write credential file: %v", err) - } - - subscriber := observable.NewSubscriber[struct{}](8) - subscription, _ := subscriber.Subscription() - credential := &defaultCredential{ - tag: "test", - credentialPath: credentialPath, - logger: newTestLogger(), - statusSubscriber: subscriber, - } - - err = credential.markCredentialsUnavailable(errors.New("boom")) - if err == nil { - t.Fatal("expected error from markCredentialsUnavailable") - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after unavailable transition, got %d", count) - } - - err = credential.reloadCredentials(true) - if err != nil { - t.Fatalf("reload credentials: %v", err) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after recovery, got %d", count) - } - if weight := credential.planWeight(); weight != 5 { - t.Fatalf("expected initial max weight 5, got %v", weight) - } - - profileClient := &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Header: make(http.Header), - Body: io.NopCloser(strings.NewReader( - "{\"organization\":{\"organization_type\":\"claude_max\",\"rate_limit_tier\":\"default_claude_max_20x\"}}", - )), - }, nil - })} - credential.fetchProfile(context.Background(), profileClient, "token") - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after weight change, got %d", count) - } - if weight := credential.planWeight(); weight != 10 { - t.Fatalf("expected upgraded max weight 10, got %v", weight) - } -} - -func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) { - subscriber := observable.NewSubscriber[struct{}](8) - subscription, _ := subscriber.Subscription() - credential := &externalCredential{ - tag: "receiver", - baseURL: reverseProxyBaseURL, - pollInterval: time.Minute, - logger: newTestLogger(), - statusSubscriber: subscriber, - } - - clientSession, _ := newTestYamuxSessionPair(t) - if !credential.setReverseSession(clientSession) { - t.Fatal("expected reverse session to be accepted") - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after reverse session up, got %d", count) - } - if !credential.isAvailable() { - t.Fatal("expected receiver credential to become available") - } - - credential.clearReverseSession(clientSession) - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after reverse session down, got %d", count) - } - if credential.isAvailable() { - t.Fatal("expected receiver credential to become unavailable") - } -} diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 27342aa72a..b06171d03b 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -5,6 +5,7 @@ import ( "context" stdTLS "crypto/tls" "encoding/json" + "errors" "io" "net" "net/http" @@ -719,6 +720,10 @@ func (c *externalCredential) statusStreamLoop() { if ctx.Err() != nil { return } + if !shouldRetryStatusStreamError(err) { + c.logger.Warn("status stream for ", c.tag, " disconnected: ", err, ", not retrying") + return + } var backoff time.Duration consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures) c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) @@ -802,6 +807,10 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } } +func shouldRetryStatusStreamError(err error) bool { + return errors.Is(err, io.ErrUnexpectedEOF) || E.IsClosedOrCanceled(err) +} + func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) { if result.duration >= connectorBackoffResetThreshold { consecutiveFailures = 0 diff --git a/service/ocm/credential_status_test.go b/service/ocm/credential_status_test.go deleted file mode 100644 index 2865a23808..0000000000 --- a/service/ocm/credential_status_test.go +++ /dev/null @@ -1,246 +0,0 @@ -package ocm - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common/observable" - - "github.com/hashicorp/yamux" -) - -type roundTripperFunc func(*http.Request) (*http.Response, error) - -func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) { - return f(request) -} - -func drainStatusEvents(subscription observable.Subscription[struct{}]) int { - var count int - for { - select { - case <-subscription: - count++ - default: - return count - } - } -} - -func newTestLogger() log.ContextLogger { - return log.NewNOPFactory().Logger() -} - -func newTestOCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) { - t.Helper() - subscriber := observable.NewSubscriber[struct{}](8) - subscription, _ := subscriber.Subscription() - credential := &externalCredential{ - tag: "test", - baseURL: "http://example.com", - token: "token", - pollInterval: 25 * time.Millisecond, - forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) { - if request.URL.String() != "http://example.com/ocm/v1/status?watch=true" { - t.Fatalf("unexpected request URL: %s", request.URL.String()) - } - return &http.Response{ - StatusCode: http.StatusOK, - Header: headers.Clone(), - Body: io.NopCloser(strings.NewReader(body)), - }, nil - })}, - logger: newTestLogger(), - statusSubscriber: subscriber, - } - return credential, subscription -} - -func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) { - t.Helper() - clientConn, serverConn := net.Pipe() - clientSession, err := yamux.Client(clientConn, defaultYamuxConfig) - if err != nil { - t.Fatalf("create yamux client: %v", err) - } - serverSession, err := yamux.Server(serverConn, defaultYamuxConfig) - if err != nil { - clientSession.Close() - t.Fatalf("create yamux server: %v", err) - } - t.Cleanup(func() { - clientSession.Close() - serverSession.Close() - }) - return clientSession, serverSession -} - -func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { - credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) - oldTime := time.Unix(123, 0) - credential.stateAccess.Lock() - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - result, err := credential.connectStatusStream(context.Background()) - if err != io.EOF { - t.Fatalf("expected EOF, got %v", err) - } - if result.frames != 1 { - t.Fatalf("expected 1 frame, got %d", result.frames) - } - if credential.lastUpdatedTime().Equal(oldTime) { - t.Fatal("expected lastUpdated to remain refreshed") - } - if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { - t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event, got %d", count) - } - - failures, backoff := credential.nextStatusStreamBackoff(result, 3) - if failures != 4 { - t.Fatalf("expected failures incremented to 4, got %d", failures) - } - if backoff < 16*time.Second || backoff >= 24*time.Second { - t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff) - } -} - -func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) { - credential, subscription := newTestOCMExternalCredential(t, strings.Join([]string{ - "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}", - "{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}", - }, "\n"), nil) - oldTime := time.Unix(123, 0) - credential.stateAccess.Lock() - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - result, err := credential.connectStatusStream(context.Background()) - if err != io.EOF { - t.Fatalf("expected EOF, got %v", err) - } - if result.frames != 2 { - t.Fatalf("expected 2 frames, got %d", result.frames) - } - if credential.lastUpdatedTime().Equal(oldTime) { - t.Fatal("expected lastUpdated to remain refreshed") - } - if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 { - t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) - } - if count := drainStatusEvents(subscription); count != 2 { - t.Fatalf("expected 2 status events, got %d", count) - } -} - -func TestExternalCredentialPlanWeightOnlyRateLimitsEventEmitsStatus(t *testing.T) { - subscriber := observable.NewSubscriber[struct{}](8) - subscription, _ := subscriber.Subscription() - credential := &externalCredential{ - tag: "test", - logger: newTestLogger(), - statusSubscriber: subscriber, - } - credential.stateAccess.Lock() - credential.state.remotePlanWeight = 2 - oldTime := time.Unix(123, 0) - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - (&Service{}).handleWebSocketRateLimitsEvent([]byte(`{"plan_weight":3}`), credential) - - if weight := credential.planWeight(); weight != 3 { - t.Fatalf("expected plan weight 3, got %v", weight) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event, got %d", count) - } - if !credential.lastUpdatedTime().Equal(oldTime) { - t.Fatalf("expected lastUpdated to stay %v, got %v", oldTime, credential.lastUpdatedTime()) - } - - (&Service{}).handleWebSocketRateLimitsEvent([]byte(`{"plan_weight":3}`), credential) - - if count := drainStatusEvents(subscription); count != 0 { - t.Fatalf("expected no status event for unchanged plan weight, got %d", count) - } -} - -func TestDefaultCredentialAvailabilityChangesEmitStatus(t *testing.T) { - credentialPath := filepath.Join(t.TempDir(), "auth.json") - err := os.WriteFile(credentialPath, []byte("{\"OPENAI_API_KEY\":\"sk-test\"}\n"), 0o600) - if err != nil { - t.Fatalf("write credential file: %v", err) - } - - subscriber := observable.NewSubscriber[struct{}](8) - subscription, _ := subscriber.Subscription() - credential := &defaultCredential{ - tag: "test", - credentialPath: credentialPath, - logger: newTestLogger(), - statusSubscriber: subscriber, - } - - err = credential.markCredentialsUnavailable(errors.New("boom")) - if err == nil { - t.Fatal("expected error from markCredentialsUnavailable") - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after unavailable transition, got %d", count) - } - - err = credential.reloadCredentials(true) - if err != nil { - t.Fatalf("reload credentials: %v", err) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after recovery, got %d", count) - } - if !credential.isAvailable() { - t.Fatal("expected credential to become available") - } -} - -func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) { - subscriber := observable.NewSubscriber[struct{}](8) - subscription, _ := subscriber.Subscription() - credential := &externalCredential{ - tag: "receiver", - baseURL: reverseProxyBaseURL, - pollInterval: time.Minute, - logger: newTestLogger(), - statusSubscriber: subscriber, - } - - clientSession, _ := newTestYamuxSessionPair(t) - if !credential.setReverseSession(clientSession) { - t.Fatal("expected reverse session to be accepted") - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after reverse session up, got %d", count) - } - if !credential.isAvailable() { - t.Fatal("expected receiver credential to become available") - } - - credential.clearReverseSession(clientSession) - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event after reverse session down, got %d", count) - } - if credential.isAvailable() { - t.Fatal("expected receiver credential to become unavailable") - } -} From a2d6cf9715a25a5922ebe6a2addddeaa171167f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 21:14:14 +0800 Subject: [PATCH 64/96] fix(ocm): defer initial websocket rate-limit push --- service/ocm/service_websocket.go | 89 +++++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 24 deletions(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 520b6c58b5..6568f4dea2 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -49,6 +49,34 @@ func (s *webSocketSession) Close() { }) } +type webSocketResponseCreateRequest struct { + Type string `json:"type"` + Model string `json:"model"` + ServiceTier string `json:"service_tier"` + Generate *bool `json:"generate"` +} + +func parseWebSocketResponseCreateRequest(data []byte) (webSocketResponseCreateRequest, bool) { + var request webSocketResponseCreateRequest + if json.Unmarshal(data, &request) != nil { + return webSocketResponseCreateRequest{}, false + } + if request.Type != "response.create" || request.Model == "" { + return webSocketResponseCreateRequest{}, false + } + return request, true +} + +func (r webSocketResponseCreateRequest) isWarmup() bool { + return r.Generate != nil && !*r.Generate +} + +func signalWebSocketReady(channel chan struct{}, once *sync.Once) { + once.Do(func() { + close(channel) + }) +} + func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string { upstreamURL := baseURL if strings.HasPrefix(upstreamURL, "https://") { @@ -295,13 +323,15 @@ func (s *Service) handleWebSocket( var clientWriteAccess sync.Mutex modelChannel := make(chan string, 1) + firstRealRequest := make(chan struct{}) + var firstRealRequestOnce sync.Once var waitGroup sync.WaitGroup waitGroup.Add(3) go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID) + s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, firstRealRequest, &firstRealRequestOnce, isNew, username, sessionID) }() go func() { defer waitGroup.Done() @@ -311,12 +341,12 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, session.closed, provider, userConfig) + s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, session.closed, firstRealRequest, provider, userConfig) }() waitGroup.Wait() } -func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential Credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { +func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential Credential, modelChannel chan<- string, firstRealRequest chan struct{}, firstRealRequestOnce *sync.Once, isNew bool, username string, sessionID string) { logged := false for { data, opCode, err := wsutil.ReadClientData(clientConn) @@ -327,15 +357,10 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn return } + shouldSignalFirstRealRequest := false if opCode == ws.OpText { - var request struct { - Type string `json:"type"` - Model string `json:"model"` - ServiceTier string `json:"service_tier"` - Generate *bool `json:"generate"` - } - if json.Unmarshal(data, &request) == nil && request.Type == "response.create" && request.Model != "" { - isWarmup := request.Generate != nil && !*request.Generate + if request, ok := parseWebSocketResponseCreateRequest(data); ok { + isWarmup := request.isWarmup() if !isWarmup && isNew && !logged { logged = true logParts := []any{"assigned credential ", selectedCredential.tagName()} @@ -357,6 +382,9 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn default: } } + if !isWarmup { + shouldSignalFirstRealRequest = true + } } } @@ -367,6 +395,9 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn } return } + if shouldSignalFirstRealRequest { + signalWebSocketReady(firstRealRequest, firstRealRequestOnce) + } } } @@ -477,21 +508,22 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia selectedCredential.markRateLimited(resetAt) } -func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, sessionClosed <-chan struct{}, provider credentialProvider, userConfig *option.OCMUser) { +func writeWebSocketAggregatedStatus(clientConn net.Conn, clientWriteAccess *sync.Mutex, status aggregatedStatus) error { + data := buildSyntheticRateLimitsEvent(status) + clientWriteAccess.Lock() + defer clientWriteAccess.Unlock() + return wsutil.WriteServerMessage(clientConn, ws.OpText, data) +} + +func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, sessionClosed <-chan struct{}, firstRealRequest <-chan struct{}, provider credentialProvider, userConfig *option.OCMUser) { subscription, done, err := s.statusObserver.Subscribe() if err != nil { return } defer s.statusObserver.UnSubscribe(subscription) - last := s.computeAggregatedUtilization(provider, userConfig) - data := buildSyntheticRateLimitsEvent(last) - clientWriteAccess.Lock() - err = wsutil.WriteServerMessage(clientConn, ws.OpText, data) - clientWriteAccess.Unlock() - if err != nil { - return - } + var last aggregatedStatus + hasLast := false for { select { @@ -501,6 +533,15 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn return case <-sessionClosed: return + case <-firstRealRequest: + current := s.computeAggregatedUtilization(provider, userConfig) + err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current) + if err != nil { + return + } + last = current + hasLast = true + firstRealRequest = nil case <-subscription: for { select { @@ -510,15 +551,15 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn } } drained: + if !hasLast { + continue + } current := s.computeAggregatedUtilization(provider, userConfig) if current.equal(last) { continue } last = current - data = buildSyntheticRateLimitsEvent(current) - clientWriteAccess.Lock() - err = wsutil.WriteServerMessage(clientConn, ws.OpText, data) - clientWriteAccess.Unlock() + err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current) if err != nil { return } From b3429ef1f3a1b660eaef53eac3bfaade5650bcc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 21:39:31 +0800 Subject: [PATCH 65/96] fix(ocm): strip non-active rate-limit headers from forwarded responses --- service/ocm/service_status.go | 38 ++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index cb23f42d73..ebbc9ceaa6 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -239,23 +239,41 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) { status := s.computeAggregatedUtilization(provider, userConfig) - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier == "" { - activeLimitIdentifier = "codex" - } - headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64)) - headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64)) + headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64)) + headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64)) if !status.fiveHourReset.IsZero() { - headers.Set("x-"+activeLimitIdentifier+"-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) + headers.Set("x-codex-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) } else { - headers.Del("x-" + activeLimitIdentifier + "-primary-reset-at") + headers.Del("x-codex-primary-reset-at") } if !status.weeklyReset.IsZero() { - headers.Set("x-"+activeLimitIdentifier+"-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10)) + headers.Set("x-codex-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10)) } else { - headers.Del("x-" + activeLimitIdentifier + "-secondary-reset-at") + headers.Del("x-codex-secondary-reset-at") } if status.totalWeight > 0 { headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } + rateLimitSuffixes := [...]string{ + "-primary-used-percent", + "-primary-reset-at", + "-secondary-used-percent", + "-secondary-reset-at", + "-secondary-window-minutes", + "-limit-name", + } + for key := range headers { + lowerKey := strings.ToLower(key) + if !strings.HasPrefix(lowerKey, "x-") { + continue + } + for _, suffix := range rateLimitSuffixes { + if strings.HasSuffix(lowerKey, suffix) { + if strings.TrimSuffix(lowerKey, suffix) != "x-codex" { + headers.Del(key) + } + break + } + } + } } From 6b8838d323a17c7821b9043ec585ba9aed310f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 21:54:09 +0800 Subject: [PATCH 66/96] fix(ccm,ocm): restart status stream when receiver gets reverse session statusStreamLoop started on start() before any reverse session existed, got a non-retryable error, and exited permanently. Restart it when setReverseSession transitions receiver credentials to available. --- service/ccm/credential_external.go | 15 +++++++++++++++ service/ocm/credential_external.go | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index ba42ad64eb..98332e41d0 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -884,6 +884,8 @@ func (c *externalCredential) getReverseSession() *yamux.Session { func (c *externalCredential) setReverseSession(session *yamux.Session) bool { var emitStatus bool + var restartStatusStream bool + var triggerUsageRefresh bool c.reverseAccess.Lock() if c.closed { c.reverseAccess.Unlock() @@ -894,10 +896,23 @@ func (c *externalCredential) setReverseSession(session *yamux.Session) bool { c.reverseSession = session isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() emitStatus = wasAvailable != isAvailable + if isAvailable && !wasAvailable { + c.reverseCancel() + c.reverseContext, c.reverseCancel = context.WithCancel(context.Background()) + restartStatusStream = true + triggerUsageRefresh = true + } c.reverseAccess.Unlock() if old != nil { old.Close() } + if restartStatusStream { + c.logger.Debug("poll usage for ", c.tag, ": reverse session ready, restarting status stream") + go c.statusStreamLoop() + } + if triggerUsageRefresh { + go c.pollUsage(c.getReverseContext()) + } if emitStatus { c.emitStatusUpdate() } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index b06171d03b..47b6c0d5e4 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -954,6 +954,8 @@ func (c *externalCredential) getReverseSession() *yamux.Session { func (c *externalCredential) setReverseSession(session *yamux.Session) bool { var emitStatus bool + var restartStatusStream bool + var triggerUsageRefresh bool c.reverseAccess.Lock() if c.closed { c.reverseAccess.Unlock() @@ -964,10 +966,23 @@ func (c *externalCredential) setReverseSession(session *yamux.Session) bool { c.reverseSession = session isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed() emitStatus = wasAvailable != isAvailable + if isAvailable && !wasAvailable { + c.reverseCancel() + c.reverseContext, c.reverseCancel = context.WithCancel(context.Background()) + restartStatusStream = true + triggerUsageRefresh = true + } c.reverseAccess.Unlock() if old != nil { old.Close() } + if restartStatusStream { + c.logger.Debug("poll usage for ", c.tag, ": reverse session ready, restarting status stream") + go c.statusStreamLoop() + } + if triggerUsageRefresh { + go c.pollUsage(c.getReverseContext()) + } if emitStatus { c.emitStatusUpdate() } From b119d08764fe793753d8dae3beeea651a6f87d25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 22:37:38 +0800 Subject: [PATCH 67/96] fix(ccm,ocm): add usage logging to status stream, remove redundant isFirstUpdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit connectStatusStream updated credential state silently — no log on first frame or value changes. After restart, external credentials get usage via stream before any request, so pollIfStale skips them and no usage log ever appears. Add the same change-detection log to connectStatusStream. Also remove redundant isFirstUpdate guards from pollUsage and updateStateFromHeaders: when old values are zero, any non-zero new value already satisfies the integer-percent comparison. --- service/ccm/credential_external.go | 12 ++++++++++-- service/ocm/credential_external.go | 15 +++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 98332e41d0..f512052038 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -631,7 +631,6 @@ func (c *externalCredential) pollUsage(ctx context.Context) { } c.stateAccess.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 @@ -649,7 +648,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) @@ -739,6 +738,8 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } c.stateAccess.Lock() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization @@ -754,6 +755,13 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } + if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 47b6c0d5e4..3bcae1391f 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -460,7 +460,6 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.stateAccess.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization oldPlanWeight := c.state.remotePlanWeight @@ -516,7 +515,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) @@ -673,7 +672,6 @@ func (c *externalCredential) pollUsage(ctx context.Context) { } c.stateAccess.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 @@ -691,7 +689,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) @@ -781,6 +779,8 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } c.stateAccess.Lock() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization @@ -796,6 +796,13 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } + if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { From 3bcfdd54558013347dd610b81d69f6a88c1ff7cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 18 Mar 2026 00:54:01 +0800 Subject: [PATCH 68/96] fix(ccm,ocm): remove external context from pollUsage/pollIfStale MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pollUsage(ctx) accepted caller context, and service_status.go passed r.Context() which gets canceled on client disconnect or service shutdown. This caused incrementPollFailures → interruptConnections on transient cancellations. Each implementation now uses its own persistent context: defaultCredential uses serviceContext, externalCredential uses getReverseContext(). --- service/ccm/credential.go | 2 +- service/ccm/credential_default.go | 9 ++++++--- service/ccm/credential_external.go | 5 +++-- service/ccm/credential_provider.go | 10 +++++----- service/ccm/service_handler.go | 4 ++-- service/ccm/service_status.go | 4 ++-- service/ocm/credential.go | 2 +- service/ocm/credential_default.go | 4 +++- service/ocm/credential_external.go | 5 +++-- service/ocm/credential_provider.go | 10 +++++----- service/ocm/service_handler.go | 4 ++-- service/ocm/service_status.go | 4 ++-- 12 files changed, 35 insertions(+), 28 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 6f41ba1283..89ad5bb97d 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -120,7 +120,7 @@ type Credential interface { setStatusSubscriber(*observable.Subscriber[struct{}]) start() error - pollUsage(ctx context.Context) + pollUsage() lastUpdatedTime() time.Time pollBackoff(base time.Duration) time.Duration usageTrackerOrNil() *AggregatedUsage diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index c004d1e844..1003c7b36f 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -139,6 +139,7 @@ func (c *defaultCredential) start() error { c.logger.Warn("load usage statistics for ", c.tag, ": ", err) } } + go c.pollUsage() return nil } @@ -516,7 +517,7 @@ func (c *defaultCredential) earliestReset() time.Time { return earliest } -func (c *defaultCredential) pollUsage(ctx context.Context) { +func (c *defaultCredential) pollUsage() { if !c.pollAccess.TryLock() { return } @@ -537,6 +538,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { return } + ctx := c.serviceContext httpClient := &http.Client{ Transport: c.forwardHTTPClient.Transport, Timeout: 5 * time.Second, @@ -633,11 +635,12 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { c.emitStatusUpdate() if needsProfileFetch { - c.fetchProfile(ctx, httpClient, accessToken) + c.fetchProfile(httpClient, accessToken) } } -func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) { +func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken string) { + ctx := c.serviceContext response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) if err != nil { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index f512052038..40d2c3677b 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -574,13 +574,14 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp return nil, E.New("no transport available") } -func (c *externalCredential) pollUsage(ctx context.Context) { +func (c *externalCredential) pollUsage() { if !c.pollAccess.TryLock() { return } defer c.pollAccess.Unlock() defer c.markUsagePollAttempted() + ctx := c.getReverseContext() response, err := c.doPollUsageRequest(ctx) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": ", err) @@ -919,7 +920,7 @@ func (c *externalCredential) setReverseSession(session *yamux.Session) bool { go c.statusStreamLoop() } if triggerUsageRefresh { - go c.pollUsage(c.getReverseContext()) + go c.pollUsage() } if emitStatus { c.emitStatusUpdate() diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go index 8d993c6cae..9fac91b2c6 100644 --- a/service/ccm/credential_provider.go +++ b/service/ccm/credential_provider.go @@ -16,7 +16,7 @@ type credentialProvider interface { selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool - pollIfStale(ctx context.Context) + pollIfStale() allCredentials() []Credential close() } @@ -58,7 +58,7 @@ func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential return nil } -func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { +func (p *singleCredentialProvider) pollIfStale() { now := time.Now() p.sessionAccess.Lock() for id, createdAt := range p.sessions { @@ -69,7 +69,7 @@ func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { p.sessionAccess.Unlock() if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) { - p.credential.pollUsage(ctx) + p.credential.pollUsage() } } @@ -357,7 +357,7 @@ func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential { return usable[rand.IntN(len(usable))] } -func (p *balancerProvider) pollIfStale(ctx context.Context) { +func (p *balancerProvider) pollIfStale() { now := time.Now() p.sessionAccess.Lock() for id, entry := range p.sessions { @@ -377,7 +377,7 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) { for _, credential := range p.credentials { if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { - credential.pollUsage(ctx) + credential.pollUsage() } } } diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 1ccbd83ff4..6bface7383 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -182,7 +182,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - provider.pollIfStale(s.ctx) + provider.pollIfStale() anthropicBetaHeader := r.Header.Get("anthropic-beta") if isFastModeRequest(anthropicBetaHeader) { @@ -305,7 +305,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) - go selectedCredential.pollUsage(s.ctx) + go selectedCredential.pollUsage() writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) return diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index 50ba7ffe45..424afa1279 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -103,7 +103,7 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { return } - provider.pollIfStale(r.Context()) + provider.pollIfStale() status := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") @@ -125,7 +125,7 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } defer s.statusObserver.UnSubscribe(subscription) - provider.pollIfStale(r.Context()) + provider.pollIfStale() w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 1478f5f193..b0226f2d07 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -122,7 +122,7 @@ type Credential interface { setOnBecameUnusable(fn func()) setStatusSubscriber(*observable.Subscriber[struct{}]) start() error - pollUsage(ctx context.Context) + pollUsage() lastUpdatedTime() time.Time pollBackoff(base time.Duration) time.Duration usageTrackerOrNil() *AggregatedUsage diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index dd610e88de..3622ac8ffa 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -143,6 +143,7 @@ func (c *defaultCredential) start() error { c.logger.Warn("load usage statistics for ", c.tag, ": ", err) } } + go c.pollUsage() return nil } @@ -597,7 +598,7 @@ func (c *defaultCredential) ocmGetBaseURL() string { return c.getBaseURL() } -func (c *defaultCredential) pollUsage(ctx context.Context) { +func (c *defaultCredential) pollUsage() { if !c.pollAccess.TryLock() { return } @@ -621,6 +622,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { return } + ctx := c.serviceContext usageURL := strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage" accountID := c.getAccountID() diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 3bcae1391f..0ee1595d9b 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -615,13 +615,14 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp return nil, E.New("no transport available") } -func (c *externalCredential) pollUsage(ctx context.Context) { +func (c *externalCredential) pollUsage() { if !c.pollAccess.TryLock() { return } defer c.pollAccess.Unlock() defer c.markUsagePollAttempted() + ctx := c.getReverseContext() response, err := c.doPollUsageRequest(ctx) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": ", err) @@ -988,7 +989,7 @@ func (c *externalCredential) setReverseSession(session *yamux.Session) bool { go c.statusStreamLoop() } if triggerUsageRefresh { - go c.pollUsage(c.getReverseContext()) + go c.pollUsage() } if emitStatus { c.emitStatusUpdate() diff --git a/service/ocm/credential_provider.go b/service/ocm/credential_provider.go index 421258cd63..5d67eb0326 100644 --- a/service/ocm/credential_provider.go +++ b/service/ocm/credential_provider.go @@ -16,7 +16,7 @@ type credentialProvider interface { selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool - pollIfStale(ctx context.Context) + pollIfStale() allCredentials() []Credential close() } @@ -58,7 +58,7 @@ func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential return nil } -func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { +func (p *singleCredentialProvider) pollIfStale() { now := time.Now() p.sessionAccess.Lock() for id, createdAt := range p.sessions { @@ -69,7 +69,7 @@ func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { p.sessionAccess.Unlock() if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) { - p.credential.pollUsage(ctx) + p.credential.pollUsage() } } @@ -384,7 +384,7 @@ func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential { return usable[rand.IntN(len(usable))] } -func (p *balancerProvider) pollIfStale(ctx context.Context) { +func (p *balancerProvider) pollIfStale() { now := time.Now() p.sessionAccess.Lock() for id, entry := range p.sessions { @@ -404,7 +404,7 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) { for _, credential := range p.credentials { if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { - credential.pollUsage(ctx) + credential.pollUsage() } } } diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 52e35f39b9..8b50f748a1 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -131,7 +131,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - provider.pollIfStale(s.ctx) + provider.pollIfStale() selection := credentialSelectionForUser(userConfig) @@ -285,7 +285,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) - go selectedCredential.pollUsage(s.ctx) + go selectedCredential.pollUsage() writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) return diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index ebbc9ceaa6..3e7e2ff65f 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -103,7 +103,7 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { return } - provider.pollIfStale(r.Context()) + provider.pollIfStale() status := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") @@ -125,7 +125,7 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } defer s.statusObserver.UnSubscribe(subscription) - provider.pollIfStale(r.Context()) + provider.pollIfStale() w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) From 2fe1e37b179d258165cee8d31b134076f397b69a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 18 Mar 2026 01:00:55 +0800 Subject: [PATCH 69/96] fix(ccm,ocm): add missing isFirstUpdate to external credential usage logging --- service/ccm/credential_external.go | 6 ++++-- service/ocm/credential_external.go | 9 ++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 40d2c3677b..186d6d9d69 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -632,6 +632,7 @@ func (c *externalCredential) pollUsage() { } c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 @@ -649,7 +650,7 @@ func (c *externalCredential) pollUsage() { if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } - if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) @@ -739,6 +740,7 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 @@ -756,7 +758,7 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } - if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 0ee1595d9b..39dab378ee 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -460,6 +460,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization oldPlanWeight := c.state.remotePlanWeight @@ -515,7 +516,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() } - if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) @@ -673,6 +674,7 @@ func (c *externalCredential) pollUsage() { } c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 @@ -690,7 +692,7 @@ func (c *externalCredential) pollUsage() { if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } - if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) @@ -780,6 +782,7 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 @@ -797,7 +800,7 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } - if int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) From 7acba747555ad5112374b3adec1d3f23d5b9a5b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 18 Mar 2026 15:53:36 +0800 Subject: [PATCH 70/96] fix(ccm): forward 529 upstream overloaded response transparently --- service/ccm/service_handler.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 6bface7383..87f943cca9 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -302,6 +302,18 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { selectedCredential.updateStateFromHeaders(response.Header) + if response.StatusCode == 529 { + s.logger.WarnContext(ctx, "upstream overloaded from ", selectedCredential.tagName()) + for key, values := range response.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { + w.Header()[key] = values + } + } + w.WriteHeader(response.StatusCode) + io.Copy(w, response.Body) + return + } + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) From 608b7e7fa2de3e40c9677cdc0eb40909db8996aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Mar 2026 09:23:58 +0800 Subject: [PATCH 71/96] fix(ccm,ocm): stop cascading 429 retry storm on token refresh When the access token expires and refreshToken() gets 429, getAccessToken() returned the error but left credentials unchanged with no cooldown. Every subsequent request re-attempted the refresh, creating a burst that overwhelmed the token endpoint. - refreshToken() now returns Retry-After duration from 429 response headers (-1 when no header present, meaning permanently blocked) - getAccessToken() caches the 429 and blocks further refresh attempts until Retry-After expires (or permanently if no header) - reloadCredentials() clears the block when new credentials are loaded from file - Remove go pollUsage() on upstream errors (unrelated to usage state) --- service/ccm/credential_default.go | 24 +++++++++++++++++++++++- service/ccm/credential_file.go | 3 +++ service/ccm/credential_oauth.go | 24 ++++++++++++++++-------- service/ccm/service_handler.go | 1 - service/ocm/credential_default.go | 24 +++++++++++++++++++++++- service/ocm/credential_file.go | 3 +++ service/ocm/credential_oauth.go | 24 ++++++++++++++++-------- service/ocm/service_handler.go | 1 - 8 files changed, 84 insertions(+), 20 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 1003c7b36f..24549a82a1 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -46,6 +46,11 @@ type defaultCredential struct { statusSubscriber *observable.Subscriber[struct{}] + // Refresh rate-limit cooldown (protected by access mutex) + refreshRetryAt time.Time + refreshRetryError error + refreshBlocked bool + // Connection interruption interrupted bool requestContext context.Context @@ -197,16 +202,33 @@ func (c *defaultCredential) getAccessToken() (string, error) { return c.credentials.AccessToken, nil } + if c.refreshBlocked { + return "", c.refreshRetryError + } + if !c.refreshRetryAt.IsZero() && time.Now().Before(c.refreshRetryAt) { + return "", c.refreshRetryError + } + err = platformCanWriteCredentials(c.credentialPath) if err != nil { return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") } baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) + newCredentials, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) if err != nil { + if retryDelay < 0 { + c.refreshBlocked = true + c.refreshRetryError = err + } else if retryDelay > 0 { + c.refreshRetryAt = time.Now().Add(retryDelay) + c.refreshRetryError = err + } return "", err } + c.refreshRetryAt = time.Time{} + c.refreshRetryError = nil + c.refreshBlocked = false latestCredentials, latestErr := platformReadCredentials(c.credentialPath) if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index 4a65314712..3f67eaf13e 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -108,6 +108,9 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.access.Lock() c.credentials = credentials + c.refreshRetryAt = time.Time{} + c.refreshRetryError = nil + c.refreshBlocked = false c.access.Unlock() c.stateAccess.Lock() diff --git a/service/ccm/credential_oauth.go b/service/ccm/credential_oauth.go index da559c173d..114f87d3e3 100644 --- a/service/ccm/credential_oauth.go +++ b/service/ccm/credential_oauth.go @@ -11,6 +11,7 @@ import ( "path/filepath" "runtime" "slices" + "strconv" "sync" "time" @@ -144,9 +145,9 @@ func (c *oauthCredentials) needsRefresh() bool { return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs } -func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, time.Duration, error) { if credentials.RefreshToken == "" { - return nil, E.New("refresh token is empty") + return nil, 0, E.New("refresh token is empty") } requestBody, err := json.Marshal(map[string]string{ @@ -155,7 +156,7 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau "client_id": oauth2ClientID, }) if err != nil { - return nil, E.Cause(err, "marshal request") + return nil, 0, E.Cause(err, "marshal request") } response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { @@ -168,17 +169,24 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau return request, nil }) if err != nil { - return nil, err + return nil, 0, err } defer response.Body.Close() if response.StatusCode == http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + retryDelay := time.Duration(-1) + if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" { + seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64) + if parseErr == nil && seconds > 0 { + retryDelay = time.Duration(seconds) * time.Second + } + } + return nil, retryDelay, E.New("refresh rate limited: ", response.Status, " ", string(body)) } if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh failed: ", response.Status, " ", string(body)) + return nil, 0, E.New("refresh failed: ", response.Status, " ", string(body)) } var tokenResponse struct { @@ -188,7 +196,7 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau } err = json.NewDecoder(response.Body).Decode(&tokenResponse) if err != nil { - return nil, E.Cause(err, "decode response") + return nil, 0, E.Cause(err, "decode response") } newCredentials := *credentials @@ -198,7 +206,7 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau } newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000 - return &newCredentials, nil + return &newCredentials, 0, nil } func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 87f943cca9..e034dc0412 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -317,7 +317,6 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) - go selectedCredential.pollUsage() writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) return diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 3622ac8ffa..c3e8335bba 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -48,6 +48,11 @@ type defaultCredential struct { statusSubscriber *observable.Subscriber[struct{}] + // Refresh rate-limit cooldown (protected by access mutex) + refreshRetryAt time.Time + refreshRetryError error + refreshBlocked bool + // Connection interruption onBecameUnusable func() interrupted bool @@ -201,16 +206,33 @@ func (c *defaultCredential) getAccessToken() (string, error) { return c.credentials.getAccessToken(), nil } + if c.refreshBlocked { + return "", c.refreshRetryError + } + if !c.refreshRetryAt.IsZero() && time.Now().Before(c.refreshRetryAt) { + return "", c.refreshRetryError + } + err = platformCanWriteCredentials(c.credentialPath) if err != nil { return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") } baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) + newCredentials, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) if err != nil { + if retryDelay < 0 { + c.refreshBlocked = true + c.refreshRetryError = err + } else if retryDelay > 0 { + c.refreshRetryAt = time.Now().Add(retryDelay) + c.refreshRetryError = err + } return "", err } + c.refreshRetryAt = time.Time{} + c.refreshRetryError = nil + c.refreshBlocked = false latestCredentials, latestErr := platformReadCredentials(c.credentialPath) if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { diff --git a/service/ocm/credential_file.go b/service/ocm/credential_file.go index b15417a46f..d5f23a7e23 100644 --- a/service/ocm/credential_file.go +++ b/service/ocm/credential_file.go @@ -108,6 +108,9 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.access.Lock() c.credentials = credentials + c.refreshRetryAt = time.Time{} + c.refreshRetryError = nil + c.refreshBlocked = false c.access.Unlock() c.stateAccess.Lock() diff --git a/service/ocm/credential_oauth.go b/service/ocm/credential_oauth.go index bb240b5aba..fd4692998b 100644 --- a/service/ocm/credential_oauth.go +++ b/service/ocm/credential_oauth.go @@ -9,6 +9,7 @@ import ( "os" "os/user" "path/filepath" + "strconv" "time" E "github.com/sagernet/sing/common/exceptions" @@ -119,9 +120,9 @@ func (c *oauthCredentials) needsRefresh() bool { return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour } -func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, time.Duration, error) { if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" { - return nil, E.New("refresh token is empty") + return nil, 0, E.New("refresh token is empty") } requestBody, err := json.Marshal(map[string]string{ @@ -131,7 +132,7 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau "scope": "openid profile email", }) if err != nil { - return nil, E.Cause(err, "marshal request") + return nil, 0, E.Cause(err, "marshal request") } response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { @@ -144,17 +145,24 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau return request, nil }) if err != nil { - return nil, err + return nil, 0, err } defer response.Body.Close() if response.StatusCode == http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + retryDelay := time.Duration(-1) + if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" { + seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64) + if parseErr == nil && seconds > 0 { + retryDelay = time.Duration(seconds) * time.Second + } + } + return nil, retryDelay, E.New("refresh rate limited: ", response.Status, " ", string(body)) } if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh failed: ", response.Status, " ", string(body)) + return nil, 0, E.New("refresh failed: ", response.Status, " ", string(body)) } var tokenResponse struct { @@ -164,7 +172,7 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau } err = json.NewDecoder(response.Body).Decode(&tokenResponse) if err != nil { - return nil, E.Cause(err, "decode response") + return nil, 0, E.Cause(err, "decode response") } newCredentials := *credentials @@ -183,7 +191,7 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau now := time.Now() newCredentials.LastRefresh = &now - return &newCredentials, nil + return &newCredentials, 0, nil } func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 8b50f748a1..d4a04457a6 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -285,7 +285,6 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) - go selectedCredential.pollUsage() writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) return From 99d9e06dd0a18dd61372a7028e0638b1dabe575e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Mar 2026 10:31:17 +0800 Subject: [PATCH 72/96] fix(ccm,ocm): handle upstream 400 by marking external credentials rejected and polling default credentials External credentials returning 400 are marked unavailable for pollInterval duration; status stream/poll success clears the rejection early. Default credentials trigger a stale poll to let the usage API detect account issues without causing 429 storms. --- service/ccm/credential.go | 2 ++ service/ccm/credential_default.go | 2 ++ service/ccm/credential_external.go | 22 +++++++++++++++++++++- service/ccm/credential_provider.go | 13 +++++++++++++ service/ccm/service_handler.go | 11 +++++++++++ service/ocm/credential.go | 2 ++ service/ocm/credential_default.go | 2 ++ service/ocm/credential_external.go | 22 +++++++++++++++++++++- service/ocm/credential_provider.go | 13 +++++++++++++ service/ocm/service_handler.go | 11 +++++++++++ service/ocm/service_websocket.go | 6 ++++++ 11 files changed, 104 insertions(+), 2 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 89ad5bb97d..ed2c713d3f 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -66,6 +66,7 @@ type credentialState struct { consecutivePollFailures int usageAPIRetryDelay time.Duration unavailable bool + upstreamRejectedUntil time.Time lastCredentialLoadAttempt time.Time lastCredentialLoadError string } @@ -108,6 +109,7 @@ type Credential interface { fiveHourResetTime() time.Time weeklyResetTime() time.Time markRateLimited(resetAt time.Time) + markUpstreamRejected() earliestReset() time.Time unavailableError() error diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 24549a82a1..68a37f1226 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -354,6 +354,8 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { c.emitStatusUpdate() } +func (c *defaultCredential) markUpstreamRejected() {} + func (c *defaultCredential) isUsable() bool { c.retryCredentialReloadIfNeeded() diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 186d6d9d69..94f353b6a2 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -273,6 +273,10 @@ func (c *externalCredential) isUsable() bool { c.stateAccess.RUnlock() return false } + if !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil) { + c.stateAccess.RUnlock() + return false + } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { c.stateAccess.RUnlock() @@ -347,6 +351,18 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { c.emitStatusUpdate() } +func (c *externalCredential) markUpstreamRejected() { + c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(c.pollInterval)) + c.stateAccess.Lock() + c.state.upstreamRejectedUntil = time.Now().Add(c.pollInterval) + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + c.emitStatusUpdate() +} + func (c *externalCredential) earliestReset() time.Time { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -475,6 +491,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } if hadData { c.state.consecutivePollFailures = 0 + c.state.upstreamRejectedUntil = time.Time{} c.state.lastUpdated = time.Now() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { @@ -499,7 +516,8 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } func (c *externalCredential) checkTransitionLocked() bool { - unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 + upstreamRejected := !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil) + unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 || upstreamRejected if unusable && !c.interrupted { c.interrupted = true return true @@ -636,6 +654,7 @@ func (c *externalCredential) pollUsage() { oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 + c.state.upstreamRejectedUntil = time.Time{} c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization if statusResponse.PlanWeight > 0 { @@ -744,6 +763,7 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 + c.state.upstreamRejectedUntil = time.Time{} c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization if statusResponse.PlanWeight > 0 { diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go index 9fac91b2c6..640ced702a 100644 --- a/service/ccm/credential_provider.go +++ b/service/ccm/credential_provider.go @@ -17,6 +17,7 @@ type credentialProvider interface { onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool pollIfStale() + pollCredentialIfStale(credential Credential) allCredentials() []Credential close() } @@ -83,6 +84,12 @@ func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credent } } +func (p *singleCredentialProvider) pollCredentialIfStale(credential Credential) { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) { + credential.pollUsage() + } +} + func (p *singleCredentialProvider) close() {} type sessionEntry struct { @@ -382,6 +389,12 @@ func (p *balancerProvider) pollIfStale() { } } +func (p *balancerProvider) pollCredentialIfStale(credential Credential) { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage() + } +} + func (p *balancerProvider) allCredentials() []Credential { return p.credentials } diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index e034dc0412..8e6d65e244 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -314,6 +314,17 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if response.StatusCode == http.StatusBadRequest { + if selectedCredential.isExternal() { + selectedCredential.markUpstreamRejected() + } else { + provider.pollCredentialIfStale(selectedCredential) + } + s.logger.ErrorContext(ctx, "upstream rejected from ", selectedCredential.tagName(), ": status ", response.StatusCode) + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential") + return + } + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) diff --git a/service/ocm/credential.go b/service/ocm/credential.go index b0226f2d07..0c4e56cddd 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -67,6 +67,7 @@ type credentialState struct { consecutivePollFailures int usageAPIRetryDelay time.Duration unavailable bool + upstreamRejectedUntil time.Time lastCredentialLoadAttempt time.Time lastCredentialLoadError string } @@ -109,6 +110,7 @@ type Credential interface { weeklyResetTime() time.Time fiveHourResetTime() time.Time markRateLimited(resetAt time.Time) + markUpstreamRejected() earliestReset() time.Time unavailableError() error diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index c3e8335bba..61f1314195 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -394,6 +394,8 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { c.emitStatusUpdate() } +func (c *defaultCredential) markUpstreamRejected() {} + func (c *defaultCredential) isUsable() bool { c.retryCredentialReloadIfNeeded() diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 39dab378ee..f8a73684e3 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -298,6 +298,10 @@ func (c *externalCredential) isUsable() bool { c.stateAccess.RUnlock() return false } + if !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil) { + c.stateAccess.RUnlock() + return false + } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { c.stateAccess.RUnlock() @@ -371,6 +375,18 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { c.emitStatusUpdate() } +func (c *externalCredential) markUpstreamRejected() { + c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(c.pollInterval)) + c.stateAccess.Lock() + c.state.upstreamRejectedUntil = time.Now().Add(c.pollInterval) + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + c.emitStatusUpdate() +} + func (c *externalCredential) earliestReset() time.Time { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -514,6 +530,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } if hadData { c.state.consecutivePollFailures = 0 + c.state.upstreamRejectedUntil = time.Time{} c.state.lastUpdated = time.Now() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { @@ -538,7 +555,8 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { } func (c *externalCredential) checkTransitionLocked() bool { - unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 + upstreamRejected := !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil) + unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 || upstreamRejected if unusable && !c.interrupted { c.interrupted = true return true @@ -678,6 +696,7 @@ func (c *externalCredential) pollUsage() { oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 + c.state.upstreamRejectedUntil = time.Time{} c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization if statusResponse.FiveHourReset > 0 { @@ -786,6 +805,7 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 + c.state.upstreamRejectedUntil = time.Time{} c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization if statusResponse.FiveHourReset > 0 { diff --git a/service/ocm/credential_provider.go b/service/ocm/credential_provider.go index 5d67eb0326..714e44ab7c 100644 --- a/service/ocm/credential_provider.go +++ b/service/ocm/credential_provider.go @@ -17,6 +17,7 @@ type credentialProvider interface { onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool pollIfStale() + pollCredentialIfStale(credential Credential) allCredentials() []Credential close() } @@ -83,6 +84,12 @@ func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credent } } +func (p *singleCredentialProvider) pollCredentialIfStale(credential Credential) { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) { + credential.pollUsage() + } +} + func (p *singleCredentialProvider) close() {} type sessionEntry struct { @@ -409,6 +416,12 @@ func (p *balancerProvider) pollIfStale() { } } +func (p *balancerProvider) pollCredentialIfStale(credential Credential) { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage() + } +} + func (p *balancerProvider) allCredentials() []Credential { return p.credentials } diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index d4a04457a6..c2e90a582d 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -282,6 +282,17 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { selectedCredential.updateStateFromHeaders(response.Header) + if response.StatusCode == http.StatusBadRequest { + if selectedCredential.isExternal() { + selectedCredential.markUpstreamRejected() + } else { + provider.pollCredentialIfStale(selectedCredential) + } + s.logger.ErrorContext(ctx, "upstream rejected from ", selectedCredential.tagName(), ": status ", response.StatusCode) + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential") + return + } + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 6568f4dea2..96d7e58d3a 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -264,6 +264,12 @@ func (s *Service) handleWebSocket( selectedCredential = nextCredential continue } + if statusCode == http.StatusBadRequest && selectedCredential.isExternal() { + selectedCredential.markUpstreamRejected() + s.logger.ErrorContext(ctx, "upstream rejected websocket from ", selectedCredential.tagName(), ": status ", statusCode) + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential") + return + } if statusCode > 0 && statusResponseBody != "" { s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody) } else { From 53f832330dded92fc1a552970f9436c09ad80e9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Mar 2026 10:45:24 +0800 Subject: [PATCH 73/96] fix(ccm): adapt to Claude Code v2.1.78 metadata format, separate state from credentials Claude Code v2.1.78 changed metadata.user_id from a template literal (`user_${id}_account_${uuid}_session_${sid}`) to a JSON-encoded object (`JSON.stringify({device_id, account_uuid, session_id})`), breaking session ID extraction via `_session_` substring match. - Fix extractCCMSessionID to try JSON parse first, fallback to legacy - Remove subscriptionType/rateLimitTier/isMax from oauthCredentials (profile state does not belong in auth credentials) - Add state_path option for persisting profile state across restarts - Parse account.uuid from /api/oauth/profile response - Inject account_uuid into forwarded requests when client sends it empty (happens when using ANTHROPIC_AUTH_TOKEN instead of Claude AI OAuth) --- option/ccm.go | 1 + service/ccm/credential.go | 1 + service/ccm/credential_default.go | 77 ++++++++++++++++++++++++++-- service/ccm/credential_file.go | 2 - service/ccm/credential_oauth.go | 24 +++++---- service/ccm/credential_state_file.go | 64 +++++++++++++++++++++++ service/ccm/service_handler.go | 23 +++++++++ 7 files changed, 175 insertions(+), 17 deletions(-) create mode 100644 service/ccm/credential_state_file.go diff --git a/option/ccm.go b/option/ccm.go index 7a4f0709f3..96200248a9 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -77,6 +77,7 @@ func (c *CCMCredential) UnmarshalJSON(bytes []byte) error { type CCMDefaultCredentialOptions struct { CredentialPath string `json:"credential_path,omitempty"` + StatePath string `json:"state_path,omitempty"` UsagesPath string `json:"usages_path,omitempty"` Detour string `json:"detour,omitempty"` Reserve5h uint8 `json:"reserve_5h"` diff --git a/service/ccm/credential.go b/service/ccm/credential.go index ed2c713d3f..f2d4003f12 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -59,6 +59,7 @@ type credentialState struct { weeklyReset time.Time hardRateLimited bool rateLimitResetAt time.Time + accountUUID string accountType string rateLimitTier string remotePlanWeight float64 diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 68a37f1226..9c19a2cc91 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -29,6 +29,7 @@ type defaultCredential struct { serviceContext context.Context credentialPath string credentialFilePath string + statePath string credentials *oauthCredentials access sync.RWMutex state credentialState @@ -106,6 +107,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef tag: tag, serviceContext: ctx, credentialPath: options.CredentialPath, + statePath: options.StatePath, cap5h: cap5h, capWeekly: capWeekly, forwardHTTPClient: httpClient, @@ -130,6 +132,7 @@ func (c *defaultCredential) start() error { return E.Cause(err, "resolve credential path for ", c.tag) } c.credentialFilePath = credentialFilePath + c.loadPersistedState() err = c.ensureCredentialWatcher() if err != nil { c.logger.Debug("start credential watcher for ", c.tag, ": ", err) @@ -238,8 +241,6 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.state.unavailable = false c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" - c.state.accountType = latestCredentials.SubscriptionType - c.state.rateLimitTier = latestCredentials.RateLimitTier c.checkTransitionLocked() shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() @@ -258,8 +259,6 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.state.unavailable = false c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" - c.state.accountType = newCredentials.SubscriptionType - c.state.rateLimitTier = newCredentials.RateLimitTier c.checkTransitionLocked() shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() @@ -663,6 +662,12 @@ func (c *defaultCredential) pollUsage() { } } +// fetchProfile calls GET /api/oauth/profile to retrieve account and organization info. +// Same endpoint used by Claude Code (@anthropic-ai/claude-code @2.1.81): +// +// ref: cli.js GB() — fetches profile +// ref: cli.js AH8() / fetchProfileInfo — parses organization_type, rate_limit_tier +// ref: cli.js EX1() / populateOAuthAccountInfoIfNeeded — stores account.uuid func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken string) { ctx := c.serviceContext response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { @@ -686,6 +691,9 @@ func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken st } var profileResponse struct { + Account *struct { + UUID string `json:"uuid"` + } `json:"account"` Organization *struct { OrganizationType string `json:"organization_type"` RateLimitTier string `json:"rate_limit_tier"` @@ -711,6 +719,9 @@ func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken st c.stateAccess.Lock() before := c.statusSnapshotLocked() + if profileResponse.Account != nil && profileResponse.Account.UUID != "" { + c.state.accountUUID = profileResponse.Account.UUID + } if accountType != "" && c.state.accountType == "" { c.state.accountType = accountType } @@ -723,6 +734,7 @@ func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken st if shouldEmit { c.emitStatusUpdate() } + c.savePersistedState() c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier)) } @@ -781,6 +793,7 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt proxyURL := claudeAPIBaseURL + original.URL.RequestURI() var body io.Reader if bodyBytes != nil { + bodyBytes = c.injectAccountUUID(bodyBytes) body = bytes.NewReader(bodyBytes) } else { body = original.Body @@ -816,3 +829,59 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt return proxyRequest, nil } + +// injectAccountUUID fills in the account_uuid field in metadata.user_id +// when the client sends it empty (e.g. using ANTHROPIC_AUTH_TOKEN). +// +// Claude Code >= 2.1.78 (@anthropic-ai/claude-code) sets metadata as: +// +// {user_id: JSON.stringify({device_id, account_uuid, session_id})} +// +// ref: cli.js L66() — metadata constructor +// +// account_uuid is populated from oauthAccount.accountUuid which comes from +// the /api/oauth/profile endpoint (ref: cli.js EX1() → fP6()). +// When the client uses ANTHROPIC_AUTH_TOKEN instead of Claude AI OAuth, +// account_uuid is empty. We inject it from the fetchProfile result. +func (c *defaultCredential) injectAccountUUID(bodyBytes []byte) []byte { + c.stateAccess.RLock() + accountUUID := c.state.accountUUID + c.stateAccess.RUnlock() + if accountUUID == "" { + return bodyBytes + } + + var body struct { + Metadata struct { + UserID string `json:"user_id"` + } `json:"metadata"` + } + if json.Unmarshal(bodyBytes, &body) != nil || body.Metadata.UserID == "" { + return bodyBytes + } + + var userIDObject map[string]any + if json.Unmarshal([]byte(body.Metadata.UserID), &userIDObject) != nil { + return bodyBytes + } + existing, _ := userIDObject["account_uuid"].(string) + if existing != "" { + return bodyBytes + } + userIDObject["account_uuid"] = accountUUID + newUserID, err := json.Marshal(userIDObject) + if err != nil { + return bodyBytes + } + + newUserIDStr := string(newUserID) + oldUserIDJSON, err := json.Marshal(body.Metadata.UserID) + if err != nil { + return bodyBytes + } + newUserIDJSON, err := json.Marshal(newUserIDStr) + if err != nil { + return bodyBytes + } + return bytes.Replace(bodyBytes, oldUserIDJSON, newUserIDJSON, 1) +} diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index 3f67eaf13e..23e71d5442 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -117,8 +117,6 @@ func (c *defaultCredential) reloadCredentials(force bool) error { before := c.statusSnapshotLocked() c.state.unavailable = false c.state.lastCredentialLoadError = "" - c.state.accountType = credentials.SubscriptionType - c.state.rateLimitTier = credentials.RateLimitTier c.checkTransitionLocked() shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() diff --git a/service/ccm/credential_oauth.go b/service/ccm/credential_oauth.go index 114f87d3e3..9d0b5146e0 100644 --- a/service/ccm/credential_oauth.go +++ b/service/ccm/credential_oauth.go @@ -128,14 +128,19 @@ func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) err return os.WriteFile(path, data, 0o600) } +// oauthCredentials mirrors the claudeAiOauth object in Claude Code's +// credential file ($CLAUDE_CONFIG_DIR/.credentials.json). +// +// ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6() / refreshOAuthToken +// +// Note: subscriptionType, rateLimitTier, and isMax were removed from this +// struct — they are profile state, not auth credentials. Claude Code also +// stores them here, but we persist them separately via state_path instead. type oauthCredentials struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` - ExpiresAt int64 `json:"expiresAt"` - Scopes []string `json:"scopes,omitempty"` - SubscriptionType string `json:"subscriptionType,omitempty"` - RateLimitTier string `json:"rateLimitTier,omitempty"` - IsMax bool `json:"isMax,omitempty"` + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresAt int64 `json:"expiresAt"` + Scopes []string `json:"scopes,omitempty"` } func (c *oauthCredentials) needsRefresh() bool { @@ -225,8 +230,5 @@ func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { return left.AccessToken == right.AccessToken && left.RefreshToken == right.RefreshToken && left.ExpiresAt == right.ExpiresAt && - slices.Equal(left.Scopes, right.Scopes) && - left.SubscriptionType == right.SubscriptionType && - left.RateLimitTier == right.RateLimitTier && - left.IsMax == right.IsMax + slices.Equal(left.Scopes, right.Scopes) } diff --git a/service/ccm/credential_state_file.go b/service/ccm/credential_state_file.go new file mode 100644 index 0000000000..630e261887 --- /dev/null +++ b/service/ccm/credential_state_file.go @@ -0,0 +1,64 @@ +package ccm + +import ( + "encoding/json" + "os" +) + +// persistedState holds profile data fetched from /api/oauth/profile, +// persisted to state_path so it survives restarts without re-fetching. +// +// Claude Code (@anthropic-ai/claude-code @2.1.81) stores equivalent data in +// its config file (~/.claude/.config.json) under the oauthAccount key: +// +// ref: cli.js fP6() / storeOAuthAccountInfo — writes accountUuid, billingType, etc. +// ref: cli.js P8() — reads config from $CLAUDE_CONFIG_DIR/.config.json +type persistedState struct { + AccountUUID string `json:"account_uuid,omitempty"` + AccountType string `json:"account_type,omitempty"` + RateLimitTier string `json:"rate_limit_tier,omitempty"` +} + +func (c *defaultCredential) loadPersistedState() { + if c.statePath == "" { + return + } + data, err := os.ReadFile(c.statePath) + if err != nil { + return + } + var state persistedState + err = json.Unmarshal(data, &state) + if err != nil { + return + } + c.stateAccess.Lock() + if state.AccountUUID != "" { + c.state.accountUUID = state.AccountUUID + } + if state.AccountType != "" { + c.state.accountType = state.AccountType + } + if state.RateLimitTier != "" { + c.state.rateLimitTier = state.RateLimitTier + } + c.stateAccess.Unlock() +} + +func (c *defaultCredential) savePersistedState() { + if c.statePath == "" { + return + } + c.stateAccess.RLock() + state := persistedState{ + AccountUUID: c.state.accountUUID, + AccountType: c.state.accountType, + RateLimitTier: c.state.rateLimitTier, + } + c.stateAccess.RUnlock() + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return + } + os.WriteFile(c.statePath, data, 0o600) +} diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 8e6d65e244..07b0556d46 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -69,6 +69,19 @@ func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { } } +// extractCCMSessionID extracts the session ID from the request body's metadata.user_id field. +// +// Claude Code >= 2.1.78 (@anthropic-ai/claude-code) encodes user_id as: +// +// JSON.stringify({device_id, account_uuid, session_id, ...extras}) +// +// ref: cli.js L66() — metadata constructor +// +// Claude Code < 2.1.78 used a template literal: +// +// `user_${deviceId}_account_${accountUuid}_session_${sessionId}` +// +// ref: cli.js qs() — old metadata constructor func extractCCMSessionID(bodyBytes []byte) string { var body struct { Metadata struct { @@ -80,6 +93,16 @@ func extractCCMSessionID(bodyBytes []byte) string { return "" } userID := body.Metadata.UserID + + // v2.1.78+ JSON object format + var userIDObject struct { + SessionID string `json:"session_id"` + } + if json.Unmarshal([]byte(userID), &userIDObject) == nil && userIDObject.SessionID != "" { + return userIDObject.SessionID + } + + // legacy template literal format sessionIndex := strings.LastIndex(userID, "_session_") if sessionIndex < 0 { return "" From 29b901a8b3048ad27dda928a00ac0ff00bff9d67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Mar 2026 11:00:05 +0800 Subject: [PATCH 74/96] fix(ccm): robust account UUID injection and session ID validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace bytes.Replace-based UUID injection with proper JSON unmarshal/re-marshal through map[string]json.RawMessage — the old approach silently failed when the body used non-canonical JSON escaping. Return 500 when metadata.user_id is present but in an unrecognized format, instead of silently passing through with an empty session ID. --- service/ccm/credential_default.go | 64 +++++++++++++++++++++++-------- service/ccm/service_handler.go | 28 ++++++++++---- 2 files changed, 69 insertions(+), 23 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 9c19a2cc91..57e0ca912e 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -851,37 +851,71 @@ func (c *defaultCredential) injectAccountUUID(bodyBytes []byte) []byte { return bodyBytes } - var body struct { - Metadata struct { - UserID string `json:"user_id"` - } `json:"metadata"` + var body map[string]json.RawMessage + err := json.Unmarshal(bodyBytes, &body) + if err != nil { + return bodyBytes + } + metadataRaw, hasMetadata := body["metadata"] + if !hasMetadata { + return bodyBytes + } + + var metadata map[string]json.RawMessage + err = json.Unmarshal(metadataRaw, &metadata) + if err != nil { + return bodyBytes + } + userIDRaw, hasUserID := metadata["user_id"] + if !hasUserID { + return bodyBytes + } + + var userIDStr string + err = json.Unmarshal(userIDRaw, &userIDStr) + if err != nil || userIDStr == "" { + return bodyBytes } - if json.Unmarshal(bodyBytes, &body) != nil || body.Metadata.UserID == "" { + + var userIDObject map[string]json.RawMessage + err = json.Unmarshal([]byte(userIDStr), &userIDObject) + if err != nil { return bodyBytes } - var userIDObject map[string]any - if json.Unmarshal([]byte(body.Metadata.UserID), &userIDObject) != nil { + existingRaw, hasExisting := userIDObject["account_uuid"] + if hasExisting { + var existing string + if json.Unmarshal(existingRaw, &existing) == nil && existing != "" { + return bodyBytes + } + } + + accountUUIDJSON, err := json.Marshal(accountUUID) + if err != nil { return bodyBytes } - existing, _ := userIDObject["account_uuid"].(string) - if existing != "" { + userIDObject["account_uuid"] = json.RawMessage(accountUUIDJSON) + + newUserIDBytes, err := json.Marshal(userIDObject) + if err != nil { return bodyBytes } - userIDObject["account_uuid"] = accountUUID - newUserID, err := json.Marshal(userIDObject) + newUserIDRaw, err := json.Marshal(string(newUserIDBytes)) if err != nil { return bodyBytes } + metadata["user_id"] = json.RawMessage(newUserIDRaw) - newUserIDStr := string(newUserID) - oldUserIDJSON, err := json.Marshal(body.Metadata.UserID) + newMetadataBytes, err := json.Marshal(metadata) if err != nil { return bodyBytes } - newUserIDJSON, err := json.Marshal(newUserIDStr) + body["metadata"] = json.RawMessage(newMetadataBytes) + + newBodyBytes, err := json.Marshal(body) if err != nil { return bodyBytes } - return bytes.Replace(bodyBytes, oldUserIDJSON, newUserIDJSON, 1) + return newBodyBytes } diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 07b0556d46..9ce78db530 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -82,15 +82,21 @@ func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { // `user_${deviceId}_account_${accountUuid}_session_${sessionId}` // // ref: cli.js qs() — old metadata constructor -func extractCCMSessionID(bodyBytes []byte) string { +// +// Returns ("", nil) when body has no metadata.user_id (non-message endpoints). +// Returns error when user_id is present but in an unrecognized format. +func extractCCMSessionID(bodyBytes []byte) (string, error) { var body struct { - Metadata struct { + Metadata *struct { UserID string `json:"user_id"` } `json:"metadata"` } err := json.Unmarshal(bodyBytes, &body) if err != nil { - return "" + return "", nil + } + if body.Metadata == nil || body.Metadata.UserID == "" { + return "", nil } userID := body.Metadata.UserID @@ -99,15 +105,16 @@ func extractCCMSessionID(bodyBytes []byte) string { SessionID string `json:"session_id"` } if json.Unmarshal([]byte(userID), &userIDObject) == nil && userIDObject.SessionID != "" { - return userIDObject.SessionID + return userIDObject.SessionID, nil } // legacy template literal format sessionIndex := strings.LastIndex(userID, "_session_") - if sessionIndex < 0 { - return "" + if sessionIndex >= 0 { + return userID[sessionIndex+len("_session_"):], nil } - return userID[sessionIndex+len("_session_"):] + + return "", E.New("unrecognized metadata.user_id format: ", userID) } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -181,7 +188,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { messagesCount = len(request.Messages) } - sessionID = extractCCMSessionID(bodyBytes) + sessionID, err = extractCCMSessionID(bodyBytes) + if err != nil { + s.logger.ErrorContext(ctx, "invalid metadata format: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "invalid metadata format") + return + } r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } From f172a575b7383cd0fcea2b4c347342fc89dca4c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Mar 2026 11:07:43 +0800 Subject: [PATCH 75/96] fix(ccm): log assigned credential for each distinct model per session --- service/ccm/service.go | 22 ++++++++++++++++++++++ service/ccm/service_handler.go | 24 ++++++++++++++++++------ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/service/ccm/service.go b/service/ccm/service.go index 8a9d8f17f0..d3f76381f0 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -5,6 +5,8 @@ import ( "encoding/json" "net/http" "strings" + "sync" + "time" "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" @@ -163,10 +165,29 @@ type Service struct { allCredentials []Credential userConfigMap map[string]*option.CCMUser + sessionModelAccess sync.Mutex + sessionModels map[sessionModelKey]time.Time + statusSubscriber *observable.Subscriber[struct{}] statusObserver *observable.Observer[struct{}] } +type sessionModelKey struct { + sessionID string + model string +} + +func (s *Service) cleanSessionModels() { + now := time.Now() + s.sessionModelAccess.Lock() + for key, createdAt := range s.sessionModels { + if now.Sub(createdAt) > sessionExpiry { + delete(s.sessionModels, key) + } + } + s.sessionModelAccess.Unlock() +} + func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) { initCCMUserAgent(logger) @@ -212,6 +233,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio Listen: options.ListenOptions, }), userManager: userManager, + sessionModels: make(map[sessionModelKey]time.Time), statusSubscriber: statusSubscriber, statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8), } diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 9ce78db530..94af4e90e8 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -218,6 +218,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale() + s.cleanSessionModels() anthropicBetaHeader := r.Header.Get("anthropic-beta") if isFastModeRequest(anthropicBetaHeader) { @@ -235,7 +236,22 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) return } - if isNew { + modelDisplay := requestModel + if requestModel != "" && isExtendedContextRequest(anthropicBetaHeader) { + modelDisplay += "[1m]" + } + isNewModel := false + if sessionID != "" && modelDisplay != "" { + key := sessionModelKey{sessionID, modelDisplay} + s.sessionModelAccess.Lock() + _, exists := s.sessionModels[key] + if !exists { + s.sessionModels[key] = time.Now() + isNewModel = true + } + s.sessionModelAccess.Unlock() + } + if isNew || isNewModel { logParts := []any{"assigned credential ", selectedCredential.tagName()} if sessionID != "" { logParts = append(logParts, " for session ", sessionID) @@ -243,11 +259,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if username != "" { logParts = append(logParts, " by user ", username) } - if requestModel != "" { - modelDisplay := requestModel - if isExtendedContextRequest(anthropicBetaHeader) { - modelDisplay += "[1m]" - } + if modelDisplay != "" { logParts = append(logParts, ", model=", modelDisplay) } s.logger.DebugContext(ctx, logParts...) From 0950783479031befd86956867800ac9ac0eb68bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Mar 2026 11:42:49 +0800 Subject: [PATCH 76/96] fix(ccm,ocm): exclude unusable credentials from status aggregation computeAggregatedUtilization used isAvailable() which only checks permanent unavailability, so credentials rejected by upstream 400 still had their planWeight included in the total, inflating reported capacity and diluting utilization. --- service/ccm/service_status.go | 2 +- service/ocm/service_status.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index 424afa1279..41256b7f8b 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -176,7 +176,7 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user var totalWeightedHoursUntil5hReset, total5hResetWeight float64 var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 for _, credential := range provider.allCredentials() { - if !credential.isAvailable() { + if !credential.isUsable() { continue } if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index 3e7e2ff65f..092209e981 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -176,7 +176,7 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user var totalWeightedHoursUntil5hReset, total5hResetWeight float64 var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 for _, credential := range provider.allCredentials() { - if !credential.isAvailable() { + if !credential.isUsable() { continue } if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { From 084a6f1302147791beeff0533bb3a063b16d4687 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 22 Mar 2026 06:02:55 +0800 Subject: [PATCH 77/96] fix(ccm): align OAuth token refresh with Claude Code v2.1.81 After re-login with newer Claude Code (v2.1.75+), CCM refresh requests returned persistent 429s. Root cause: CCM omitted the `scope` parameter that the server now requires for tokens with `user:file_upload` scope. Changes to fully match Claude Code's OAuth behavior: - Add `scope` parameter to token refresh request body - Parse `scope` from refresh response and store back - Add `subscriptionType`/`rateLimitTier` to credential struct to preserve Claude Code's profile state on write-back - Change credential file write to read-modify-write, preserving other top-level JSON keys (matches Claude Code's BP6 pattern) - Same for macOS keychain write path - Increase token expiry buffer from 1 min to 5 min (matching CC's isOAuthTokenExpired with 300s buffer) - Add cross-process mkdir-based file lock compatible with Claude Code's proper-lockfile protocol (~/.claude.lock) - Add post-failure recovery: re-read credentials from disk after refresh failure in case another process succeeded - Add 401/403 "OAuth token has been revoked" recovery in proxy handler: reload credentials and retry once --- service/ccm/credential_darwin.go | 78 ++++++++++++++++--------- service/ccm/credential_default.go | 48 ++++++++++++++++ service/ccm/credential_lock.go | 84 +++++++++++++++++++++++++++ service/ccm/credential_oauth.go | 94 +++++++++++++++++++++++++------ service/ccm/service_handler.go | 37 ++++++++++++ 5 files changed, 297 insertions(+), 44 deletions(-) create mode 100644 service/ccm/credential_lock.go diff --git a/service/ccm/credential_darwin.go b/service/ccm/credential_darwin.go index aef10c8748..d025fd47c1 100644 --- a/service/ccm/credential_darwin.go +++ b/service/ccm/credential_darwin.go @@ -76,41 +76,65 @@ func platformCanWriteCredentials(customPath string) error { return checkCredentialFileWritable(customPath) } -func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error { +// platformWriteCredentials performs a read-modify-write on the keychain entry, +// preserving any fields or top-level keys not managed by CCM. +// +// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write +func platformWriteCredentials(credentials *oauthCredentials, customPath string) error { if customPath != "" { - return writeCredentialsToFile(oauthCredentials, customPath) + return writeCredentialsToFile(credentials, customPath) } userInfo, err := getRealUser() if err == nil { - data, err := json.Marshal(map[string]any{"claudeAiOauth": oauthCredentials}) + serviceName := getKeychainServiceName() + + existing := make(map[string]json.RawMessage) + query := keychain.NewItem() + query.SetSecClass(keychain.SecClassGenericPassword) + query.SetService(serviceName) + query.SetAccount(userInfo.Username) + query.SetMatchLimit(keychain.MatchLimitOne) + query.SetReturnData(true) + results, queryErr := keychain.QueryItem(query) + if queryErr == nil && len(results) == 1 { + _ = json.Unmarshal(results[0].Data, &existing) + } + + credentialData, err := json.Marshal(credentials) + if err != nil { + return E.Cause(err, "marshal credentials") + } + existing["claudeAiOauth"] = credentialData + data, err := json.Marshal(existing) + if err != nil { + return E.Cause(err, "marshal credential container") + } + + item := keychain.NewItem() + item.SetSecClass(keychain.SecClassGenericPassword) + item.SetService(serviceName) + item.SetAccount(userInfo.Username) + item.SetData(data) + item.SetAccessible(keychain.AccessibleWhenUnlocked) + + err = keychain.AddItem(item) if err == nil { - serviceName := getKeychainServiceName() - item := keychain.NewItem() - item.SetSecClass(keychain.SecClassGenericPassword) - item.SetService(serviceName) - item.SetAccount(userInfo.Username) - item.SetData(data) - item.SetAccessible(keychain.AccessibleWhenUnlocked) - - err = keychain.AddItem(item) - if err == nil { - return nil - } + return nil + } - if err == keychain.ErrorDuplicateItem { - query := keychain.NewItem() - query.SetSecClass(keychain.SecClassGenericPassword) - query.SetService(serviceName) - query.SetAccount(userInfo.Username) + if err == keychain.ErrorDuplicateItem { + updateQuery := keychain.NewItem() + updateQuery.SetSecClass(keychain.SecClassGenericPassword) + updateQuery.SetService(serviceName) + updateQuery.SetAccount(userInfo.Username) - updateItem := keychain.NewItem() - updateItem.SetData(data) + updateItem := keychain.NewItem() + updateItem.SetData(data) - updateErr := keychain.UpdateItem(query, updateItem) - if updateErr == nil { - return nil - } + updateErr := keychain.UpdateItem(updateQuery, updateItem) + if updateErr == nil { + return nil } } } @@ -119,5 +143,5 @@ func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath str if err != nil { return err } - return writeCredentialsToFile(oauthCredentials, defaultPath) + return writeCredentialsToFile(credentials, defaultPath) } diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 57e0ca912e..4c571a38ed 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -9,6 +9,7 @@ import ( "math" "net" "net/http" + "slices" "strconv" "sync" "time" @@ -29,6 +30,7 @@ type defaultCredential struct { serviceContext context.Context credentialPath string credentialFilePath string + configDir string statePath string credentials *oauthCredentials access sync.RWMutex @@ -132,6 +134,7 @@ func (c *defaultCredential) start() error { return E.Cause(err, "resolve credential path for ", c.tag) } c.credentialFilePath = credentialFilePath + c.configDir = resolveConfigDir(c.credentialPath, credentialFilePath) c.loadPersistedState() err = c.ensureCredentialWatcher() if err != nil { @@ -176,6 +179,7 @@ func (c *defaultCredential) statusSnapshotLocked() statusSnapshot { func (c *defaultCredential) getAccessToken() (string, error) { c.retryCredentialReloadIfNeeded() + // Fast path: cached token is still valid c.access.RLock() if c.credentials != nil && !c.credentials.needsRefresh() { token := c.credentials.AccessToken @@ -184,6 +188,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { } c.access.RUnlock() + // Reload from disk — Claude Code or another process may have refreshed err := c.reloadCredentials(true) if err == nil { c.access.RLock() @@ -195,6 +200,41 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.access.RUnlock() } + // ref (@anthropic-ai/claude-code @2.1.81): cli.js _P1 line 179526 + // Claude Code skips refresh for tokens without user:inference scope. + // Return existing token (may be expired); 401 recovery is the safety net. + c.access.RLock() + if c.credentials != nil && !slices.Contains(c.credentials.Scopes, "user:inference") { + token := c.credentials.AccessToken + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + + // Acquire cross-process lock before refresh (outside Go mutex to avoid holding mutex during sleep) + // ref: cli.js _P1 (line 179534-179536) — proper-lockfile lock on config dir + release, lockErr := acquireCredentialLock(c.configDir) + if lockErr != nil { + c.logger.Debug("acquire credential lock for ", c.tag, ": ", lockErr) + release = func() {} + } + defer release() + + // ref: cli.js _P1 (line 179559-179562) — re-read after lock, skip if race resolved + _ = c.reloadCredentials(true) + c.access.RLock() + noRefreshToken := c.credentials == nil || c.credentials.RefreshToken == "" + raceResolved := !noRefreshToken && !c.credentials.needsRefresh() + var racedToken string + if (noRefreshToken || raceResolved) && c.credentials != nil { + racedToken = c.credentials.AccessToken + } + c.access.RUnlock() + if noRefreshToken || raceResolved { + return racedToken, nil + } + + // Slow path: acquire Go mutex and refresh c.access.Lock() defer c.access.Unlock() @@ -227,6 +267,14 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.refreshRetryAt = time.Now().Add(retryDelay) c.refreshRetryError = err } + // ref: cli.js _P1 (line 179568-179573) — post-failure recovery: + // re-read from disk; if another process refreshed successfully, use that. + // Cannot call reloadCredentials here (deadlock: already holding c.access). + latestCredentials, readErr := platformReadCredentials(c.credentialPath) + if readErr == nil && latestCredentials != nil && !latestCredentials.needsRefresh() { + c.credentials = latestCredentials + return latestCredentials.AccessToken, nil + } return "", err } c.refreshRetryAt = time.Time{} diff --git a/service/ccm/credential_lock.go b/service/ccm/credential_lock.go new file mode 100644 index 0000000000..2374437c4f --- /dev/null +++ b/service/ccm/credential_lock.go @@ -0,0 +1,84 @@ +package ccm + +import ( + "math/rand/v2" + "os" + "path/filepath" + "time" + + E "github.com/sagernet/sing/common/exceptions" +) + +// acquireCredentialLock acquires a cross-process lock compatible with Claude Code's +// proper-lockfile protocol. The lock is a directory created via mkdir (atomic on +// POSIX filesystems). +// +// ref (@anthropic-ai/claude-code @2.1.81): cli.js _P1 (line 179530-179577) +// ref: proper-lockfile mkdir protocol (cli.js:43570) +// ref: proper-lockfile default options — stale=10s, update=stale/2=5s, realpath=true (cli.js:43661-43664) +// +// Claude Code locks d1() (= ~/.claude config dir). The lock directory is +// .lock (proper-lockfile default: .lock). +// Manual retry: initial + 5 retries = 6 total, delay 1+rand(1s) per retry. +func acquireCredentialLock(configDir string) (func(), error) { + // ref: cli.js _P1 line 179531 — mkdir -p configDir before locking + os.MkdirAll(configDir, 0o700) + // ref: proper-lockfile realpath:true (cli.js:43664) — resolve symlinks before appending .lock + resolved, err := filepath.EvalSymlinks(configDir) + if err != nil { + resolved = filepath.Clean(configDir) + } + lockPath := resolved + ".lock" + // ref: cli.js _P1 line 179539-179543 — initial + 5 retries = 6 total attempts + for attempt := 0; attempt < 6; attempt++ { + if attempt > 0 { + // ref: cli.js _P1 line 179542 — 1000 + Math.random() * 1000 + delay := time.Second + time.Duration(rand.IntN(1000))*time.Millisecond + time.Sleep(delay) + } + err = os.Mkdir(lockPath, 0o755) + if err == nil { + return startLockHeartbeat(lockPath), nil + } + if !os.IsExist(err) { + return nil, E.Cause(err, "create lock directory") + } + // ref: proper-lockfile stale check (cli.js:43603-43604) + // stale threshold = 10s (cli.js:43662) + info, statErr := os.Stat(lockPath) + if statErr != nil { + continue + } + if time.Since(info.ModTime()) > 10*time.Second { + os.Remove(lockPath) + } + } + return nil, E.New("credential lock timeout") +} + +// startLockHeartbeat spawns a goroutine that touches the lock directory's mtime +// every 5 seconds to prevent stale detection by other processes. +// +// ref: proper-lockfile update interval = stale/2 = 5s (cli.js:43662-43663) +// +// Returns a release function that stops the heartbeat and removes the lock directory. +func startLockHeartbeat(lockPath string) func() { + done := make(chan struct{}) + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + os.Chtimes(lockPath, now, now) + case <-done: + return + } + } + }() + return func() { + close(done) + os.Remove(lockPath) + } +} diff --git a/service/ccm/credential_oauth.go b/service/ccm/credential_oauth.go index 9d0b5146e0..71caf62673 100644 --- a/service/ccm/credential_oauth.go +++ b/service/ccm/credential_oauth.go @@ -12,6 +12,7 @@ import ( "runtime" "slices" "strconv" + "strings" "sync" "time" @@ -23,10 +24,32 @@ const ( oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" oauth2TokenURL = "https://platform.claude.com/v1/oauth/token" claudeAPIBaseURL = "https://api.anthropic.com" - tokenRefreshBufferMs = 60000 anthropicBetaOAuthValue = "oauth-2025-04-20" + + // ref (@anthropic-ai/claude-code @2.1.81): cli.js vB (line 172879) + tokenRefreshBufferMs = 300000 ) +// ref (@anthropic-ai/claude-code @2.1.81): cli.js q78 (line 33167) +// These scopes may change across Claude Code versions. +var defaultOAuthScopes = []string{ + "user:profile", "user:inference", "user:sessions:claude_code", + "user:mcp_servers", "user:file_upload", +} + +// resolveRefreshScopes determines which scopes to send in the token refresh request. +// +// ref (@anthropic-ai/claude-code @2.1.81): cli.js NR() (line 172693) + mB6 scope logic (line 172761) +// +// Claude Code behavior: if stored scopes include "user:inference", send default +// scopes; otherwise send the stored scopes verbatim. +func resolveRefreshScopes(stored []string) string { + if len(stored) == 0 || slices.Contains(stored, "user:inference") { + return strings.Join(defaultOAuthScopes, " ") + } + return strings.Join(stored, " ") +} + const ccmUserAgentFallback = "claude-code/2.1.72" var ( @@ -71,6 +94,22 @@ func detectClaudeCodeVersion() (string, error) { return filepath.Base(target), nil } +// resolveConfigDir returns the Claude config directory for lock coordination. +// +// ref (@anthropic-ai/claude-code @2.1.81): cli.js d1() (line 2983) — config dir used for locking +func resolveConfigDir(credentialPath string, credentialFilePath string) string { + if credentialPath == "" { + if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" { + return configDir + } + userInfo, err := getRealUser() + if err == nil { + return filepath.Join(userInfo.HomeDir, ".claude") + } + } + return filepath.Dir(credentialFilePath) +} + func getRealUser() (*user.User, error) { if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { sudoUserInfo, err := user.Lookup(sudoUser) @@ -118,10 +157,24 @@ func checkCredentialFileWritable(path string) error { return file.Close() } -func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error { - data, err := json.MarshalIndent(map[string]any{ - "claudeAiOauth": oauthCredentials, - }, "", " ") +// writeCredentialsToFile performs a read-modify-write: reads the existing JSON, +// replaces only the claudeAiOauth key, and writes back. This preserves any +// other top-level keys in the credential file. +// +// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write +// ref: cli.js qD1.update (line 176156) — writeFileSync + chmod 0o600 +func writeCredentialsToFile(credentials *oauthCredentials, path string) error { + existing := make(map[string]json.RawMessage) + data, readErr := os.ReadFile(path) + if readErr == nil { + _ = json.Unmarshal(data, &existing) + } + credentialData, err := json.Marshal(credentials) + if err != nil { + return err + } + existing["claudeAiOauth"] = credentialData + data, err = json.MarshalIndent(existing, "", " ") if err != nil { return err } @@ -131,16 +184,14 @@ func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) err // oauthCredentials mirrors the claudeAiOauth object in Claude Code's // credential file ($CLAUDE_CONFIG_DIR/.credentials.json). // -// ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6() / refreshOAuthToken -// -// Note: subscriptionType, rateLimitTier, and isMax were removed from this -// struct — they are profile state, not auth credentials. Claude Code also -// stores them here, but we persist them separately via state_path instead. +// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179446-179452) type oauthCredentials struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` - ExpiresAt int64 `json:"expiresAt"` - Scopes []string `json:"scopes,omitempty"` + AccessToken string `json:"accessToken"` // ref: cli.js line 179447 + RefreshToken string `json:"refreshToken"` // ref: cli.js line 179448 + ExpiresAt int64 `json:"expiresAt"` // ref: cli.js line 179449 (epoch ms) + Scopes []string `json:"scopes"` // ref: cli.js line 179450 + SubscriptionType *string `json:"subscriptionType"` // ref: cli.js line 179451 (?? null) + RateLimitTier *string `json:"rateLimitTier"` // ref: cli.js line 179452 (?? null) } func (c *oauthCredentials) needsRefresh() bool { @@ -155,10 +206,12 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau return nil, 0, E.New("refresh token is empty") } + // ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 (line 172757-172761) requestBody, err := json.Marshal(map[string]string{ "grant_type": "refresh_token", "refresh_token": credentials.RefreshToken, "client_id": oauth2ClientID, + "scope": resolveRefreshScopes(credentials.Scopes), }) if err != nil { return nil, 0, E.Cause(err, "marshal request") @@ -194,10 +247,12 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau return nil, 0, E.New("refresh failed: ", response.Status, " ", string(body)) } + // ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 response (line 172769-172772) var tokenResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` // ref: cli.js line 172770 z + RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input) + ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O + Scope string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope) } err = json.NewDecoder(response.Body).Decode(&tokenResponse) if err != nil { @@ -210,6 +265,11 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau newCredentials.RefreshToken = tokenResponse.RefreshToken } newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000 + // ref: cli.js uB6 (line 172696-172697): A?.split(" ").filter(Boolean) + // strings.Fields matches .filter(Boolean): splits on whitespace runs, removes empty strings + if tokenResponse.Scope != "" { + newCredentials.Scopes = strings.Fields(tokenResponse.Scope) + } return &newCredentials, 0, nil } diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 94af4e90e8..870a39514d 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -372,6 +372,43 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // ref (@anthropic-ai/claude-code @2.1.81): cli.js NA9 (line 179488-179494) — 401 recovery + // ref: cli.js CR1 (line 314268-314273) — 403 "OAuth token has been revoked" recovery + if !selectedCredential.isExternal() && bodyBytes != nil && + (response.StatusCode == http.StatusUnauthorized || response.StatusCode == http.StatusForbidden) { + shouldRetry := response.StatusCode == http.StatusUnauthorized + if response.StatusCode == http.StatusForbidden { + peekBody, _ := io.ReadAll(response.Body) + shouldRetry = strings.Contains(string(peekBody), "OAuth token has been revoked") + if !shouldRetry { + response.Body.Close() + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(peekBody)) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(peekBody)) + return + } + } + if shouldRetry { + response.Body.Close() + s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying") + if defaultCred, ok := selectedCredential.(*defaultCredential); ok { + _ = defaultCred.reloadCredentials(true) + } + retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if buildErr != nil { + writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error()) + return + } + retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest) + if retryErr != nil { + writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error()) + return + } + response = retryResponse + defer retryResponse.Body.Close() + } + } + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) From d996b60f44f5e7fe25518b8778e1743293785ea5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 22 Mar 2026 06:23:13 +0800 Subject: [PATCH 78/96] ccm/ocm: Add CLAUDE.md --- .gitignore | 4 ++-- service/ccm/CLAUDE.md | 13 +++++++++++++ service/ocm/CLAUDE.md | 7 +++++++ 3 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 service/ccm/CLAUDE.md create mode 100644 service/ocm/CLAUDE.md diff --git a/.gitignore b/.gitignore index d2b74d08cd..c3c9a17bc3 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,6 @@ .DS_Store /config.d/ /venv/ -CLAUDE.md -AGENTS.md +/CLAUDE.md +/AGENTS.md /.claude/ diff --git a/service/ccm/CLAUDE.md b/service/ccm/CLAUDE.md new file mode 100644 index 0000000000..d339aece16 --- /dev/null +++ b/service/ccm/CLAUDE.md @@ -0,0 +1,13 @@ +# Claude Code Multiplexer + +### Reverse Claude Code + +Claude distributes a huge binary by default in a Bun, which is difficult to reverse engineer (and is very likely the one the user have installed now). + +You must obtain the npm version of the Claude Code js source code: + +Example: + +```bash +cd /tmp && npm pack @anthropic-ai/claude-code && tar xzf anthropic-ai-claude-code-*.tgz && npx prettier --write package/cli.js +``` diff --git a/service/ocm/CLAUDE.md b/service/ocm/CLAUDE.md new file mode 100644 index 0000000000..986cb2673f --- /dev/null +++ b/service/ocm/CLAUDE.md @@ -0,0 +1,7 @@ +# OpenAI Codex Multiplexer + +### Reverse Codex + +Oh, Codex is just open source. + +Clone it and study its code: https://github.com/openai/codex From bb2169bc17d50a7eac030dcbf8962453c9747432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 22 Mar 2026 06:35:34 +0800 Subject: [PATCH 79/96] release: Fix install_go.sh --- release/local/install_go.sh | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/release/local/install_go.sh b/release/local/install_go.sh index ea64fec45a..082d7ce81d 100755 --- a/release/local/install_go.sh +++ b/release/local/install_go.sh @@ -2,8 +2,16 @@ set -e -o pipefail -go_version=$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g') -curl -Lo go.tar.gz "https://go.dev/dl/go$go_version.linux-amd64.tar.gz" +manifest=$(curl -fS 'https://go.dev/VERSION?m=text') +go_version=$(echo "$manifest" | head -1 | sed 's/^go//') +os=$(uname -s | tr '[:upper:]' '[:lower:]') +arch=$(uname -m) +case "$arch" in + x86_64) arch="amd64" ;; + aarch64|arm64) arch="arm64" ;; +esac +curl -Lo go.tar.gz "https://go.dev/dl/go$go_version.$os-$arch.tar.gz" sudo rm -rf /usr/local/go sudo tar -C /usr/local -xzf go.tar.gz rm go.tar.gz +echo "Installed Go $go_version" From 441c98890d19ade96cf2dd216a1291e60bffcac2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 22 Mar 2026 16:59:36 +0800 Subject: [PATCH 80/96] feat(ccm): add claude_directory option to read Claude Code config --- option/ccm.go | 16 ++--- service/ccm/credential_config_file.go | 59 +++++++++++++++++ service/ccm/credential_default.go | 95 ++++++++++++++++++++------- service/ccm/credential_state_file.go | 64 ------------------ 4 files changed, 139 insertions(+), 95 deletions(-) create mode 100644 service/ccm/credential_config_file.go delete mode 100644 service/ccm/credential_state_file.go diff --git a/option/ccm.go b/option/ccm.go index 96200248a9..481068e617 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -76,14 +76,14 @@ func (c *CCMCredential) UnmarshalJSON(bytes []byte) error { } type CCMDefaultCredentialOptions struct { - CredentialPath string `json:"credential_path,omitempty"` - StatePath string `json:"state_path,omitempty"` - UsagesPath string `json:"usages_path,omitempty"` - Detour string `json:"detour,omitempty"` - Reserve5h uint8 `json:"reserve_5h"` - ReserveWeekly uint8 `json:"reserve_weekly"` - Limit5h uint8 `json:"limit_5h,omitempty"` - LimitWeekly uint8 `json:"limit_weekly,omitempty"` + CredentialPath string `json:"credential_path,omitempty"` + ClaudeDirectory string `json:"claude_directory,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` + Detour string `json:"detour,omitempty"` + Reserve5h uint8 `json:"reserve_5h"` + ReserveWeekly uint8 `json:"reserve_weekly"` + Limit5h uint8 `json:"limit_5h,omitempty"` + LimitWeekly uint8 `json:"limit_weekly,omitempty"` } type CCMBalancerCredentialOptions struct { diff --git a/service/ccm/credential_config_file.go b/service/ccm/credential_config_file.go new file mode 100644 index 0000000000..d1b66d2e33 --- /dev/null +++ b/service/ccm/credential_config_file.go @@ -0,0 +1,59 @@ +package ccm + +import ( + "encoding/json" + "os" + "path/filepath" +) + +// claudeCodeConfig represents the persisted config written by Claude Code. +// +// ref (@anthropic-ai/claude-code @2.1.81): +// +// ref: cli.js P8() (line 174997) — reads config +// ref: cli.js c8() (line 174919) — writes config +// ref: cli.js _D() (line 39158-39163) — config file path resolution +type claudeCodeConfig struct { + UserID string `json:"userID"` // ref: cli.js XL() (line 175325) — random 32-byte hex, generated once + OAuthAccount *claudeOAuthAccount `json:"oauthAccount"` // ref: cli.js fP6() / storeOAuthAccountInfo — from /api/oauth/profile +} + +type claudeOAuthAccount struct { + AccountUUID string `json:"accountUuid"` +} + +// resolveClaudeConfigFile finds the Claude Code config file within the given directory. +// +// Config file path resolution mirrors cli.js _D() (line 39158-39163): +// 1. claudeDirectory/.config.json — newer format, checked first +// 2. claudeDirectory/.claude.json — used when CLAUDE_CONFIG_DIR is set +// 3. filepath.Dir(claudeDirectory)/.claude.json — default ~/.claude case → ~/.claude.json +// +// Returns the first path that exists, or "" if none found. +func resolveClaudeConfigFile(claudeDirectory string) string { + candidates := []string{ + filepath.Join(claudeDirectory, ".config.json"), + filepath.Join(claudeDirectory, ".claude.json"), + filepath.Join(filepath.Dir(claudeDirectory), ".claude.json"), + } + for _, candidate := range candidates { + _, err := os.Stat(candidate) + if err == nil { + return candidate + } + } + return "" +} + +func readClaudeCodeConfig(path string) (*claudeCodeConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var config claudeCodeConfig + err = json.Unmarshal(data, &config) + if err != nil { + return nil, err + } + return &config, nil +} diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 4c571a38ed..2172f06c55 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -9,6 +9,7 @@ import ( "math" "net" "net/http" + "path/filepath" "slices" "strconv" "sync" @@ -29,9 +30,11 @@ type defaultCredential struct { tag string serviceContext context.Context credentialPath string + claudeDirectory string credentialFilePath string configDir string - statePath string + deviceID string + configLoaded bool credentials *oauthCredentials access sync.RWMutex state credentialState @@ -109,7 +112,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef tag: tag, serviceContext: ctx, credentialPath: options.CredentialPath, - statePath: options.StatePath, + claudeDirectory: options.ClaudeDirectory, cap5h: cap5h, capWeekly: capWeekly, forwardHTTPClient: httpClient, @@ -129,13 +132,18 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef } func (c *defaultCredential) start() error { + if c.claudeDirectory != "" { + c.loadClaudeCodeConfig() + if c.credentialPath == "" { + c.credentialPath = filepath.Join(c.claudeDirectory, ".credentials.json") + } + } credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) if err != nil { return E.Cause(err, "resolve credential path for ", c.tag) } c.credentialFilePath = credentialFilePath c.configDir = resolveConfigDir(c.credentialPath, credentialFilePath) - c.loadPersistedState() err = c.ensureCredentialWatcher() if err != nil { c.logger.Debug("start credential watcher for ", c.tag, ": ", err) @@ -154,6 +162,28 @@ func (c *defaultCredential) start() error { return nil } +func (c *defaultCredential) loadClaudeCodeConfig() { + configFilePath := resolveClaudeConfigFile(c.claudeDirectory) + if configFilePath == "" { + return + } + config, err := readClaudeCodeConfig(configFilePath) + if err != nil { + c.logger.Warn("read claude code config for ", c.tag, ": ", err) + return + } + c.stateAccess.Lock() + if config.OAuthAccount != nil && config.OAuthAccount.AccountUUID != "" { + c.state.accountUUID = config.OAuthAccount.AccountUUID + } + c.stateAccess.Unlock() + if config.UserID != "" { + c.deviceID = config.UserID + } + c.configLoaded = true + c.logger.Debug("loaded claude code config for ", c.tag, ": account=", c.state.accountUUID, ", device=", c.deviceID) +} + func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) { c.statusSubscriber = subscriber } @@ -697,7 +727,7 @@ func (c *defaultCredential) pollUsage() { } c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } - needsProfileFetch := c.state.rateLimitTier == "" + needsProfileFetch := !c.configLoaded && c.state.rateLimitTier == "" shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -782,7 +812,6 @@ func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken st if shouldEmit { c.emitStatusUpdate() } - c.savePersistedState() c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier)) } @@ -841,7 +870,7 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt proxyURL := claudeAPIBaseURL + original.URL.RequestURI() var body io.Reader if bodyBytes != nil { - bodyBytes = c.injectAccountUUID(bodyBytes) + bodyBytes = c.injectMetadataFields(bodyBytes) body = bytes.NewReader(bodyBytes) } else { body = original.Body @@ -878,24 +907,20 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt return proxyRequest, nil } -// injectAccountUUID fills in the account_uuid field in metadata.user_id -// when the client sends it empty (e.g. using ANTHROPIC_AUTH_TOKEN). +// injectMetadataFields fills in account_uuid and device_id in metadata.user_id +// when the client sends them empty (e.g. using ANTHROPIC_AUTH_TOKEN). // // Claude Code >= 2.1.78 (@anthropic-ai/claude-code) sets metadata as: // // {user_id: JSON.stringify({device_id, account_uuid, session_id})} // // ref: cli.js L66() — metadata constructor -// -// account_uuid is populated from oauthAccount.accountUuid which comes from -// the /api/oauth/profile endpoint (ref: cli.js EX1() → fP6()). -// When the client uses ANTHROPIC_AUTH_TOKEN instead of Claude AI OAuth, -// account_uuid is empty. We inject it from the fetchProfile result. -func (c *defaultCredential) injectAccountUUID(bodyBytes []byte) []byte { +func (c *defaultCredential) injectMetadataFields(bodyBytes []byte) []byte { c.stateAccess.RLock() accountUUID := c.state.accountUUID c.stateAccess.RUnlock() - if accountUUID == "" { + deviceID := c.deviceID + if accountUUID == "" && deviceID == "" { return bodyBytes } @@ -931,19 +956,43 @@ func (c *defaultCredential) injectAccountUUID(bodyBytes []byte) []byte { return bodyBytes } - existingRaw, hasExisting := userIDObject["account_uuid"] - if hasExisting { - var existing string - if json.Unmarshal(existingRaw, &existing) == nil && existing != "" { - return bodyBytes + modified := false + + if accountUUID != "" { + existingRaw, hasExisting := userIDObject["account_uuid"] + needsInject := !hasExisting + if hasExisting { + var existing string + needsInject = json.Unmarshal(existingRaw, &existing) != nil || existing == "" + } + if needsInject { + accountUUIDJSON, marshalErr := json.Marshal(accountUUID) + if marshalErr == nil { + userIDObject["account_uuid"] = json.RawMessage(accountUUIDJSON) + modified = true + } } } - accountUUIDJSON, err := json.Marshal(accountUUID) - if err != nil { + if deviceID != "" { + existingRaw, hasExisting := userIDObject["device_id"] + needsInject := !hasExisting + if hasExisting { + var existing string + needsInject = json.Unmarshal(existingRaw, &existing) != nil || existing == "" + } + if needsInject { + deviceIDJSON, marshalErr := json.Marshal(deviceID) + if marshalErr == nil { + userIDObject["device_id"] = json.RawMessage(deviceIDJSON) + modified = true + } + } + } + + if !modified { return bodyBytes } - userIDObject["account_uuid"] = json.RawMessage(accountUUIDJSON) newUserIDBytes, err := json.Marshal(userIDObject) if err != nil { diff --git a/service/ccm/credential_state_file.go b/service/ccm/credential_state_file.go deleted file mode 100644 index 630e261887..0000000000 --- a/service/ccm/credential_state_file.go +++ /dev/null @@ -1,64 +0,0 @@ -package ccm - -import ( - "encoding/json" - "os" -) - -// persistedState holds profile data fetched from /api/oauth/profile, -// persisted to state_path so it survives restarts without re-fetching. -// -// Claude Code (@anthropic-ai/claude-code @2.1.81) stores equivalent data in -// its config file (~/.claude/.config.json) under the oauthAccount key: -// -// ref: cli.js fP6() / storeOAuthAccountInfo — writes accountUuid, billingType, etc. -// ref: cli.js P8() — reads config from $CLAUDE_CONFIG_DIR/.config.json -type persistedState struct { - AccountUUID string `json:"account_uuid,omitempty"` - AccountType string `json:"account_type,omitempty"` - RateLimitTier string `json:"rate_limit_tier,omitempty"` -} - -func (c *defaultCredential) loadPersistedState() { - if c.statePath == "" { - return - } - data, err := os.ReadFile(c.statePath) - if err != nil { - return - } - var state persistedState - err = json.Unmarshal(data, &state) - if err != nil { - return - } - c.stateAccess.Lock() - if state.AccountUUID != "" { - c.state.accountUUID = state.AccountUUID - } - if state.AccountType != "" { - c.state.accountType = state.AccountType - } - if state.RateLimitTier != "" { - c.state.rateLimitTier = state.RateLimitTier - } - c.stateAccess.Unlock() -} - -func (c *defaultCredential) savePersistedState() { - if c.statePath == "" { - return - } - c.stateAccess.RLock() - state := persistedState{ - AccountUUID: c.state.accountUUID, - AccountType: c.state.accountType, - RateLimitTier: c.state.rateLimitTier, - } - c.stateAccess.RUnlock() - data, err := json.MarshalIndent(state, "", " ") - if err != nil { - return - } - os.WriteFile(c.statePath, data, 0o600) -} From 92c8f4c5c8d754803f2e6e2fa1bf12bacdd7312f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 21:15:46 +0800 Subject: [PATCH 81/96] fix(ccm): align default credential with Claude Code --- service/ccm/credential.go | 1 + service/ccm/credential_config_file.go | 94 ++++- service/ccm/credential_darwin.go | 125 +++--- service/ccm/credential_default.go | 520 ++++++++++++++++--------- service/ccm/credential_default_test.go | 205 ++++++++++ service/ccm/credential_file.go | 19 +- service/ccm/credential_oauth.go | 78 ++-- service/ccm/credential_oauth_test.go | 141 +++++++ service/ccm/credential_storage.go | 124 ++++++ service/ccm/credential_storage_test.go | 125 ++++++ service/ccm/service_handler.go | 41 +- service/ccm/service_handler_test.go | 221 +++++++++++ service/ccm/test_helpers_test.go | 138 +++++++ 13 files changed, 1541 insertions(+), 291 deletions(-) create mode 100644 service/ccm/credential_default_test.go create mode 100644 service/ccm/credential_oauth_test.go create mode 100644 service/ccm/credential_storage.go create mode 100644 service/ccm/credential_storage_test.go create mode 100644 service/ccm/service_handler_test.go create mode 100644 service/ccm/test_helpers_test.go diff --git a/service/ccm/credential.go b/service/ccm/credential.go index f2d4003f12..20170fca23 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -62,6 +62,7 @@ type credentialState struct { accountUUID string accountType string rateLimitTier string + oauthAccount *claudeOAuthAccount remotePlanWeight float64 lastUpdated time.Time consecutivePollFailures int diff --git a/service/ccm/credential_config_file.go b/service/ccm/credential_config_file.go index d1b66d2e33..82c35cb10b 100644 --- a/service/ccm/credential_config_file.go +++ b/service/ccm/credential_config_file.go @@ -19,7 +19,14 @@ type claudeCodeConfig struct { } type claudeOAuthAccount struct { - AccountUUID string `json:"accountUuid"` + AccountUUID string `json:"accountUuid,omitempty"` + EmailAddress string `json:"emailAddress,omitempty"` + OrganizationUUID string `json:"organizationUuid,omitempty"` + DisplayName *string `json:"displayName,omitempty"` + HasExtraUsageEnabled *bool `json:"hasExtraUsageEnabled,omitempty"` + BillingType *string `json:"billingType,omitempty"` + AccountCreatedAt *string `json:"accountCreatedAt,omitempty"` + SubscriptionCreatedAt *string `json:"subscriptionCreatedAt,omitempty"` } // resolveClaudeConfigFile finds the Claude Code config file within the given directory. @@ -33,8 +40,8 @@ type claudeOAuthAccount struct { func resolveClaudeConfigFile(claudeDirectory string) string { candidates := []string{ filepath.Join(claudeDirectory, ".config.json"), - filepath.Join(claudeDirectory, ".claude.json"), - filepath.Join(filepath.Dir(claudeDirectory), ".claude.json"), + filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName()), + filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName()), } for _, candidate := range candidates { _, err := os.Stat(candidate) @@ -57,3 +64,84 @@ func readClaudeCodeConfig(path string) (*claudeCodeConfig, error) { } return &config, nil } + +func resolveClaudeConfigWritePath(claudeDirectory string) string { + if claudeDirectory == "" { + return "" + } + existingPath := resolveClaudeConfigFile(claudeDirectory) + if existingPath != "" { + return existingPath + } + if os.Getenv("CLAUDE_CONFIG_DIR") != "" { + return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName()) + } + defaultClaudeDirectory := filepath.Join(filepath.Dir(claudeDirectory), ".claude") + if claudeDirectory != defaultClaudeDirectory { + return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName()) + } + return filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName()) +} + +func writeClaudeCodeOAuthAccount(path string, account *claudeOAuthAccount) error { + if path == "" || account == nil { + return nil + } + storage := jsonFileStorage{path: path} + return writeStorageValue(storage, "oauthAccount", account) +} + +func claudeCodeLegacyConfigFileName() string { + if os.Getenv("CLAUDE_CODE_CUSTOM_OAUTH_URL") != "" { + return ".claude-custom-oauth.json" + } + return ".claude.json" +} + +func cloneClaudeOAuthAccount(account *claudeOAuthAccount) *claudeOAuthAccount { + if account == nil { + return nil + } + cloned := *account + cloned.DisplayName = cloneStringPointer(account.DisplayName) + cloned.HasExtraUsageEnabled = cloneBoolPointer(account.HasExtraUsageEnabled) + cloned.BillingType = cloneStringPointer(account.BillingType) + cloned.AccountCreatedAt = cloneStringPointer(account.AccountCreatedAt) + cloned.SubscriptionCreatedAt = cloneStringPointer(account.SubscriptionCreatedAt) + return &cloned +} + +func mergeClaudeOAuthAccount(base *claudeOAuthAccount, update *claudeOAuthAccount) *claudeOAuthAccount { + if update == nil { + return cloneClaudeOAuthAccount(base) + } + if base == nil { + return cloneClaudeOAuthAccount(update) + } + merged := cloneClaudeOAuthAccount(base) + if update.AccountUUID != "" { + merged.AccountUUID = update.AccountUUID + } + if update.EmailAddress != "" { + merged.EmailAddress = update.EmailAddress + } + if update.OrganizationUUID != "" { + merged.OrganizationUUID = update.OrganizationUUID + } + if update.DisplayName != nil { + merged.DisplayName = cloneStringPointer(update.DisplayName) + } + if update.HasExtraUsageEnabled != nil { + merged.HasExtraUsageEnabled = cloneBoolPointer(update.HasExtraUsageEnabled) + } + if update.BillingType != nil { + merged.BillingType = cloneStringPointer(update.BillingType) + } + if update.AccountCreatedAt != nil { + merged.AccountCreatedAt = cloneStringPointer(update.AccountCreatedAt) + } + if update.SubscriptionCreatedAt != nil { + merged.SubscriptionCreatedAt = cloneStringPointer(update.SubscriptionCreatedAt) + } + return merged +} diff --git a/service/ccm/credential_darwin.go b/service/ccm/credential_darwin.go index d025fd47c1..0cd28b79bd 100644 --- a/service/ccm/credential_darwin.go +++ b/service/ccm/credential_darwin.go @@ -14,6 +14,11 @@ import ( "github.com/keybase/go-keychain" ) +type keychainStorage struct { + service string + account string +} + func getKeychainServiceName() string { configDirectory := os.Getenv("CLAUDE_CONFIG_DIR") if configDirectory == "" { @@ -76,72 +81,90 @@ func platformCanWriteCredentials(customPath string) error { return checkCredentialFileWritable(customPath) } -// platformWriteCredentials performs a read-modify-write on the keychain entry, -// preserving any fields or top-level keys not managed by CCM. -// -// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write func platformWriteCredentials(credentials *oauthCredentials, customPath string) error { if customPath != "" { return writeCredentialsToFile(credentials, customPath) } + defaultPath, err := getDefaultCredentialsPath() + if err != nil { + return err + } + fileStorage := jsonFileStorage{path: defaultPath} + userInfo, err := getRealUser() - if err == nil { - serviceName := getKeychainServiceName() + if err != nil { + return writeCredentialsToFile(credentials, defaultPath) + } + return persistStorageValue(keychainStorage{ + service: getKeychainServiceName(), + account: userInfo.Username, + }, fileStorage, "claudeAiOauth", credentials) +} - existing := make(map[string]json.RawMessage) - query := keychain.NewItem() - query.SetSecClass(keychain.SecClassGenericPassword) - query.SetService(serviceName) - query.SetAccount(userInfo.Username) - query.SetMatchLimit(keychain.MatchLimitOne) - query.SetReturnData(true) - results, queryErr := keychain.QueryItem(query) - if queryErr == nil && len(results) == 1 { - _ = json.Unmarshal(results[0].Data, &existing) - } +func (s keychainStorage) readContainer() (map[string]json.RawMessage, bool, error) { + query := keychain.NewItem() + query.SetSecClass(keychain.SecClassGenericPassword) + query.SetService(s.service) + query.SetAccount(s.account) + query.SetMatchLimit(keychain.MatchLimitOne) + query.SetReturnData(true) - credentialData, err := json.Marshal(credentials) - if err != nil { - return E.Cause(err, "marshal credentials") - } - existing["claudeAiOauth"] = credentialData - data, err := json.Marshal(existing) - if err != nil { - return E.Cause(err, "marshal credential container") + results, err := keychain.QueryItem(query) + if err != nil { + if err == keychain.ErrorItemNotFound { + return make(map[string]json.RawMessage), false, nil } + return nil, false, E.Cause(err, "query keychain") + } + if len(results) != 1 { + return make(map[string]json.RawMessage), false, nil + } - item := keychain.NewItem() - item.SetSecClass(keychain.SecClassGenericPassword) - item.SetService(serviceName) - item.SetAccount(userInfo.Username) - item.SetData(data) - item.SetAccessible(keychain.AccessibleWhenUnlocked) + container := make(map[string]json.RawMessage) + if len(results[0].Data) == 0 { + return container, true, nil + } + if err := json.Unmarshal(results[0].Data, &container); err != nil { + return nil, true, err + } + return container, true, nil +} - err = keychain.AddItem(item) - if err == nil { - return nil - } +func (s keychainStorage) writeContainer(container map[string]json.RawMessage) error { + data, err := json.Marshal(container) + if err != nil { + return err + } - if err == keychain.ErrorDuplicateItem { - updateQuery := keychain.NewItem() - updateQuery.SetSecClass(keychain.SecClassGenericPassword) - updateQuery.SetService(serviceName) - updateQuery.SetAccount(userInfo.Username) + item := keychain.NewItem() + item.SetSecClass(keychain.SecClassGenericPassword) + item.SetService(s.service) + item.SetAccount(s.account) + item.SetData(data) + item.SetAccessible(keychain.AccessibleWhenUnlocked) + err = keychain.AddItem(item) + if err == nil { + return nil + } + if err != keychain.ErrorDuplicateItem { + return err + } - updateItem := keychain.NewItem() - updateItem.SetData(data) + updateQuery := keychain.NewItem() + updateQuery.SetSecClass(keychain.SecClassGenericPassword) + updateQuery.SetService(s.service) + updateQuery.SetAccount(s.account) - updateErr := keychain.UpdateItem(updateQuery, updateItem) - if updateErr == nil { - return nil - } - } - } + updateItem := keychain.NewItem() + updateItem.SetData(data) + return keychain.UpdateItem(updateQuery, updateItem) +} - defaultPath, err := getDefaultCredentialsPath() - if err != nil { +func (s keychainStorage) delete() error { + err := keychain.DeleteGenericPasswordItem(s.service, s.account) + if err != nil && err != keychain.ErrorItemNotFound { return err } - return writeCredentialsToFile(credentials, defaultPath) + return nil } diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 2172f06c55..ea74d51e73 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -26,6 +26,15 @@ import ( "github.com/sagernet/sing/common/observable" ) +var acquireCredentialLockFunc = acquireCredentialLock + +type claudeProfileSnapshot struct { + OAuthAccount *claudeOAuthAccount + AccountType string + RateLimitTier string + SubscriptionType *string +} + type defaultCredential struct { tag string serviceContext context.Context @@ -33,8 +42,9 @@ type defaultCredential struct { claudeDirectory string credentialFilePath string configDir string + claudeConfigPath string + syncClaudeConfig bool deviceID string - configLoaded bool credentials *oauthCredentials access sync.RWMutex state credentialState @@ -52,11 +62,6 @@ type defaultCredential struct { statusSubscriber *observable.Subscriber[struct{}] - // Refresh rate-limit cooldown (protected by access mutex) - refreshRetryAt time.Time - refreshRetryError error - refreshBlocked bool - // Connection interruption interrupted bool requestContext context.Context @@ -113,6 +118,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef serviceContext: ctx, credentialPath: options.CredentialPath, claudeDirectory: options.ClaudeDirectory, + syncClaudeConfig: options.ClaudeDirectory != "" || options.CredentialPath == "", cap5h: cap5h, capWeekly: capWeekly, forwardHTTPClient: httpClient, @@ -133,7 +139,6 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef func (c *defaultCredential) start() error { if c.claudeDirectory != "" { - c.loadClaudeCodeConfig() if c.credentialPath == "" { c.credentialPath = filepath.Join(c.claudeDirectory, ".credentials.json") } @@ -144,6 +149,13 @@ func (c *defaultCredential) start() error { } c.credentialFilePath = credentialFilePath c.configDir = resolveConfigDir(c.credentialPath, credentialFilePath) + if c.syncClaudeConfig { + if c.claudeDirectory == "" { + c.claudeDirectory = c.configDir + } + c.claudeConfigPath = resolveClaudeConfigWritePath(c.claudeDirectory) + c.loadClaudeCodeConfig() + } err = c.ensureCredentialWatcher() if err != nil { c.logger.Debug("start credential watcher for ", c.tag, ": ", err) @@ -173,6 +185,7 @@ func (c *defaultCredential) loadClaudeCodeConfig() { return } c.stateAccess.Lock() + c.state.oauthAccount = cloneClaudeOAuthAccount(config.OAuthAccount) if config.OAuthAccount != nil && config.OAuthAccount.AccountUUID != "" { c.state.accountUUID = config.OAuthAccount.AccountUUID } @@ -180,7 +193,7 @@ func (c *defaultCredential) loadClaudeCodeConfig() { if config.UserID != "" { c.deviceID = config.UserID } - c.configLoaded = true + c.claudeConfigPath = configFilePath c.logger.Debug("loaded claude code config for ", c.tag, ": account=", c.state.accountUUID, ", device=", c.deviceID) } @@ -209,147 +222,358 @@ func (c *defaultCredential) statusSnapshotLocked() statusSnapshot { func (c *defaultCredential) getAccessToken() (string, error) { c.retryCredentialReloadIfNeeded() - // Fast path: cached token is still valid c.access.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.AccessToken - c.access.RUnlock() - return token, nil - } + currentCredentials := cloneCredentials(c.credentials) c.access.RUnlock() - - // Reload from disk — Claude Code or another process may have refreshed - err := c.reloadCredentials(true) - if err == nil { - c.access.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.AccessToken - c.access.RUnlock() - return token, nil + if currentCredentials == nil { + err := c.reloadCredentials(true) + if err != nil { + return "", err } + c.access.RLock() + currentCredentials = cloneCredentials(c.credentials) c.access.RUnlock() } - - // ref (@anthropic-ai/claude-code @2.1.81): cli.js _P1 line 179526 - // Claude Code skips refresh for tokens without user:inference scope. - // Return existing token (may be expired); 401 recovery is the safety net. + if currentCredentials == nil { + return "", c.unavailableError() + } + if !currentCredentials.needsRefresh() || !slices.Contains(currentCredentials.Scopes, "user:inference") { + return currentCredentials.AccessToken, nil + } + c.tryRefreshCredentials(false) c.access.RLock() - if c.credentials != nil && !slices.Contains(c.credentials.Scopes, "user:inference") { - token := c.credentials.AccessToken - c.access.RUnlock() - return token, nil + defer c.access.RUnlock() + if c.credentials != nil && c.credentials.AccessToken != "" { + return c.credentials.AccessToken, nil } - c.access.RUnlock() + return "", c.unavailableError() +} + +func (c *defaultCredential) shouldUseClaudeConfig() bool { + return c.syncClaudeConfig && c.claudeConfigPath != "" +} - // Acquire cross-process lock before refresh (outside Go mutex to avoid holding mutex during sleep) - // ref: cli.js _P1 (line 179534-179536) — proper-lockfile lock on config dir - release, lockErr := acquireCredentialLock(c.configDir) - if lockErr != nil { - c.logger.Debug("acquire credential lock for ", c.tag, ": ", lockErr) - release = func() {} +func (c *defaultCredential) absorbCredentials(credentials *oauthCredentials) { + c.access.Lock() + c.credentials = cloneCredentials(credentials) + c.access.Unlock() + + c.stateAccess.Lock() + before := c.statusSnapshotLocked() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.applyCredentialMetadataLocked(credentials) + c.checkTransitionLocked() + shouldEmit := before != c.statusSnapshotLocked() + c.stateAccess.Unlock() + if shouldEmit { + c.emitStatusUpdate() } - defer release() +} - // ref: cli.js _P1 (line 179559-179562) — re-read after lock, skip if race resolved - _ = c.reloadCredentials(true) - c.access.RLock() - noRefreshToken := c.credentials == nil || c.credentials.RefreshToken == "" - raceResolved := !noRefreshToken && !c.credentials.needsRefresh() - var racedToken string - if (noRefreshToken || raceResolved) && c.credentials != nil { - racedToken = c.credentials.AccessToken +func (c *defaultCredential) applyCredentialMetadataLocked(credentials *oauthCredentials) { + if credentials == nil { + return } - c.access.RUnlock() - if noRefreshToken || raceResolved { - return racedToken, nil + if credentials.SubscriptionType != nil && *credentials.SubscriptionType != "" { + c.state.accountType = *credentials.SubscriptionType } + if credentials.RateLimitTier != nil && *credentials.RateLimitTier != "" { + c.state.rateLimitTier = *credentials.RateLimitTier + } +} - // Slow path: acquire Go mutex and refresh - c.access.Lock() - defer c.access.Unlock() +func (c *defaultCredential) absorbOAuthAccount(account *claudeOAuthAccount) { + c.stateAccess.Lock() + c.state.oauthAccount = mergeClaudeOAuthAccount(c.state.oauthAccount, account) + if c.state.oauthAccount != nil && c.state.oauthAccount.AccountUUID != "" { + c.state.accountUUID = c.state.oauthAccount.AccountUUID + } + c.stateAccess.Unlock() +} - if c.credentials == nil { - return "", c.unavailableError() +func (c *defaultCredential) persistOAuthAccount() { + if !c.shouldUseClaudeConfig() { + return } - if !c.credentials.needsRefresh() { - return c.credentials.AccessToken, nil + c.stateAccess.RLock() + account := cloneClaudeOAuthAccount(c.state.oauthAccount) + c.stateAccess.RUnlock() + if account == nil { + return } + if err := writeClaudeCodeOAuthAccount(c.claudeConfigPath, account); err != nil { + c.logger.Debug("write claude code config for ", c.tag, ": ", err) + } +} - if c.refreshBlocked { - return "", c.refreshRetryError +func (c *defaultCredential) needsProfileHydration() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.needsProfileHydrationLocked() +} + +func (c *defaultCredential) needsProfileHydrationLocked() bool { + if c.state.accountUUID == "" || c.state.accountType == "" || c.state.rateLimitTier == "" { + return true } - if !c.refreshRetryAt.IsZero() && time.Now().Before(c.refreshRetryAt) { - return "", c.refreshRetryError + if c.state.oauthAccount == nil { + return true } + return c.state.oauthAccount.BillingType == nil || + c.state.oauthAccount.AccountCreatedAt == nil || + c.state.oauthAccount.SubscriptionCreatedAt == nil +} - err = platformCanWriteCredentials(c.credentialPath) +func (c *defaultCredential) currentCredentials() *oauthCredentials { + c.access.RLock() + defer c.access.RUnlock() + return cloneCredentials(c.credentials) +} + +func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) { + if credentials == nil { + return + } + if err := platformWriteCredentials(credentials, c.credentialPath); err != nil { + c.logger.Error("persist refreshed token for ", c.tag, ": ", err) + } +} + +func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, force bool) bool { + if credentials == nil || credentials.RefreshToken == "" { + return false + } + if !slices.Contains(credentials.Scopes, "user:inference") { + return false + } + if force { + return true + } + return credentials.needsRefresh() +} + +func (c *defaultCredential) tryRefreshCredentials(force bool) bool { + latestCredentials, err := platformReadCredentials(c.credentialPath) + if err == nil && latestCredentials != nil { + c.absorbCredentials(latestCredentials) + } + currentCredentials := c.currentCredentials() + if !c.shouldAttemptRefresh(currentCredentials, force) { + return false + } + release, err := acquireCredentialLockFunc(c.configDir) if err != nil { - return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") + c.logger.Debug("acquire credential lock for ", c.tag, ": ", err) + return false } + defer release() - baseCredentials := cloneCredentials(c.credentials) - newCredentials, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) + latestCredentials, err = platformReadCredentials(c.credentialPath) + if err == nil && latestCredentials != nil { + c.absorbCredentials(latestCredentials) + currentCredentials = latestCredentials + } else { + currentCredentials = c.currentCredentials() + } + if !c.shouldAttemptRefresh(currentCredentials, force) { + return false + } + if err := platformCanWriteCredentials(c.credentialPath); err != nil { + c.logger.Debug("credential file not writable for ", c.tag, ": ", err) + return false + } + + baseCredentials := cloneCredentials(currentCredentials) + refreshResult, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, currentCredentials) if err != nil { - if retryDelay < 0 { - c.refreshBlocked = true - c.refreshRetryError = err - } else if retryDelay > 0 { - c.refreshRetryAt = time.Now().Add(retryDelay) - c.refreshRetryError = err - } - // ref: cli.js _P1 (line 179568-179573) — post-failure recovery: - // re-read from disk; if another process refreshed successfully, use that. - // Cannot call reloadCredentials here (deadlock: already holding c.access). + if retryDelay != 0 { + c.logger.Debug("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err) + } else { + c.logger.Debug("refresh token for ", c.tag, ": ", err) + } latestCredentials, readErr := platformReadCredentials(c.credentialPath) - if readErr == nil && latestCredentials != nil && !latestCredentials.needsRefresh() { - c.credentials = latestCredentials - return latestCredentials.AccessToken, nil + if readErr == nil && latestCredentials != nil { + c.absorbCredentials(latestCredentials) + return latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh()) + } + return false + } + if refreshResult == nil || refreshResult.Credentials == nil { + return false + } + + refreshedCredentials := cloneCredentials(refreshResult.Credentials) + c.absorbCredentials(refreshedCredentials) + c.persistCredentials(refreshedCredentials) + + if refreshResult.TokenAccount != nil { + c.absorbOAuthAccount(refreshResult.TokenAccount) + c.persistOAuthAccount() + } + if c.needsProfileHydration() { + profileSnapshot, profileErr := c.fetchProfileSnapshot(c.forwardHTTPClient, refreshedCredentials.AccessToken) + if profileErr != nil { + c.logger.Debug("fetch profile for ", c.tag, ": ", profileErr) + } else if profileSnapshot != nil { + credentialsChanged := c.applyProfileSnapshot(profileSnapshot) + c.persistOAuthAccount() + if credentialsChanged { + c.persistCredentials(c.currentCredentials()) + } } - return "", err } - c.refreshRetryAt = time.Time{} - c.refreshRetryError = nil - c.refreshBlocked = false + return true +} - latestCredentials, latestErr := platformReadCredentials(c.credentialPath) - if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { - c.credentials = latestCredentials - c.stateAccess.Lock() - before := c.statusSnapshotLocked() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.checkTransitionLocked() - shouldEmit := before != c.statusSnapshotLocked() - c.stateAccess.Unlock() - if shouldEmit { - c.emitStatusUpdate() +func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool { + latestCredentials, err := platformReadCredentials(c.credentialPath) + if err == nil && latestCredentials != nil { + c.absorbCredentials(latestCredentials) + if latestCredentials.AccessToken != "" && latestCredentials.AccessToken != failedAccessToken { + return true + } + } + c.tryRefreshCredentials(true) + currentCredentials := c.currentCredentials() + return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken +} + +func (c *defaultCredential) applyProfileSnapshot(snapshot *claudeProfileSnapshot) bool { + if snapshot == nil { + return false + } + + credentialsChanged := false + c.access.Lock() + if c.credentials != nil { + updatedCredentials := cloneCredentials(c.credentials) + if snapshot.SubscriptionType != nil { + updatedCredentials.SubscriptionType = cloneStringPointer(snapshot.SubscriptionType) } - if !latestCredentials.needsRefresh() { - return latestCredentials.AccessToken, nil + if snapshot.RateLimitTier != "" { + updatedCredentials.RateLimitTier = cloneStringPointer(&snapshot.RateLimitTier) } - return "", E.New("credential ", c.tag, " changed while refreshing") + credentialsChanged = !credentialsEqual(c.credentials, updatedCredentials) + c.credentials = updatedCredentials } + c.access.Unlock() - c.credentials = newCredentials c.stateAccess.Lock() before := c.statusSnapshotLocked() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" + if snapshot.OAuthAccount != nil { + c.state.oauthAccount = mergeClaudeOAuthAccount(c.state.oauthAccount, snapshot.OAuthAccount) + if c.state.oauthAccount != nil && c.state.oauthAccount.AccountUUID != "" { + c.state.accountUUID = c.state.oauthAccount.AccountUUID + } + } + if snapshot.AccountType != "" { + c.state.accountType = snapshot.AccountType + } + if snapshot.RateLimitTier != "" { + c.state.rateLimitTier = snapshot.RateLimitTier + } c.checkTransitionLocked() shouldEmit := before != c.statusSnapshotLocked() c.stateAccess.Unlock() if shouldEmit { c.emitStatusUpdate() } + return credentialsChanged +} - err = platformWriteCredentials(newCredentials, c.credentialPath) +func (c *defaultCredential) fetchProfileSnapshot(httpClient *http.Client, accessToken string) (*claudeProfileSnapshot, error) { + ctx := c.serviceContext + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) if err != nil { - c.logger.Error("persist refreshed token for ", c.tag, ": ", err) + return nil, err } + defer response.Body.Close() - return newCredentials.AccessToken, nil + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + return nil, E.New("status ", response.StatusCode, " ", string(body)) + } + + var profileResponse struct { + Account *struct { + UUID string `json:"uuid"` + Email string `json:"email"` + DisplayName string `json:"display_name"` + CreatedAt string `json:"created_at"` + } `json:"account"` + Organization *struct { + UUID string `json:"uuid"` + OrganizationType string `json:"organization_type"` + RateLimitTier string `json:"rate_limit_tier"` + HasExtraUsageEnabled *bool `json:"has_extra_usage_enabled"` + BillingType *string `json:"billing_type"` + SubscriptionCreatedAt *string `json:"subscription_created_at"` + } `json:"organization"` + } + if err := json.NewDecoder(response.Body).Decode(&profileResponse); err != nil { + return nil, err + } + if profileResponse.Organization == nil { + return nil, nil + } + + accountType := normalizeClaudeOrganizationType(profileResponse.Organization.OrganizationType) + snapshot := &claudeProfileSnapshot{ + AccountType: accountType, + RateLimitTier: profileResponse.Organization.RateLimitTier, + } + if accountType != "" { + snapshot.SubscriptionType = cloneStringPointer(&accountType) + } + account := &claudeOAuthAccount{} + if profileResponse.Account != nil { + account.AccountUUID = profileResponse.Account.UUID + account.EmailAddress = profileResponse.Account.Email + account.DisplayName = optionalStringPointer(profileResponse.Account.DisplayName) + account.AccountCreatedAt = optionalStringPointer(profileResponse.Account.CreatedAt) + } + account.OrganizationUUID = profileResponse.Organization.UUID + account.HasExtraUsageEnabled = cloneBoolPointer(profileResponse.Organization.HasExtraUsageEnabled) + account.BillingType = cloneStringPointer(profileResponse.Organization.BillingType) + account.SubscriptionCreatedAt = cloneStringPointer(profileResponse.Organization.SubscriptionCreatedAt) + if account.AccountUUID != "" || account.EmailAddress != "" || account.OrganizationUUID != "" || account.DisplayName != nil || + account.HasExtraUsageEnabled != nil || account.BillingType != nil || account.AccountCreatedAt != nil || account.SubscriptionCreatedAt != nil { + snapshot.OAuthAccount = account + } + return snapshot, nil +} + +func normalizeClaudeOrganizationType(organizationType string) string { + switch organizationType { + case "claude_pro": + return "pro" + case "claude_max": + return "max" + case "claude_team": + return "team" + case "claude_enterprise": + return "enterprise" + default: + return "" + } +} + +func optionalStringPointer(value string) *string { + if value == "" { + return nil + } + return &value } func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { @@ -727,7 +951,7 @@ func (c *defaultCredential) pollUsage() { } c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } - needsProfileFetch := !c.configLoaded && c.state.rateLimitTier == "" + needsProfileFetch := c.needsProfileHydrationLocked() shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -736,83 +960,19 @@ func (c *defaultCredential) pollUsage() { c.emitStatusUpdate() if needsProfileFetch { - c.fetchProfile(httpClient, accessToken) - } -} - -// fetchProfile calls GET /api/oauth/profile to retrieve account and organization info. -// Same endpoint used by Claude Code (@anthropic-ai/claude-code @2.1.81): -// -// ref: cli.js GB() — fetches profile -// ref: cli.js AH8() / fetchProfileInfo — parses organization_type, rate_limit_tier -// ref: cli.js EX1() / populateOAuthAccountInfoIfNeeded — stores account.uuid -func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken string) { - ctx := c.serviceContext - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) + profileSnapshot, err := c.fetchProfileSnapshot(httpClient, accessToken) if err != nil { - return nil, err + c.logger.Debug("fetch profile for ", c.tag, ": ", err) + return + } + if profileSnapshot != nil { + credentialsChanged := c.applyProfileSnapshot(profileSnapshot) + c.persistOAuthAccount() + if credentialsChanged { + c.persistCredentials(c.currentCredentials()) + } } - request.Header.Set("Authorization", "Bearer "+accessToken) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - return request, nil - }) - if err != nil { - c.logger.Debug("fetch profile for ", c.tag, ": ", err) - return - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - return - } - - var profileResponse struct { - Account *struct { - UUID string `json:"uuid"` - } `json:"account"` - Organization *struct { - OrganizationType string `json:"organization_type"` - RateLimitTier string `json:"rate_limit_tier"` - } `json:"organization"` - } - err = json.NewDecoder(response.Body).Decode(&profileResponse) - if err != nil || profileResponse.Organization == nil { - return - } - - accountType := "" - switch profileResponse.Organization.OrganizationType { - case "claude_pro": - accountType = "pro" - case "claude_max": - accountType = "max" - case "claude_team": - accountType = "team" - case "claude_enterprise": - accountType = "enterprise" - } - rateLimitTier := profileResponse.Organization.RateLimitTier - - c.stateAccess.Lock() - before := c.statusSnapshotLocked() - if profileResponse.Account != nil && profileResponse.Account.UUID != "" { - c.state.accountUUID = profileResponse.Account.UUID - } - if accountType != "" && c.state.accountType == "" { - c.state.accountType = accountType - } - if rateLimitTier != "" { - c.state.rateLimitTier = rateLimitTier - } - resolvedAccountType := c.state.accountType - shouldEmit := before != c.statusSnapshotLocked() - c.stateAccess.Unlock() - if shouldEmit { - c.emitStatusUpdate() } - c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier)) } func (c *defaultCredential) close() { diff --git a/service/ccm/credential_default_test.go b/service/ccm/credential_default_test.go new file mode 100644 index 0000000000..a5535aa26b --- /dev/null +++ b/service/ccm/credential_default_test.go @@ -0,0 +1,205 @@ +package ccm + +import ( + "errors" + "net/http" + "os" + "path/filepath" + "testing" + "time" +) + +func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) { + t.Parallel() + + directory := t.TempDir() + credentialPath := filepath.Join(directory, ".credentials.json") + writeTestCredentials(t, credentialPath, &oauthCredentials{ + AccessToken: "old-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + Scopes: []string{"user:profile", "user:inference"}, + SubscriptionType: optionalStringPointer("max"), + RateLimitTier: optionalStringPointer("default_claude_max_20x"), + }) + + credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) { + t.Fatal("refresh should not be attempted when lock acquisition fails") + return nil, nil + })) + if err := credential.reloadCredentials(true); err != nil { + t.Fatal(err) + } + + originalLockFunc := acquireCredentialLockFunc + acquireCredentialLockFunc = func(string) (func(), error) { + return nil, errors.New("locked") + } + t.Cleanup(func() { + acquireCredentialLockFunc = originalLockFunc + }) + + token, err := credential.getAccessToken() + if err != nil { + t.Fatal(err) + } + if token != "old-token" { + t.Fatalf("expected old token, got %q", token) + } +} + +func TestGetAccessTokenAbsorbsRefreshDoneByAnotherProcess(t *testing.T) { + t.Parallel() + + directory := t.TempDir() + credentialPath := filepath.Join(directory, ".credentials.json") + oldCredentials := &oauthCredentials{ + AccessToken: "old-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + Scopes: []string{"user:profile", "user:inference"}, + SubscriptionType: optionalStringPointer("max"), + RateLimitTier: optionalStringPointer("default_claude_max_20x"), + } + writeTestCredentials(t, credentialPath, oldCredentials) + + newCredentials := cloneCredentials(oldCredentials) + newCredentials.AccessToken = "new-token" + newCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli() + transport := roundTripFunc(func(request *http.Request) (*http.Response, error) { + if request.URL.Path == "/v1/oauth/token" { + writeTestCredentials(t, credentialPath, newCredentials) + return newJSONResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil + } + t.Fatalf("unexpected path %s", request.URL.Path) + return nil, nil + }) + + credential := newTestDefaultCredential(t, credentialPath, transport) + if err := credential.reloadCredentials(true); err != nil { + t.Fatal(err) + } + + token, err := credential.getAccessToken() + if err != nil { + t.Fatal(err) + } + if token != "new-token" { + t.Fatalf("expected refreshed token from disk, got %q", token) + } +} + +func TestCustomCredentialPathDoesNotEnableClaudeConfigSync(t *testing.T) { + t.Parallel() + + directory := t.TempDir() + credentialPath := filepath.Join(directory, ".credentials.json") + writeTestCredentials(t, credentialPath, &oauthCredentials{ + AccessToken: "token", + ExpiresAt: time.Now().Add(time.Hour).UnixMilli(), + Scopes: []string{"user:profile"}, + }) + + credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) { + t.Fatalf("unexpected request to %s", request.URL.Path) + return nil, nil + })) + if err := credential.reloadCredentials(true); err != nil { + t.Fatal(err) + } + + token, err := credential.getAccessToken() + if err != nil { + t.Fatal(err) + } + if token != "token" { + t.Fatalf("expected token, got %q", token) + } + if credential.shouldUseClaudeConfig() { + t.Fatal("custom credential path should not enable Claude config sync") + } + if _, err := os.Stat(filepath.Join(directory, ".claude.json")); !os.IsNotExist(err) { + t.Fatalf("did not expect config file to be created, stat err=%v", err) + } +} + +func TestDefaultCredentialHydratesProfileAndWritesConfig(t *testing.T) { + configDir := t.TempDir() + credentialPath := filepath.Join(configDir, ".credentials.json") + + writeTestCredentials(t, credentialPath, &oauthCredentials{ + AccessToken: "old-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + Scopes: []string{"user:profile", "user:inference"}, + }) + + transport := roundTripFunc(func(request *http.Request) (*http.Response, error) { + switch request.URL.Path { + case "/v1/oauth/token": + return newJSONResponse(http.StatusOK, `{ + "access_token":"new-token", + "refresh_token":"new-refresh", + "expires_in":3600, + "account":{"uuid":"account","email_address":"user@example.com"}, + "organization":{"uuid":"org"} + }`), nil + case "/api/oauth/profile": + return newJSONResponse(http.StatusOK, `{ + "account":{ + "uuid":"account", + "email":"user@example.com", + "display_name":"User", + "created_at":"2024-01-01T00:00:00Z" + }, + "organization":{ + "uuid":"org", + "organization_type":"claude_max", + "rate_limit_tier":"default_claude_max_20x", + "has_extra_usage_enabled":true, + "billing_type":"individual", + "subscription_created_at":"2024-01-02T00:00:00Z" + } + }`), nil + default: + t.Fatalf("unexpected path %s", request.URL.Path) + return nil, nil + } + }) + + credential := newTestDefaultCredential(t, credentialPath, transport) + credential.syncClaudeConfig = true + credential.claudeDirectory = configDir + credential.claudeConfigPath = resolveClaudeConfigWritePath(configDir) + if err := credential.reloadCredentials(true); err != nil { + t.Fatal(err) + } + + token, err := credential.getAccessToken() + if err != nil { + t.Fatal(err) + } + if token != "new-token" { + t.Fatalf("expected refreshed token, got %q", token) + } + + updatedCredentials := readTestCredentials(t, credentialPath) + if updatedCredentials.SubscriptionType == nil || *updatedCredentials.SubscriptionType != "max" { + t.Fatalf("expected subscription type to be persisted, got %#v", updatedCredentials.SubscriptionType) + } + if updatedCredentials.RateLimitTier == nil || *updatedCredentials.RateLimitTier != "default_claude_max_20x" { + t.Fatalf("expected rate limit tier to be persisted, got %#v", updatedCredentials.RateLimitTier) + } + + configPath := tempConfigPath(t, configDir) + config, err := readClaudeCodeConfig(configPath) + if err != nil { + t.Fatal(err) + } + if config.OAuthAccount == nil || config.OAuthAccount.AccountUUID != "account" || config.OAuthAccount.EmailAddress != "user@example.com" { + t.Fatalf("unexpected oauth account: %#v", config.OAuthAccount) + } + if config.OAuthAccount.BillingType == nil || *config.OAuthAccount.BillingType != "individual" { + t.Fatalf("expected billing type to be hydrated, got %#v", config.OAuthAccount.BillingType) + } +} diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index 23e71d5442..cf2af4d2d0 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -106,24 +106,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error { return c.markCredentialsUnavailable(E.Cause(err, "read credentials")) } - c.access.Lock() - c.credentials = credentials - c.refreshRetryAt = time.Time{} - c.refreshRetryError = nil - c.refreshBlocked = false - c.access.Unlock() - - c.stateAccess.Lock() - before := c.statusSnapshotLocked() - c.state.unavailable = false - c.state.lastCredentialLoadError = "" - c.checkTransitionLocked() - shouldEmit := before != c.statusSnapshotLocked() - c.stateAccess.Unlock() - if shouldEmit { - c.emitStatusUpdate() - } - + c.absorbCredentials(credentials) return nil } diff --git a/service/ccm/credential_oauth.go b/service/ccm/credential_oauth.go index 71caf62673..03fcb4023a 100644 --- a/service/ccm/credential_oauth.go +++ b/service/ccm/credential_oauth.go @@ -164,21 +164,7 @@ func checkCredentialFileWritable(path string) error { // ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write // ref: cli.js qD1.update (line 176156) — writeFileSync + chmod 0o600 func writeCredentialsToFile(credentials *oauthCredentials, path string) error { - existing := make(map[string]json.RawMessage) - data, readErr := os.ReadFile(path) - if readErr == nil { - _ = json.Unmarshal(data, &existing) - } - credentialData, err := json.Marshal(credentials) - if err != nil { - return err - } - existing["claudeAiOauth"] = credentialData - data, err = json.MarshalIndent(existing, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, data, 0o600) + return writeStorageValue(jsonFileStorage{path: path}, "claudeAiOauth", credentials) } // oauthCredentials mirrors the claudeAiOauth object in Claude Code's @@ -194,6 +180,12 @@ type oauthCredentials struct { RateLimitTier *string `json:"rateLimitTier"` // ref: cli.js line 179452 (?? null) } +type oauthRefreshResult struct { + Credentials *oauthCredentials + TokenAccount *claudeOAuthAccount + Profile *claudeProfileSnapshot +} + func (c *oauthCredentials) needsRefresh() bool { if c.ExpiresAt == 0 { return false @@ -201,7 +193,7 @@ func (c *oauthCredentials) needsRefresh() bool { return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs } -func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, time.Duration, error) { +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthRefreshResult, time.Duration, error) { if credentials.RefreshToken == "" { return nil, 0, E.New("refresh token is empty") } @@ -249,10 +241,17 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau // ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 response (line 172769-172772) var tokenResponse struct { - AccessToken string `json:"access_token"` // ref: cli.js line 172770 z - RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input) - ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O - Scope string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope) + AccessToken string `json:"access_token"` // ref: cli.js line 172770 z + RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input) + ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O + Scope *string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope) + Account *struct { + UUID string `json:"uuid"` + EmailAddress string `json:"email_address"` + } `json:"account"` + Organization *struct { + UUID string `json:"uuid"` + } `json:"organization"` } err = json.NewDecoder(response.Body).Decode(&tokenResponse) if err != nil { @@ -267,11 +266,14 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000 // ref: cli.js uB6 (line 172696-172697): A?.split(" ").filter(Boolean) // strings.Fields matches .filter(Boolean): splits on whitespace runs, removes empty strings - if tokenResponse.Scope != "" { - newCredentials.Scopes = strings.Fields(tokenResponse.Scope) + if tokenResponse.Scope != nil { + newCredentials.Scopes = strings.Fields(*tokenResponse.Scope) } - return &newCredentials, 0, nil + return &oauthRefreshResult{ + Credentials: &newCredentials, + TokenAccount: extractTokenAccount(tokenResponse.Account, tokenResponse.Organization), + }, 0, nil } func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { @@ -280,6 +282,8 @@ func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { } cloned := *credentials cloned.Scopes = append([]string(nil), credentials.Scopes...) + cloned.SubscriptionType = cloneStringPointer(credentials.SubscriptionType) + cloned.RateLimitTier = cloneStringPointer(credentials.RateLimitTier) return &cloned } @@ -290,5 +294,31 @@ func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { return left.AccessToken == right.AccessToken && left.RefreshToken == right.RefreshToken && left.ExpiresAt == right.ExpiresAt && - slices.Equal(left.Scopes, right.Scopes) + slices.Equal(left.Scopes, right.Scopes) && + equalStringPointer(left.SubscriptionType, right.SubscriptionType) && + equalStringPointer(left.RateLimitTier, right.RateLimitTier) +} + +func extractTokenAccount(account *struct { + UUID string `json:"uuid"` + EmailAddress string `json:"email_address"` +}, organization *struct { + UUID string `json:"uuid"` +}, +) *claudeOAuthAccount { + if account == nil && organization == nil { + return nil + } + tokenAccount := &claudeOAuthAccount{} + if account != nil { + tokenAccount.AccountUUID = account.UUID + tokenAccount.EmailAddress = account.EmailAddress + } + if organization != nil { + tokenAccount.OrganizationUUID = organization.UUID + } + if tokenAccount.AccountUUID == "" && tokenAccount.EmailAddress == "" && tokenAccount.OrganizationUUID == "" { + return nil + } + return tokenAccount } diff --git a/service/ccm/credential_oauth_test.go b/service/ccm/credential_oauth_test.go new file mode 100644 index 0000000000..2f98cb0929 --- /dev/null +++ b/service/ccm/credential_oauth_test.go @@ -0,0 +1,141 @@ +package ccm + +import ( + "context" + "encoding/json" + "io" + "net/http" + "slices" + "strings" + "testing" + "time" +) + +func TestRefreshTokenScopeParsing(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + storedScopes []string + responseBody string + expectedScope string + expected []string + }{ + { + name: "missing scope preserves stored scopes", + storedScopes: []string{"user:profile", "user:inference"}, + responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`, + expectedScope: strings.Join(defaultOAuthScopes, " "), + expected: []string{"user:profile", "user:inference"}, + }, + { + name: "empty scope clears stored scopes", + storedScopes: []string{"user:profile", "user:inference"}, + responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":""}`, + expectedScope: strings.Join(defaultOAuthScopes, " "), + expected: []string{}, + }, + { + name: "stored non inference scopes are sent verbatim", + storedScopes: []string{"user:profile"}, + responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":"user:profile user:file_upload"}`, + expectedScope: "user:profile", + expected: []string{"user:profile", "user:file_upload"}, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + var seenScope string + client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + body, err := io.ReadAll(request.Body) + if err != nil { + t.Fatal(err) + } + var payload map[string]string + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatal(err) + } + seenScope = payload["scope"] + return newJSONResponse(http.StatusOK, testCase.responseBody), nil + })} + + result, _, err := refreshToken(context.Background(), client, &oauthCredentials{ + AccessToken: "old-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + Scopes: testCase.storedScopes, + }) + if err != nil { + t.Fatal(err) + } + if seenScope != testCase.expectedScope { + t.Fatalf("expected request scope %q, got %q", testCase.expectedScope, seenScope) + } + if result == nil || result.Credentials == nil { + t.Fatal("expected refresh result credentials") + } + if !slices.Equal(result.Credentials.Scopes, testCase.expected) { + t.Fatalf("expected scopes %v, got %v", testCase.expected, result.Credentials.Scopes) + } + }) + } +} + +func TestRefreshTokenExtractsTokenAccount(t *testing.T) { + t.Parallel() + + client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + return newJSONResponse(http.StatusOK, `{ + "access_token":"new-token", + "refresh_token":"new-refresh", + "expires_in":3600, + "account":{"uuid":"account","email_address":"user@example.com"}, + "organization":{"uuid":"org"} + }`), nil + })} + + result, _, err := refreshToken(context.Background(), client, &oauthCredentials{ + AccessToken: "old-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + Scopes: []string{"user:profile", "user:inference"}, + }) + if err != nil { + t.Fatal(err) + } + if result == nil || result.TokenAccount == nil { + t.Fatal("expected token account") + } + if result.TokenAccount.AccountUUID != "account" || result.TokenAccount.EmailAddress != "user@example.com" || result.TokenAccount.OrganizationUUID != "org" { + t.Fatalf("unexpected token account: %#v", result.TokenAccount) + } +} + +func TestCredentialsEqualIncludesProfileFields(t *testing.T) { + t.Parallel() + + subscriptionType := "max" + rateLimitTier := "default_claude_max_20x" + left := &oauthCredentials{ + AccessToken: "token", + RefreshToken: "refresh", + ExpiresAt: 123, + Scopes: []string{"user:inference"}, + SubscriptionType: &subscriptionType, + RateLimitTier: &rateLimitTier, + } + right := cloneCredentials(left) + if !credentialsEqual(left, right) { + t.Fatal("expected cloned credentials to be equal") + } + + otherTier := "default_claude_max_5x" + right.RateLimitTier = &otherTier + if credentialsEqual(left, right) { + t.Fatal("expected different rate limit tier to break equality") + } +} diff --git a/service/ccm/credential_storage.go b/service/ccm/credential_storage.go new file mode 100644 index 0000000000..74479f8467 --- /dev/null +++ b/service/ccm/credential_storage.go @@ -0,0 +1,124 @@ +package ccm + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" +) + +type jsonContainerStorage interface { + readContainer() (map[string]json.RawMessage, bool, error) + writeContainer(map[string]json.RawMessage) error + delete() error +} + +type jsonFileStorage struct { + path string +} + +func (s jsonFileStorage) readContainer() (map[string]json.RawMessage, bool, error) { + data, err := os.ReadFile(s.path) + if err != nil { + if os.IsNotExist(err) { + return make(map[string]json.RawMessage), false, nil + } + return nil, false, err + } + container := make(map[string]json.RawMessage) + if len(data) == 0 { + return container, true, nil + } + if err := json.Unmarshal(data, &container); err != nil { + return nil, true, err + } + return container, true, nil +} + +func (s jsonFileStorage) writeContainer(container map[string]json.RawMessage) error { + if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil { + return err + } + data, err := json.MarshalIndent(container, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, data, 0o600) +} + +func (s jsonFileStorage) delete() error { + err := os.Remove(s.path) + if err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +func writeStorageValue(storage jsonContainerStorage, key string, value any) error { + container, _, err := storage.readContainer() + if err != nil { + var syntaxError *json.SyntaxError + var typeError *json.UnmarshalTypeError + if !errors.As(err, &syntaxError) && !errors.As(err, &typeError) { + return err + } + container = make(map[string]json.RawMessage) + } + if container == nil { + container = make(map[string]json.RawMessage) + } + encodedValue, err := json.Marshal(value) + if err != nil { + return err + } + container[key] = encodedValue + return storage.writeContainer(container) +} + +func persistStorageValue(primary jsonContainerStorage, fallback jsonContainerStorage, key string, value any) error { + primaryErr := writeStorageValue(primary, key, value) + if primaryErr == nil { + if fallback != nil { + _ = fallback.delete() + } + return nil + } + if fallback == nil { + return primaryErr + } + if err := writeStorageValue(fallback, key, value); err != nil { + return err + } + _ = primary.delete() + return nil +} + +func cloneStringPointer(value *string) *string { + if value == nil { + return nil + } + cloned := *value + return &cloned +} + +func cloneBoolPointer(value *bool) *bool { + if value == nil { + return nil + } + cloned := *value + return &cloned +} + +func equalStringPointer(left *string, right *string) bool { + if left == nil || right == nil { + return left == right + } + return *left == *right +} + +func equalBoolPointer(left *bool, right *bool) bool { + if left == nil || right == nil { + return left == right + } + return *left == *right +} diff --git a/service/ccm/credential_storage_test.go b/service/ccm/credential_storage_test.go new file mode 100644 index 0000000000..fa22ca0ddb --- /dev/null +++ b/service/ccm/credential_storage_test.go @@ -0,0 +1,125 @@ +package ccm + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +type fakeJSONStorage struct { + container map[string]json.RawMessage + writeErr error + deleted bool +} + +func (s *fakeJSONStorage) readContainer() (map[string]json.RawMessage, bool, error) { + if s.container == nil { + return make(map[string]json.RawMessage), false, nil + } + cloned := make(map[string]json.RawMessage, len(s.container)) + for key, value := range s.container { + cloned[key] = value + } + return cloned, true, nil +} + +func (s *fakeJSONStorage) writeContainer(container map[string]json.RawMessage) error { + if s.writeErr != nil { + return s.writeErr + } + s.container = make(map[string]json.RawMessage, len(container)) + for key, value := range container { + s.container[key] = value + } + return nil +} + +func (s *fakeJSONStorage) delete() error { + s.deleted = true + s.container = nil + return nil +} + +func TestPersistStorageValueDeletesFallbackOnPrimarySuccess(t *testing.T) { + t.Parallel() + + primary := &fakeJSONStorage{} + fallback := &fakeJSONStorage{container: map[string]json.RawMessage{"stale": json.RawMessage(`true`)}} + if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "token"}); err != nil { + t.Fatal(err) + } + if !fallback.deleted { + t.Fatal("expected fallback storage to be deleted after primary write") + } +} + +func TestPersistStorageValueDeletesPrimaryAfterFallbackSuccess(t *testing.T) { + t.Parallel() + + primary := &fakeJSONStorage{ + container: map[string]json.RawMessage{"claudeAiOauth": json.RawMessage(`{"accessToken":"old"}`)}, + writeErr: os.ErrPermission, + } + fallback := &fakeJSONStorage{} + if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "new"}); err != nil { + t.Fatal(err) + } + if !primary.deleted { + t.Fatal("expected primary storage to be deleted after fallback write") + } +} + +func TestWriteCredentialsToFilePreservesTopLevelKeys(t *testing.T) { + t.Parallel() + + directory := t.TempDir() + path := filepath.Join(directory, ".credentials.json") + initial := []byte(`{"keep":{"nested":true},"claudeAiOauth":{"accessToken":"old"}}`) + if err := os.WriteFile(path, initial, 0o600); err != nil { + t.Fatal(err) + } + + if err := writeCredentialsToFile(&oauthCredentials{AccessToken: "new"}, path); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var container map[string]json.RawMessage + if err := json.Unmarshal(data, &container); err != nil { + t.Fatal(err) + } + if _, exists := container["keep"]; !exists { + t.Fatal("expected unknown top-level key to be preserved") + } +} + +func TestWriteClaudeCodeOAuthAccountPreservesTopLevelKeys(t *testing.T) { + t.Parallel() + + directory := t.TempDir() + path := filepath.Join(directory, ".claude.json") + initial := []byte(`{"keep":{"nested":true},"oauthAccount":{"accountUuid":"old"}}`) + if err := os.WriteFile(path, initial, 0o600); err != nil { + t.Fatal(err) + } + + if err := writeClaudeCodeOAuthAccount(path, &claudeOAuthAccount{AccountUUID: "new"}); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var container map[string]json.RawMessage + if err := json.Unmarshal(data, &container); err != nil { + t.Fatal(err) + } + if _, exists := container["keep"]; !exists { + t.Fatal("expected unknown config key to be preserved") + } +} diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 870a39514d..ad5dfcaee0 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -377,8 +377,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !selectedCredential.isExternal() && bodyBytes != nil && (response.StatusCode == http.StatusUnauthorized || response.StatusCode == http.StatusForbidden) { shouldRetry := response.StatusCode == http.StatusUnauthorized + var peekBody []byte if response.StatusCode == http.StatusForbidden { - peekBody, _ := io.ReadAll(response.Body) + peekBody, _ = io.ReadAll(response.Body) shouldRetry = strings.Contains(string(peekBody), "OAuth token has been revoked") if !shouldRetry { response.Body.Close() @@ -389,23 +390,33 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } if shouldRetry { - response.Body.Close() - s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying") + recovered := false if defaultCred, ok := selectedCredential.(*defaultCredential); ok { - _ = defaultCred.reloadCredentials(true) - } - retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if buildErr != nil { - writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error()) - return + failedAccessToken := "" + currentCredentials := defaultCred.currentCredentials() + if currentCredentials != nil { + failedAccessToken = currentCredentials.AccessToken + } + s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying") + recovered = defaultCred.recoverAuthFailure(failedAccessToken) } - retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest) - if retryErr != nil { - writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error()) - return + if recovered { + response.Body.Close() + retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if buildErr != nil { + writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error()) + return + } + retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest) + if retryErr != nil { + writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error()) + return + } + response = retryResponse + defer retryResponse.Body.Close() + } else if response.StatusCode == http.StatusForbidden { + response.Body = io.NopCloser(bytes.NewReader(peekBody)) } - response = retryResponse - defer retryResponse.Body.Close() } } diff --git a/service/ccm/service_handler_test.go b/service/ccm/service_handler_test.go new file mode 100644 index 0000000000..968d238631 --- /dev/null +++ b/service/ccm/service_handler_test.go @@ -0,0 +1,221 @@ +package ccm + +import ( + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" +) + +func newHandlerCredential(t *testing.T, transport http.RoundTripper) (*defaultCredential, string) { + t.Helper() + directory := t.TempDir() + credentialPath := filepath.Join(directory, ".credentials.json") + writeTestCredentials(t, credentialPath, &oauthCredentials{ + AccessToken: "old-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(time.Hour).UnixMilli(), + Scopes: []string{"user:profile", "user:inference"}, + SubscriptionType: optionalStringPointer("max"), + RateLimitTier: optionalStringPointer("default_claude_max_20x"), + }) + credential := newTestDefaultCredential(t, credentialPath, transport) + if err := credential.reloadCredentials(true); err != nil { + t.Fatal(err) + } + seedTestCredentialState(credential) + return credential, credentialPath +} + +func TestServiceHandlerRecoversFrom401(t *testing.T) { + t.Parallel() + + var messageRequests atomic.Int32 + var refreshRequests atomic.Int32 + credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) { + switch request.URL.Path { + case "/v1/messages": + call := messageRequests.Add(1) + switch request.Header.Get("Authorization") { + case "Bearer old-token": + if call != 1 { + t.Fatalf("unexpected old-token call count %d", call) + } + return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil + case "Bearer new-token": + return newJSONResponse(http.StatusOK, `{}`), nil + default: + t.Fatalf("unexpected authorization header %q", request.Header.Get("Authorization")) + } + case "/v1/oauth/token": + refreshRequests.Add(1) + return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil + default: + t.Fatalf("unexpected path %s", request.URL.Path) + } + return nil, nil + })) + + service := newTestService(credential) + recorder := httptest.NewRecorder() + service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`)) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String()) + } + if messageRequests.Load() != 2 { + t.Fatalf("expected two upstream message requests, got %d", messageRequests.Load()) + } + if refreshRequests.Load() != 1 { + t.Fatalf("expected one refresh request, got %d", refreshRequests.Load()) + } +} + +func TestServiceHandlerRecoversFromRevoked403(t *testing.T) { + t.Parallel() + + var messageRequests atomic.Int32 + var refreshRequests atomic.Int32 + credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) { + switch request.URL.Path { + case "/v1/messages": + messageRequests.Add(1) + if request.Header.Get("Authorization") == "Bearer old-token" { + return newTextResponse(http.StatusForbidden, "OAuth token has been revoked"), nil + } + return newJSONResponse(http.StatusOK, `{}`), nil + case "/v1/oauth/token": + refreshRequests.Add(1) + return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil + default: + t.Fatalf("unexpected path %s", request.URL.Path) + } + return nil, nil + })) + + service := newTestService(credential) + recorder := httptest.NewRecorder() + service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`)) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String()) + } + if refreshRequests.Load() != 1 { + t.Fatalf("expected one refresh request, got %d", refreshRequests.Load()) + } +} + +func TestServiceHandlerDoesNotRecoverFromOrdinary403(t *testing.T) { + t.Parallel() + + var refreshRequests atomic.Int32 + credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) { + switch request.URL.Path { + case "/v1/messages": + return newTextResponse(http.StatusForbidden, "forbidden"), nil + case "/v1/oauth/token": + refreshRequests.Add(1) + return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil + default: + t.Fatalf("unexpected path %s", request.URL.Path) + } + return nil, nil + })) + + service := newTestService(credential) + recorder := httptest.NewRecorder() + service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`)) + + if recorder.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", recorder.Code) + } + if refreshRequests.Load() != 0 { + t.Fatalf("expected no refresh request, got %d", refreshRequests.Load()) + } + if !strings.Contains(recorder.Body.String(), "forbidden") { + t.Fatalf("expected forbidden body, got %s", recorder.Body.String()) + } +} + +func TestServiceHandlerUsesReloadedTokenBeforeRefreshing(t *testing.T) { + t.Parallel() + + var messageRequests atomic.Int32 + var refreshRequests atomic.Int32 + var credentialPath string + var credential *defaultCredential + credential, credentialPath = newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) { + switch request.URL.Path { + case "/v1/messages": + call := messageRequests.Add(1) + if request.Header.Get("Authorization") == "Bearer old-token" { + updatedCredentials := readTestCredentials(t, credentialPath) + updatedCredentials.AccessToken = "disk-token" + updatedCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli() + writeTestCredentials(t, credentialPath, updatedCredentials) + if call != 1 { + t.Fatalf("unexpected old-token call count %d", call) + } + return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil + } + if request.Header.Get("Authorization") != "Bearer disk-token" { + t.Fatalf("expected disk token retry, got %q", request.Header.Get("Authorization")) + } + return newJSONResponse(http.StatusOK, `{}`), nil + case "/v1/oauth/token": + refreshRequests.Add(1) + return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil + default: + t.Fatalf("unexpected path %s", request.URL.Path) + } + return nil, nil + })) + + service := newTestService(credential) + recorder := httptest.NewRecorder() + service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`)) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String()) + } + if refreshRequests.Load() != 0 { + t.Fatalf("expected zero refresh requests, got %d", refreshRequests.Load()) + } +} + +func TestServiceHandlerRetriesAuthRecoveryOnlyOnce(t *testing.T) { + t.Parallel() + + var messageRequests atomic.Int32 + var refreshRequests atomic.Int32 + credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) { + switch request.URL.Path { + case "/v1/messages": + messageRequests.Add(1) + return newTextResponse(http.StatusUnauthorized, "still unauthorized"), nil + case "/v1/oauth/token": + refreshRequests.Add(1) + return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil + default: + t.Fatalf("unexpected path %s", request.URL.Path) + } + return nil, nil + })) + + service := newTestService(credential) + recorder := httptest.NewRecorder() + service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`)) + + if recorder.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", recorder.Code) + } + if messageRequests.Load() != 2 { + t.Fatalf("expected exactly two upstream attempts, got %d", messageRequests.Load()) + } + if refreshRequests.Load() != 1 { + t.Fatalf("expected exactly one refresh request, got %d", refreshRequests.Load()) + } +} diff --git a/service/ccm/test_helpers_test.go b/service/ccm/test_helpers_test.go new file mode 100644 index 0000000000..175dc71add --- /dev/null +++ b/service/ccm/test_helpers_test.go @@ -0,0 +1,138 @@ +package ccm + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) { + return f(request) +} + +func newJSONResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func newTextResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Header: http.Header{"Content-Type": []string{"text/plain"}}, + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func writeTestCredentials(t *testing.T, path string, credentials *oauthCredentials) { + t.Helper() + if path == "" { + var err error + path, err = getDefaultCredentialsPath() + if err != nil { + t.Fatal(err) + } + } + if err := writeCredentialsToFile(credentials, path); err != nil { + t.Fatal(err) + } +} + +func readTestCredentials(t *testing.T, path string) *oauthCredentials { + t.Helper() + if path == "" { + var err error + path, err = getDefaultCredentialsPath() + if err != nil { + t.Fatal(err) + } + } + credentials, err := readCredentialsFromFile(path) + if err != nil { + t.Fatal(err) + } + return credentials +} + +func newTestDefaultCredential(t *testing.T, credentialPath string, transport http.RoundTripper) *defaultCredential { + t.Helper() + credentialFilePath, err := resolveCredentialFilePath(credentialPath) + if err != nil { + t.Fatal(err) + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: "test", + serviceContext: context.Background(), + credentialPath: credentialPath, + credentialFilePath: credentialFilePath, + configDir: resolveConfigDir(credentialPath, credentialFilePath), + syncClaudeConfig: credentialPath == "", + cap5h: 99, + capWeekly: 99, + forwardHTTPClient: &http.Client{Transport: transport}, + logger: log.NewNOPFactory().Logger(), + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if credential.syncClaudeConfig { + credential.claudeDirectory = credential.configDir + credential.claudeConfigPath = resolveClaudeConfigWritePath(credential.claudeDirectory) + } + credential.state.lastUpdated = time.Now() + return credential +} + +func seedTestCredentialState(credential *defaultCredential) { + billingType := "individual" + accountCreatedAt := "2024-01-01T00:00:00Z" + subscriptionCreatedAt := "2024-01-02T00:00:00Z" + credential.stateAccess.Lock() + credential.state.accountUUID = "account" + credential.state.accountType = "max" + credential.state.rateLimitTier = "default_claude_max_20x" + credential.state.oauthAccount = &claudeOAuthAccount{ + AccountUUID: "account", + EmailAddress: "user@example.com", + OrganizationUUID: "org", + BillingType: &billingType, + AccountCreatedAt: &accountCreatedAt, + SubscriptionCreatedAt: &subscriptionCreatedAt, + } + credential.stateAccess.Unlock() +} + +func newTestService(credential *defaultCredential) *Service { + return &Service{ + logger: log.NewNOPFactory().Logger(), + options: option.CCMServiceOptions{Credentials: []option.CCMCredential{{Tag: "default"}}}, + httpHeaders: make(http.Header), + providers: map[string]credentialProvider{"default": &singleCredentialProvider{credential: credential}}, + sessionModels: make(map[sessionModelKey]time.Time), + } +} + +func newMessageRequest(body string) *http.Request { + request := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + request.Header.Set("Content-Type", "application/json") + return request +} + +func tempConfigPath(t *testing.T, dir string) string { + t.Helper() + return filepath.Join(dir, claudeCodeLegacyConfigFileName()) +} From 4592164a7aa69bac529fb941c06ec5f9177444a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 22:06:10 +0800 Subject: [PATCH 82/96] Align CCM and OCM rate limits --- service/ccm/credential.go | 79 ++++++ service/ccm/credential_default.go | 37 +++ service/ccm/credential_external.go | 102 ++++++- service/ccm/rate_limit_state.go | 124 +++++++++ service/ccm/service.go | 4 + service/ccm/service_status.go | 260 ++++++++++++++++-- service/ccm/service_status_test.go | 173 ++++++++++++ service/ocm/credential.go | 75 ++++++ service/ocm/credential_default.go | 81 ++++-- service/ocm/credential_external.go | 146 ++++++++-- service/ocm/rate_limit_state.go | 384 ++++++++++++++++++++++++++ service/ocm/service.go | 7 +- service/ocm/service_status.go | 419 +++++++++++++++++++++++------ service/ocm/service_status_test.go | 220 +++++++++++++++ service/ocm/service_websocket.go | 180 ++++++++++--- 15 files changed, 2096 insertions(+), 195 deletions(-) create mode 100644 service/ccm/rate_limit_state.go create mode 100644 service/ccm/service_status_test.go create mode 100644 service/ocm/rate_limit_state.go create mode 100644 service/ocm/service_status_test.go diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 20170fca23..e3788c43fc 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -59,6 +59,17 @@ type credentialState struct { weeklyReset time.Time hardRateLimited bool rateLimitResetAt time.Time + availabilityState availabilityState + availabilityReason availabilityReason + availabilityResetAt time.Time + lastKnownDataAt time.Time + unifiedStatus unifiedRateLimitStatus + unifiedResetAt time.Time + representativeClaim string + unifiedFallbackAvailable bool + overageStatus string + overageResetAt time.Time + overageDisabledReason string accountUUID string accountType string rateLimitTier string @@ -103,6 +114,7 @@ type Credential interface { isAvailable() bool isUsable() bool isExternal() bool + hasSnapshotData() bool fiveHourUtilization() float64 weeklyUtilization() float64 fiveHourCap() float64 @@ -112,6 +124,8 @@ type Credential interface { weeklyResetTime() time.Time markRateLimited(resetAt time.Time) markUpstreamRejected() + availabilityStatus() availabilityStatus + unifiedRateLimitState() unifiedRateLimitInfo earliestReset() time.Time unavailableError() error @@ -185,6 +199,71 @@ func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) t return parseAnthropicResetHeaderValue(headerName, headerValue) } +func (s *credentialState) noteSnapshotData() { + s.lastKnownDataAt = time.Now() +} + +func (s credentialState) hasSnapshotData() bool { + return !s.lastKnownDataAt.IsZero() || + s.fiveHourUtilization > 0 || + s.weeklyUtilization > 0 || + !s.fiveHourReset.IsZero() || + !s.weeklyReset.IsZero() +} + +func (s *credentialState) setAvailability(state availabilityState, reason availabilityReason, resetAt time.Time) { + s.availabilityState = state + s.availabilityReason = reason + s.availabilityResetAt = resetAt +} + +func (s credentialState) currentAvailability() availabilityStatus { + now := time.Now() + switch { + case s.unavailable: + return availabilityStatus{ + State: availabilityStateUnavailable, + Reason: availabilityReasonUnknown, + ResetAt: s.availabilityResetAt, + } + case s.hardRateLimited && (s.rateLimitResetAt.IsZero() || now.Before(s.rateLimitResetAt)): + reason := s.availabilityReason + if reason == "" { + reason = availabilityReasonHardRateLimit + } + return availabilityStatus{ + State: availabilityStateRateLimited, + Reason: reason, + ResetAt: s.rateLimitResetAt, + } + case !s.upstreamRejectedUntil.IsZero() && now.Before(s.upstreamRejectedUntil): + return availabilityStatus{ + State: availabilityStateTemporarilyBlocked, + Reason: availabilityReasonUpstreamRejected, + ResetAt: s.upstreamRejectedUntil, + } + case s.consecutivePollFailures > 0: + return availabilityStatus{ + State: availabilityStateTemporarilyBlocked, + Reason: availabilityReasonPollFailed, + } + default: + return availabilityStatus{State: availabilityStateUsable} + } +} + +func (s credentialState) currentUnifiedRateLimit() unifiedRateLimitInfo { + return unifiedRateLimitInfo{ + Status: s.unifiedStatus, + ResetAt: s.unifiedResetAt, + RepresentativeClaim: s.representativeClaim, + FallbackAvailable: s.unifiedFallbackAvailable, + OverageStatus: s.overageStatus, + OverageResetAt: s.overageResetAt, + OverageDisabledReason: s.overageDisabledReason, + }.normalized() +} + func parseRateLimitResetFromHeaders(headers http.Header) time.Time { claim := headers.Get("anthropic-ratelimit-unified-representative-claim") switch claim { diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index ea74d51e73..5467d3d8d8 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -623,7 +623,21 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { if hadData { c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() + c.state.noteSnapshotData() } + if unifiedStatus := unifiedRateLimitStatus(headers.Get("anthropic-ratelimit-unified-status")); unifiedStatus != "" { + c.state.unifiedStatus = unifiedStatus + } + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-reset"); exists { + c.state.unifiedResetAt = value + } + c.state.representativeClaim = headers.Get("anthropic-ratelimit-unified-representative-claim") + c.state.unifiedFallbackAvailable = headers.Get("anthropic-ratelimit-unified-fallback") == "available" + c.state.overageStatus = headers.Get("anthropic-ratelimit-unified-overage-status") + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-overage-reset"); exists { + c.state.overageResetAt = value + } + c.state.overageDisabledReason = headers.Get("anthropic-ratelimit-unified-overage-disabled-reason") if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { @@ -647,6 +661,9 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { c.stateAccess.Lock() c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt + c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt) + c.state.unifiedStatus = unifiedRateLimitStatusRejected + c.state.unifiedResetAt = resetAt shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -741,6 +758,12 @@ func (c *defaultCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *defaultCredential) hasSnapshotData() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.hasSnapshotData() +} + func (c *defaultCredential) planWeight() float64 { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -767,6 +790,18 @@ func (c *defaultCredential) isAvailable() bool { return !c.state.unavailable } +func (c *defaultCredential) availabilityStatus() availabilityStatus { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.currentAvailability() +} + +func (c *defaultCredential) unifiedRateLimitState() unifiedRateLimitInfo { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.currentUnifiedRateLimit() +} + func (c *defaultCredential) unavailableError() error { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -794,6 +829,7 @@ func (c *defaultCredential) markUsagePollAttempted() { func (c *defaultCredential) incrementPollFailures() { c.stateAccess.Lock() c.state.consecutivePollFailures++ + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{}) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -944,6 +980,7 @@ func (c *defaultCredential) pollUsage() { if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } + c.state.noteSnapshotData() if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 94f353b6a2..e8a0cf1ea6 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -343,6 +343,9 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { c.stateAccess.Lock() c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt + c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt) + c.state.unifiedStatus = unifiedRateLimitStatusRejected + c.state.unifiedResetAt = resetAt shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -355,6 +358,7 @@ func (c *externalCredential) markUpstreamRejected() { c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(c.pollInterval)) c.stateAccess.Lock() c.state.upstreamRejectedUntil = time.Now().Add(c.pollInterval) + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -493,7 +497,21 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.consecutivePollFailures = 0 c.state.upstreamRejectedUntil = time.Time{} c.state.lastUpdated = time.Now() + c.state.noteSnapshotData() } + if unifiedStatus := unifiedRateLimitStatus(headers.Get("anthropic-ratelimit-unified-status")); unifiedStatus != "" { + c.state.unifiedStatus = unifiedStatus + } + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-reset"); exists { + c.state.unifiedResetAt = value + } + c.state.representativeClaim = headers.Get("anthropic-ratelimit-unified-representative-claim") + c.state.unifiedFallbackAvailable = headers.Get("anthropic-ratelimit-unified-fallback") == "available" + c.state.overageStatus = headers.Get("anthropic-ratelimit-unified-overage-status") + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-overage-reset"); exists { + c.state.overageResetAt = value + } + c.state.overageDisabledReason = headers.Get("anthropic-ratelimit-unified-overage-disabled-reason") if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { @@ -635,13 +653,7 @@ func (c *externalCredential) pollUsage() { c.clearPollFailures() return } - var statusResponse struct { - FiveHourUtilization float64 `json:"five_hour_utilization"` - FiveHourReset int64 `json:"five_hour_reset"` - WeeklyUtilization float64 `json:"weekly_utilization"` - WeeklyReset int64 `json:"weekly_reset"` - PlanWeight float64 `json:"plan_weight"` - } + var statusResponse statusPayload err = json.Unmarshal(body, &statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) @@ -657,6 +669,11 @@ func (c *externalCredential) pollUsage() { c.state.upstreamRejectedUntil = time.Time{} c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + c.state.unifiedStatus = unifiedRateLimitStatus(statusResponse.UnifiedStatus) + c.state.representativeClaim = statusResponse.RepresentativeClaim + c.state.unifiedFallbackAvailable = statusResponse.FallbackAvailable + c.state.overageStatus = statusResponse.OverageStatus + c.state.overageDisabledReason = statusResponse.OverageDisabledReason if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } @@ -666,6 +683,30 @@ func (c *externalCredential) pollUsage() { if statusResponse.WeeklyReset > 0 { c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) } + if statusResponse.UnifiedReset > 0 { + c.state.unifiedResetAt = time.Unix(statusResponse.UnifiedReset, 0) + } + if statusResponse.OverageReset > 0 { + c.state.overageResetAt = time.Unix(statusResponse.OverageReset, 0) + } + if statusResponse.Availability != nil { + switch availabilityState(statusResponse.Availability.State) { + case availabilityStateRateLimited: + c.state.hardRateLimited = true + if statusResponse.Availability.ResetAt > 0 { + c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0) + } + case availabilityStateTemporarilyBlocked: + resetAt := time.Time{} + if statusResponse.Availability.ResetAt > 0 { + resetAt = time.Unix(statusResponse.Availability.ResetAt, 0) + } + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt) + if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() { + c.state.upstreamRejectedUntil = resetAt + } + } + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } @@ -766,6 +807,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr c.state.upstreamRejectedUntil = time.Time{} c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + c.state.unifiedStatus = unifiedRateLimitStatus(statusResponse.UnifiedStatus) + c.state.representativeClaim = statusResponse.RepresentativeClaim + c.state.unifiedFallbackAvailable = statusResponse.FallbackAvailable + c.state.overageStatus = statusResponse.OverageStatus + c.state.overageDisabledReason = statusResponse.OverageDisabledReason if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } @@ -775,6 +821,30 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if statusResponse.WeeklyReset > 0 { c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) } + if statusResponse.UnifiedReset > 0 { + c.state.unifiedResetAt = time.Unix(statusResponse.UnifiedReset, 0) + } + if statusResponse.OverageReset > 0 { + c.state.overageResetAt = time.Unix(statusResponse.OverageReset, 0) + } + if statusResponse.Availability != nil { + switch availabilityState(statusResponse.Availability.State) { + case availabilityStateRateLimited: + c.state.hardRateLimited = true + if statusResponse.Availability.ResetAt > 0 { + c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0) + } + case availabilityStateTemporarilyBlocked: + resetAt := time.Time{} + if statusResponse.Availability.ResetAt > 0 { + resetAt = time.Unix(statusResponse.Availability.ResetAt, 0) + } + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt) + if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() { + c.state.upstreamRejectedUntil = resetAt + } + } + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } @@ -846,6 +916,24 @@ func (c *externalCredential) lastUpdatedTime() time.Time { return c.state.lastUpdated } +func (c *externalCredential) hasSnapshotData() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.hasSnapshotData() +} + +func (c *externalCredential) availabilityStatus() availabilityStatus { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.currentAvailability() +} + +func (c *externalCredential) unifiedRateLimitState() unifiedRateLimitInfo { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.currentUnifiedRateLimit() +} + func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() defer c.stateAccess.Unlock() diff --git a/service/ccm/rate_limit_state.go b/service/ccm/rate_limit_state.go new file mode 100644 index 0000000000..ab584419fb --- /dev/null +++ b/service/ccm/rate_limit_state.go @@ -0,0 +1,124 @@ +package ccm + +import "time" + +type availabilityState string + +const ( + availabilityStateUsable availabilityState = "usable" + availabilityStateRateLimited availabilityState = "rate_limited" + availabilityStateTemporarilyBlocked availabilityState = "temporarily_blocked" + availabilityStateUnavailable availabilityState = "unavailable" + availabilityStateUnknown availabilityState = "unknown" +) + +type availabilityReason string + +const ( + availabilityReasonHardRateLimit availabilityReason = "hard_rate_limit" + availabilityReasonConnectionLimit availabilityReason = "connection_limit" + availabilityReasonPollFailed availabilityReason = "poll_failed" + availabilityReasonUpstreamRejected availabilityReason = "upstream_rejected" + availabilityReasonNoCredentials availabilityReason = "no_credentials" + availabilityReasonUnknown availabilityReason = "unknown" +) + +type availabilityStatus struct { + State availabilityState + Reason availabilityReason + ResetAt time.Time +} + +type availabilityPayload struct { + State string `json:"state"` + Reason string `json:"reason,omitempty"` + ResetAt int64 `json:"reset_at,omitempty"` +} + +func (s availabilityStatus) normalized() availabilityStatus { + if s.State == "" { + s.State = availabilityStateUnknown + } + if s.Reason == "" && s.State != availabilityStateUsable { + s.Reason = availabilityReasonUnknown + } + return s +} + +func (s availabilityStatus) toPayload() *availabilityPayload { + s = s.normalized() + if s.State == "" { + return nil + } + payload := &availabilityPayload{ + State: string(s.State), + } + if s.Reason != "" && s.Reason != availabilityReasonUnknown { + payload.Reason = string(s.Reason) + } + if !s.ResetAt.IsZero() { + payload.ResetAt = s.ResetAt.Unix() + } + return payload +} + +type unifiedRateLimitStatus string + +const ( + unifiedRateLimitStatusAllowed unifiedRateLimitStatus = "allowed" + unifiedRateLimitStatusAllowedWarning unifiedRateLimitStatus = "allowed_warning" + unifiedRateLimitStatusRejected unifiedRateLimitStatus = "rejected" +) + +type unifiedRateLimitInfo struct { + Status unifiedRateLimitStatus + ResetAt time.Time + RepresentativeClaim string + FallbackAvailable bool + OverageStatus string + OverageResetAt time.Time + OverageDisabledReason string +} + +func (s unifiedRateLimitInfo) normalized() unifiedRateLimitInfo { + if s.Status == "" { + s.Status = unifiedRateLimitStatusAllowed + } + return s +} + +func claudeWindowProgress(resetAt time.Time, windowSeconds float64, now time.Time) float64 { + if resetAt.IsZero() || windowSeconds <= 0 { + return 0 + } + windowStart := resetAt.Add(-time.Duration(windowSeconds * float64(time.Second))) + if now.Before(windowStart) { + return 0 + } + progress := now.Sub(windowStart).Seconds() / windowSeconds + if progress < 0 { + return 0 + } + if progress > 1 { + return 1 + } + return progress +} + +func claudeFiveHourWarning(utilizationPercent float64, resetAt time.Time, now time.Time) bool { + return utilizationPercent >= 90 && claudeWindowProgress(resetAt, 5*60*60, now) >= 0.72 +} + +func claudeWeeklyWarning(utilizationPercent float64, resetAt time.Time, now time.Time) bool { + progress := claudeWindowProgress(resetAt, 7*24*60*60, now) + switch { + case utilizationPercent >= 75: + return progress >= 0.60 + case utilizationPercent >= 50: + return progress >= 0.35 + case utilizationPercent >= 25: + return progress >= 0.15 + default: + return false + } +} diff --git a/service/ccm/service.go b/service/ccm/service.go index d3f76381f0..a95e3175a2 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -105,6 +105,10 @@ func writeCredentialUnavailableError( writeRetryableUsageError(w, r) return } + if provider != nil && strings.HasPrefix(allCredentialsUnavailableError(provider.allCredentials()).Error(), "all credentials rate-limited") { + writeRetryableUsageError(w, r) + return + } writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback)) } diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index 41256b7f8b..11ae3fd3ad 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "net/http" + "reflect" "strconv" "strings" "time" @@ -12,11 +13,19 @@ import ( ) type statusPayload struct { - FiveHourUtilization float64 `json:"five_hour_utilization"` - FiveHourReset int64 `json:"five_hour_reset"` - WeeklyUtilization float64 `json:"weekly_utilization"` - WeeklyReset int64 `json:"weekly_reset"` - PlanWeight float64 `json:"plan_weight"` + FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` + WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` + PlanWeight float64 `json:"plan_weight"` + UnifiedStatus string `json:"unified_status,omitempty"` + UnifiedReset int64 `json:"unified_reset,omitempty"` + RepresentativeClaim string `json:"representative_claim,omitempty"` + FallbackAvailable bool `json:"fallback_available,omitempty"` + OverageStatus string `json:"overage_status,omitempty"` + OverageReset int64 `json:"overage_reset,omitempty"` + OverageDisabledReason string `json:"overage_disabled_reason,omitempty"` + Availability *availabilityPayload `json:"availability,omitempty"` } type aggregatedStatus struct { @@ -25,6 +34,8 @@ type aggregatedStatus struct { totalWeight float64 fiveHourReset time.Time weeklyReset time.Time + unifiedRateLimit unifiedRateLimitInfo + availability availabilityStatus } func resetToEpoch(t time.Time) int64 { @@ -35,23 +46,176 @@ func resetToEpoch(t time.Time) int64 { } func (s aggregatedStatus) equal(other aggregatedStatus) bool { - return s.fiveHourUtilization == other.fiveHourUtilization && - s.weeklyUtilization == other.weeklyUtilization && - s.totalWeight == other.totalWeight && - resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) && - resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset) + return reflect.DeepEqual(s.toPayload(), other.toPayload()) } func (s aggregatedStatus) toPayload() statusPayload { + unified := s.unifiedRateLimit.normalized() return statusPayload{ - FiveHourUtilization: s.fiveHourUtilization, - FiveHourReset: resetToEpoch(s.fiveHourReset), - WeeklyUtilization: s.weeklyUtilization, - WeeklyReset: resetToEpoch(s.weeklyReset), - PlanWeight: s.totalWeight, + FiveHourUtilization: s.fiveHourUtilization, + FiveHourReset: resetToEpoch(s.fiveHourReset), + WeeklyUtilization: s.weeklyUtilization, + WeeklyReset: resetToEpoch(s.weeklyReset), + PlanWeight: s.totalWeight, + UnifiedStatus: string(unified.Status), + UnifiedReset: resetToEpoch(unified.ResetAt), + RepresentativeClaim: unified.RepresentativeClaim, + FallbackAvailable: unified.FallbackAvailable, + OverageStatus: unified.OverageStatus, + OverageReset: resetToEpoch(unified.OverageResetAt), + OverageDisabledReason: unified.OverageDisabledReason, + Availability: s.availability.toPayload(), } } +type aggregateInput struct { + availability availabilityStatus + unified unifiedRateLimitInfo +} + +func aggregateAvailability(inputs []aggregateInput) availabilityStatus { + if len(inputs) == 0 { + return availabilityStatus{ + State: availabilityStateUnavailable, + Reason: availabilityReasonNoCredentials, + } + } + var earliestRateLimit time.Time + var hasRateLimited bool + var blocked availabilityStatus + var hasBlocked bool + var hasUnavailable bool + for _, input := range inputs { + availability := input.availability.normalized() + switch availability.State { + case availabilityStateUsable: + return availabilityStatus{State: availabilityStateUsable} + case availabilityStateRateLimited: + hasRateLimited = true + if !availability.ResetAt.IsZero() && (earliestRateLimit.IsZero() || availability.ResetAt.Before(earliestRateLimit)) { + earliestRateLimit = availability.ResetAt + } + if blocked.State == "" { + blocked = availabilityStatus{ + State: availabilityStateRateLimited, + Reason: availabilityReasonHardRateLimit, + ResetAt: earliestRateLimit, + } + } + case availabilityStateTemporarilyBlocked: + if !hasBlocked { + blocked = availability + hasBlocked = true + } + if !availability.ResetAt.IsZero() && (blocked.ResetAt.IsZero() || availability.ResetAt.Before(blocked.ResetAt)) { + blocked.ResetAt = availability.ResetAt + } + case availabilityStateUnavailable: + hasUnavailable = true + } + } + if hasRateLimited { + blocked.ResetAt = earliestRateLimit + return blocked + } + if hasBlocked { + return blocked + } + if hasUnavailable { + return availabilityStatus{ + State: availabilityStateUnavailable, + Reason: availabilityReasonUnknown, + } + } + return availabilityStatus{ + State: availabilityStateUnknown, + Reason: availabilityReasonUnknown, + } +} + +func chooseRepresentativeClaim(status unifiedRateLimitStatus, fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, now time.Time) string { + type claimCandidate struct { + name string + priority int + utilization float64 + } + candidateFor := func(name string, utilization float64, warning bool) claimCandidate { + priority := 0 + switch { + case status == unifiedRateLimitStatusRejected && utilization >= 100: + priority = 2 + case warning: + priority = 1 + } + return claimCandidate{name: name, priority: priority, utilization: utilization} + } + five := candidateFor("5h", fiveHourUtilization, claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now)) + weekly := candidateFor("7d", weeklyUtilization, claudeWeeklyWarning(weeklyUtilization, weeklyReset, now)) + switch { + case five.priority > weekly.priority: + return five.name + case weekly.priority > five.priority: + return weekly.name + case five.utilization > weekly.utilization: + return five.name + case weekly.utilization > five.utilization: + return weekly.name + case !fiveHourReset.IsZero(): + return five.name + case !weeklyReset.IsZero(): + return weekly.name + default: + return "5h" + } +} + +func aggregateUnifiedRateLimit(inputs []aggregateInput, fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, availability availabilityStatus) unifiedRateLimitInfo { + now := time.Now() + info := unifiedRateLimitInfo{} + usableCount := 0 + for _, input := range inputs { + if input.availability.State == availabilityStateUsable { + usableCount++ + } + if input.unified.OverageStatus != "" && info.OverageStatus == "" { + info.OverageStatus = input.unified.OverageStatus + info.OverageResetAt = input.unified.OverageResetAt + info.OverageDisabledReason = input.unified.OverageDisabledReason + } + if input.unified.Status == unifiedRateLimitStatusRejected { + info.Status = unifiedRateLimitStatusRejected + if !input.unified.ResetAt.IsZero() && (info.ResetAt.IsZero() || input.unified.ResetAt.Before(info.ResetAt)) { + info.ResetAt = input.unified.ResetAt + info.RepresentativeClaim = input.unified.RepresentativeClaim + } + } + } + if info.Status == "" { + switch { + case availability.State == availabilityStateRateLimited || fiveHourUtilization >= 100 || weeklyUtilization >= 100: + info.Status = unifiedRateLimitStatusRejected + info.ResetAt = availability.ResetAt + case claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now) || claudeWeeklyWarning(weeklyUtilization, weeklyReset, now): + info.Status = unifiedRateLimitStatusAllowedWarning + default: + info.Status = unifiedRateLimitStatusAllowed + } + } + info.FallbackAvailable = usableCount > 0 && len(inputs) > 1 + if info.RepresentativeClaim == "" { + info.RepresentativeClaim = chooseRepresentativeClaim(info.Status, fiveHourUtilization, fiveHourReset, weeklyUtilization, weeklyReset, now) + } + if info.ResetAt.IsZero() { + switch info.RepresentativeClaim { + case "7d": + info.ResetAt = weeklyReset + default: + info.ResetAt = fiveHourReset + } + } + return info.normalized() +} + func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -171,20 +335,27 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus { + visibleInputs := make([]aggregateInput, 0, len(provider.allCredentials())) var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 now := time.Now() var totalWeightedHoursUntil5hReset, total5hResetWeight float64 var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 + var hasSnapshotData bool for _, credential := range provider.allCredentials() { - if !credential.isUsable() { - continue - } if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { continue } if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() { continue } + visibleInputs = append(visibleInputs, aggregateInput{ + availability: credential.availabilityStatus(), + unified: credential.unifiedRateLimitState(), + }) + if !credential.hasSnapshotData() { + continue + } + hasSnapshotData = true weight := credential.planWeight() remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization() if remaining5h < 0 { @@ -215,16 +386,21 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user } } } + availability := aggregateAvailability(visibleInputs) if totalWeight == 0 { - return aggregatedStatus{ - fiveHourUtilization: 100, - weeklyUtilization: 100, + result := aggregatedStatus{availability: availability} + if !hasSnapshotData { + result.fiveHourUtilization = 100 + result.weeklyUtilization = 100 } + result.unifiedRateLimit = aggregateUnifiedRateLimit(visibleInputs, result.fiveHourUtilization, result.fiveHourReset, result.weeklyUtilization, result.weeklyReset, availability) + return result } result := aggregatedStatus{ fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight, weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight, totalWeight: totalWeight, + availability: availability, } if total5hResetWeight > 0 { avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight @@ -234,6 +410,7 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour))) } + result.unifiedRateLimit = aggregateUnifiedRateLimit(visibleInputs, result.fiveHourUtilization, result.fiveHourReset, result.weeklyUtilization, result.weeklyReset, availability) return result } @@ -254,4 +431,45 @@ func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentia if status.totalWeight > 0 { headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } + headers.Set("anthropic-ratelimit-unified-status", string(status.unifiedRateLimit.normalized().Status)) + if !status.unifiedRateLimit.ResetAt.IsZero() { + headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.unifiedRateLimit.ResetAt.Unix(), 10)) + } else { + headers.Del("anthropic-ratelimit-unified-reset") + } + if status.unifiedRateLimit.RepresentativeClaim != "" { + headers.Set("anthropic-ratelimit-unified-representative-claim", status.unifiedRateLimit.RepresentativeClaim) + } else { + headers.Del("anthropic-ratelimit-unified-representative-claim") + } + if status.unifiedRateLimit.FallbackAvailable { + headers.Set("anthropic-ratelimit-unified-fallback", "available") + } else { + headers.Del("anthropic-ratelimit-unified-fallback") + } + if status.unifiedRateLimit.OverageStatus != "" { + headers.Set("anthropic-ratelimit-unified-overage-status", status.unifiedRateLimit.OverageStatus) + } else { + headers.Del("anthropic-ratelimit-unified-overage-status") + } + if !status.unifiedRateLimit.OverageResetAt.IsZero() { + headers.Set("anthropic-ratelimit-unified-overage-reset", strconv.FormatInt(status.unifiedRateLimit.OverageResetAt.Unix(), 10)) + } else { + headers.Del("anthropic-ratelimit-unified-overage-reset") + } + if status.unifiedRateLimit.OverageDisabledReason != "" { + headers.Set("anthropic-ratelimit-unified-overage-disabled-reason", status.unifiedRateLimit.OverageDisabledReason) + } else { + headers.Del("anthropic-ratelimit-unified-overage-disabled-reason") + } + if claudeFiveHourWarning(status.fiveHourUtilization, status.fiveHourReset, time.Now()) || status.fiveHourUtilization >= 100 { + headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true") + } else { + headers.Del("anthropic-ratelimit-unified-5h-surpassed-threshold") + } + if claudeWeeklyWarning(status.weeklyUtilization, status.weeklyReset, time.Now()) || status.weeklyUtilization >= 100 { + headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "true") + } else { + headers.Del("anthropic-ratelimit-unified-7d-surpassed-threshold") + } } diff --git a/service/ccm/service_status_test.go b/service/ccm/service_status_test.go new file mode 100644 index 0000000000..9aef16de39 --- /dev/null +++ b/service/ccm/service_status_test.go @@ -0,0 +1,173 @@ +package ccm + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/sagernet/sing/common/observable" +) + +type testCredential struct { + tag string + external bool + available bool + usable bool + hasData bool + fiveHour float64 + weekly float64 + fiveHourCapV float64 + weeklyCapV float64 + weight float64 + fiveReset time.Time + weeklyReset time.Time + availability availabilityStatus + unified unifiedRateLimitInfo +} + +func (c *testCredential) tagName() string { return c.tag } +func (c *testCredential) isAvailable() bool { return c.available } +func (c *testCredential) isUsable() bool { return c.usable } +func (c *testCredential) isExternal() bool { return c.external } +func (c *testCredential) hasSnapshotData() bool { return c.hasData } +func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour } +func (c *testCredential) weeklyUtilization() float64 { return c.weekly } +func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV } +func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV } +func (c *testCredential) planWeight() float64 { return c.weight } +func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset } +func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset } +func (c *testCredential) markRateLimited(time.Time) {} +func (c *testCredential) markUpstreamRejected() {} +func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability } +func (c *testCredential) unifiedRateLimitState() unifiedRateLimitInfo { return c.unified } +func (c *testCredential) earliestReset() time.Time { return c.fiveReset } +func (c *testCredential) unavailableError() error { return nil } +func (c *testCredential) getAccessToken() (string, error) { return "", nil } +func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) { + return nil, nil +} +func (c *testCredential) updateStateFromHeaders(http.Header) {} +func (c *testCredential) wrapRequestContext(context.Context) *credentialRequestContext { return nil } +func (c *testCredential) interruptConnections() {} +func (c *testCredential) setStatusSubscriber(*observable.Subscriber[struct{}]) {} +func (c *testCredential) start() error { return nil } +func (c *testCredential) pollUsage() {} +func (c *testCredential) lastUpdatedTime() time.Time { return time.Now() } +func (c *testCredential) pollBackoff(time.Duration) time.Duration { return 0 } +func (c *testCredential) usageTrackerOrNil() *AggregatedUsage { return nil } +func (c *testCredential) httpClient() *http.Client { return nil } +func (c *testCredential) close() {} + +type testProvider struct { + credentials []Credential +} + +func (p *testProvider) selectCredential(string, credentialSelection) (Credential, bool, error) { + return nil, false, nil +} +func (p *testProvider) onRateLimited(string, Credential, time.Time, credentialSelection) Credential { + return nil +} +func (p *testProvider) linkProviderInterrupt(Credential, credentialSelection, func()) func() bool { + return func() bool { return true } +} +func (p *testProvider) pollIfStale() {} +func (p *testProvider) pollCredentialIfStale(Credential) {} +func (p *testProvider) allCredentials() []Credential { return p.credentials } +func (p *testProvider) close() {} + +func TestComputeAggregatedUtilizationPreservesSnapshotForRateLimitedCredential(t *testing.T) { + t.Parallel() + + reset := time.Now().Add(15 * time.Minute) + service := &Service{} + status := service.computeAggregatedUtilization(&testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: false, + hasData: true, + fiveHour: 42, + weekly: 18, + fiveHourCapV: 100, + weeklyCapV: 100, + weight: 1, + fiveReset: reset, + weeklyReset: reset.Add(2 * time.Hour), + availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: reset}, + unified: unifiedRateLimitInfo{Status: unifiedRateLimitStatusRejected, ResetAt: reset, RepresentativeClaim: "5h"}, + }, + }}, nil) + + if status.fiveHourUtilization != 42 || status.weeklyUtilization != 18 { + t.Fatalf("expected preserved utilization, got 5h=%v weekly=%v", status.fiveHourUtilization, status.weeklyUtilization) + } + if status.unifiedRateLimit.Status != unifiedRateLimitStatusRejected { + t.Fatalf("expected rejected unified status, got %q", status.unifiedRateLimit.Status) + } + if status.availability.State != availabilityStateRateLimited { + t.Fatalf("expected rate-limited availability, got %#v", status.availability) + } +} + +func TestRewriteResponseHeadersIncludesUnifiedHeaders(t *testing.T) { + t.Parallel() + + reset := time.Now().Add(80 * time.Minute) + service := &Service{} + headers := make(http.Header) + service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: true, + hasData: true, + fiveHour: 92, + weekly: 30, + fiveHourCapV: 100, + weeklyCapV: 100, + weight: 1, + fiveReset: reset, + weeklyReset: time.Now().Add(4 * 24 * time.Hour), + availability: availabilityStatus{State: availabilityStateUsable}, + }, + }}, nil) + + if headers.Get("anthropic-ratelimit-unified-status") != "allowed_warning" { + t.Fatalf("expected allowed_warning, got %q", headers.Get("anthropic-ratelimit-unified-status")) + } + if headers.Get("anthropic-ratelimit-unified-representative-claim") != "5h" { + t.Fatalf("expected 5h representative claim, got %q", headers.Get("anthropic-ratelimit-unified-representative-claim")) + } + if headers.Get("anthropic-ratelimit-unified-5h-surpassed-threshold") != "true" { + t.Fatalf("expected 5h threshold header") + } +} + +func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/v1/messages", nil) + provider := &testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: false, + hasData: true, + fiveHourCapV: 100, + weeklyCapV: 100, + weight: 1, + availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)}, + }, + }} + + writeCredentialUnavailableError(recorder, request, provider, provider.credentials[0], credentialSelection{}, "all credentials rate-limited") + + if recorder.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429, got %d", recorder.Code) + } +} diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 0c4e56cddd..2e2589366f 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -61,8 +61,14 @@ type credentialState struct { weeklyReset time.Time hardRateLimited bool rateLimitResetAt time.Time + availabilityState availabilityState + availabilityReason availabilityReason + availabilityResetAt time.Time + lastKnownDataAt time.Time accountType string remotePlanWeight float64 + activeLimitID string + rateLimitSnapshots map[string]rateLimitSnapshot lastUpdated time.Time consecutivePollFailures int usageAPIRetryDelay time.Duration @@ -102,6 +108,7 @@ type Credential interface { isAvailable() bool isUsable() bool isExternal() bool + hasSnapshotData() bool fiveHourUtilization() float64 weeklyUtilization() float64 fiveHourCap() float64 @@ -111,6 +118,10 @@ type Credential interface { fiveHourResetTime() time.Time markRateLimited(resetAt time.Time) markUpstreamRejected() + markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) + availabilityStatus() availabilityStatus + rateLimitSnapshots() []rateLimitSnapshot + activeLimitID() string earliestReset() time.Time unavailableError() error @@ -200,3 +211,67 @@ func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time { } return time.Now().Add(5 * time.Minute) } + +func (s *credentialState) noteSnapshotData() { + s.lastKnownDataAt = time.Now() +} + +func (s credentialState) hasSnapshotData() bool { + return !s.lastKnownDataAt.IsZero() || + s.fiveHourUtilization > 0 || + s.weeklyUtilization > 0 || + !s.fiveHourReset.IsZero() || + !s.weeklyReset.IsZero() || + len(s.rateLimitSnapshots) > 0 +} + +func (s *credentialState) setAvailability(state availabilityState, reason availabilityReason, resetAt time.Time) { + s.availabilityState = state + s.availabilityReason = reason + s.availabilityResetAt = resetAt +} + +func (s credentialState) currentAvailability() availabilityStatus { + now := time.Now() + switch { + case s.unavailable: + return availabilityStatus{ + State: availabilityStateUnavailable, + Reason: availabilityReasonUnknown, + } + case s.availabilityState == availabilityStateTemporarilyBlocked && + (s.availabilityResetAt.IsZero() || now.Before(s.availabilityResetAt)): + reason := s.availabilityReason + if reason == "" { + reason = availabilityReasonUnknown + } + return availabilityStatus{ + State: availabilityStateTemporarilyBlocked, + Reason: reason, + ResetAt: s.availabilityResetAt, + } + case s.hardRateLimited && (s.rateLimitResetAt.IsZero() || now.Before(s.rateLimitResetAt)): + reason := s.availabilityReason + if reason == "" { + reason = availabilityReasonHardRateLimit + } + return availabilityStatus{ + State: availabilityStateRateLimited, + Reason: reason, + ResetAt: s.rateLimitResetAt, + } + case !s.upstreamRejectedUntil.IsZero() && now.Before(s.upstreamRejectedUntil): + return availabilityStatus{ + State: availabilityStateTemporarilyBlocked, + Reason: availabilityReasonUpstreamRejected, + ResetAt: s.upstreamRejectedUntil, + } + case s.consecutivePollFailures > 0: + return availabilityStatus{ + State: availabilityStateTemporarilyBlocked, + Reason: availabilityReasonPollFailed, + } + default: + return availabilityStatus{State: availabilityStateUsable} + } +} diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 61f1314195..977a545e64 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -359,9 +359,14 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { } } } + if snapshots := parseRateLimitSnapshotsFromHeaders(headers); len(snapshots) > 0 { + hadData = true + applyRateLimitSnapshotsLocked(&c.state, snapshots, headers.Get("x-codex-active-limit"), c.state.remotePlanWeight, c.state.accountType) + } if hadData { c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() + c.state.noteSnapshotData() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" @@ -386,6 +391,7 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { c.stateAccess.Lock() c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt + c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -396,6 +402,17 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { func (c *defaultCredential) markUpstreamRejected() {} +func (c *defaultCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) { + c.stateAccess.Lock() + c.state.setAvailability(availabilityStateTemporarilyBlocked, reason, resetAt) + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + c.emitStatusUpdate() +} + func (c *defaultCredential) isUsable() bool { c.retryCredentialReloadIfNeeded() @@ -483,6 +500,12 @@ func (c *defaultCredential) fiveHourUtilization() float64 { return c.state.fiveHourUtilization } +func (c *defaultCredential) hasSnapshotData() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.hasSnapshotData() +} + func (c *defaultCredential) weeklyUtilization() float64 { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -515,6 +538,32 @@ func (c *defaultCredential) isAvailable() bool { return !c.state.unavailable } +func (c *defaultCredential) availabilityStatus() availabilityStatus { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.currentAvailability() +} + +func (c *defaultCredential) rateLimitSnapshots() []rateLimitSnapshot { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if len(c.state.rateLimitSnapshots) == 0 { + return nil + } + snapshots := make([]rateLimitSnapshot, 0, len(c.state.rateLimitSnapshots)) + for _, snapshot := range c.state.rateLimitSnapshots { + snapshots = append(snapshots, cloneRateLimitSnapshot(snapshot)) + } + sortRateLimitSnapshots(snapshots) + return snapshots +} + +func (c *defaultCredential) activeLimitID() string { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.activeLimitID +} + func (c *defaultCredential) unavailableError() error { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -542,6 +591,7 @@ func (c *defaultCredential) markUsagePollAttempted() { func (c *defaultCredential) incrementPollFailures() { c.stateAccess.Lock() c.state.consecutivePollFailures++ + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{}) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -696,17 +746,7 @@ func (c *defaultCredential) pollUsage() { return } - type usageWindow struct { - UsedPercent float64 `json:"used_percent"` - ResetAt int64 `json:"reset_at"` - } - var usageResponse struct { - PlanType string `json:"plan_type"` - RateLimit *struct { - PrimaryWindow *usageWindow `json:"primary_window"` - SecondaryWindow *usageWindow `json:"secondary_window"` - } `json:"rate_limit"` - } + var usageResponse usageRateLimitStatusPayload err = json.NewDecoder(response.Body).Decode(&usageResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) @@ -720,26 +760,11 @@ func (c *defaultCredential) pollUsage() { oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 c.state.usageAPIRetryDelay = 0 - if usageResponse.RateLimit != nil { - if w := usageResponse.RateLimit.PrimaryWindow; w != nil { - c.state.fiveHourUtilization = w.UsedPercent - if w.ResetAt > 0 { - c.state.fiveHourReset = time.Unix(w.ResetAt, 0) - } - } - if w := usageResponse.RateLimit.SecondaryWindow; w != nil { - c.state.weeklyUtilization = w.UsedPercent - if w.ResetAt > 0 { - c.state.weeklyReset = time.Unix(w.ResetAt, 0) - } - } - } - if usageResponse.PlanType != "" { - c.state.accountType = usageResponse.PlanType - } + applyRateLimitSnapshotsLocked(&c.state, snapshotsFromUsagePayload(usageResponse), c.state.activeLimitID, c.state.remotePlanWeight, usageResponse.PlanType) if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } + c.state.noteSnapshotData() if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index f8a73684e3..7553cdf439 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -367,6 +367,7 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { c.stateAccess.Lock() c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt + c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -379,6 +380,18 @@ func (c *externalCredential) markUpstreamRejected() { c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(c.pollInterval)) c.stateAccess.Lock() c.state.upstreamRejectedUntil = time.Now().Add(c.pollInterval) + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil) + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + c.emitStatusUpdate() +} + +func (c *externalCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) { + c.stateAccess.Lock() + c.state.setAvailability(availabilityStateTemporarilyBlocked, reason, resetAt) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -528,10 +541,15 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.remotePlanWeight = value } } + if snapshots := parseRateLimitSnapshotsFromHeaders(headers); len(snapshots) > 0 { + hadData = true + applyRateLimitSnapshotsLocked(&c.state, snapshots, headers.Get("x-codex-active-limit"), c.state.remotePlanWeight, c.state.accountType) + } if hadData { c.state.consecutivePollFailures = 0 c.state.upstreamRejectedUntil = time.Time{} c.state.lastUpdated = time.Now() + c.state.noteSnapshotData() } if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" @@ -670,20 +688,14 @@ func (c *externalCredential) pollUsage() { c.clearPollFailures() return } - if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || + if rawFields["limits"] == nil && (rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || - rawFields["plan_weight"] == nil { + rawFields["plan_weight"] == nil) { c.logger.Error("poll usage for ", c.tag, ": invalid response") c.clearPollFailures() return } - var statusResponse struct { - FiveHourUtilization float64 `json:"five_hour_utilization"` - FiveHourReset int64 `json:"five_hour_reset"` - WeeklyUtilization float64 `json:"weekly_utilization"` - WeeklyReset int64 `json:"weekly_reset"` - PlanWeight float64 `json:"plan_weight"` - } + var statusResponse statusPayload err = json.Unmarshal(body, &statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) @@ -697,16 +709,38 @@ func (c *externalCredential) pollUsage() { oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 c.state.upstreamRejectedUntil = time.Time{} - c.state.fiveHourUtilization = statusResponse.FiveHourUtilization - c.state.weeklyUtilization = statusResponse.WeeklyUtilization - if statusResponse.FiveHourReset > 0 { - c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) - } - if statusResponse.WeeklyReset > 0 { - c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + if len(statusResponse.Limits) > 0 { + applyRateLimitSnapshotsLocked(&c.state, statusResponse.Limits, statusResponse.ActiveLimit, statusResponse.PlanWeight, c.state.accountType) + } else { + c.state.fiveHourUtilization = statusResponse.FiveHourUtilization + c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight + } } - if statusResponse.PlanWeight > 0 { - c.state.remotePlanWeight = statusResponse.PlanWeight + if statusResponse.Availability != nil { + switch availabilityState(statusResponse.Availability.State) { + case availabilityStateRateLimited: + c.state.hardRateLimited = true + if statusResponse.Availability.ResetAt > 0 { + c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0) + } + case availabilityStateTemporarilyBlocked: + resetAt := time.Time{} + if statusResponse.Availability.ResetAt > 0 { + resetAt = time.Unix(statusResponse.Availability.ResetAt, 0) + } + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt) + if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() { + c.state.upstreamRejectedUntil = resetAt + } + } } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false @@ -787,9 +821,9 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr result.duration = time.Since(startTime) return result, E.Cause(err, "decode status frame") } - if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || + if rawFields["limits"] == nil && (rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || - rawFields["plan_weight"] == nil { + rawFields["plan_weight"] == nil) { result.duration = time.Since(startTime) return result, E.New("invalid response") } @@ -806,16 +840,38 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 c.state.upstreamRejectedUntil = time.Time{} - c.state.fiveHourUtilization = statusResponse.FiveHourUtilization - c.state.weeklyUtilization = statusResponse.WeeklyUtilization - if statusResponse.FiveHourReset > 0 { - c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) - } - if statusResponse.WeeklyReset > 0 { - c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + if len(statusResponse.Limits) > 0 { + applyRateLimitSnapshotsLocked(&c.state, statusResponse.Limits, statusResponse.ActiveLimit, statusResponse.PlanWeight, c.state.accountType) + } else { + c.state.fiveHourUtilization = statusResponse.FiveHourUtilization + c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight + } } - if statusResponse.PlanWeight > 0 { - c.state.remotePlanWeight = statusResponse.PlanWeight + if statusResponse.Availability != nil { + switch availabilityState(statusResponse.Availability.State) { + case availabilityStateRateLimited: + c.state.hardRateLimited = true + if statusResponse.Availability.ResetAt > 0 { + c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0) + } + case availabilityStateTemporarilyBlocked: + resetAt := time.Time{} + if statusResponse.Availability.ResetAt > 0 { + resetAt = time.Unix(statusResponse.Availability.ResetAt, 0) + } + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt) + if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() { + c.state.upstreamRejectedUntil = resetAt + } + } } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false @@ -888,6 +944,38 @@ func (c *externalCredential) lastUpdatedTime() time.Time { return c.state.lastUpdated } +func (c *externalCredential) hasSnapshotData() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.hasSnapshotData() +} + +func (c *externalCredential) availabilityStatus() availabilityStatus { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.currentAvailability() +} + +func (c *externalCredential) rateLimitSnapshots() []rateLimitSnapshot { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if len(c.state.rateLimitSnapshots) == 0 { + return nil + } + snapshots := make([]rateLimitSnapshot, 0, len(c.state.rateLimitSnapshots)) + for _, snapshot := range c.state.rateLimitSnapshots { + snapshots = append(snapshots, cloneRateLimitSnapshot(snapshot)) + } + sortRateLimitSnapshots(snapshots) + return snapshots +} + +func (c *externalCredential) activeLimitID() string { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.activeLimitID +} + func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() defer c.stateAccess.Unlock() diff --git a/service/ocm/rate_limit_state.go b/service/ocm/rate_limit_state.go new file mode 100644 index 0000000000..f0e4f34b1b --- /dev/null +++ b/service/ocm/rate_limit_state.go @@ -0,0 +1,384 @@ +package ocm + +import ( + "net/http" + "slices" + "strconv" + "strings" + "time" +) + +type availabilityState string + +const ( + availabilityStateUsable availabilityState = "usable" + availabilityStateRateLimited availabilityState = "rate_limited" + availabilityStateTemporarilyBlocked availabilityState = "temporarily_blocked" + availabilityStateUnavailable availabilityState = "unavailable" + availabilityStateUnknown availabilityState = "unknown" +) + +type availabilityReason string + +const ( + availabilityReasonHardRateLimit availabilityReason = "hard_rate_limit" + availabilityReasonConnectionLimit availabilityReason = "connection_limit" + availabilityReasonPollFailed availabilityReason = "poll_failed" + availabilityReasonUpstreamRejected availabilityReason = "upstream_rejected" + availabilityReasonNoCredentials availabilityReason = "no_credentials" + availabilityReasonUnknown availabilityReason = "unknown" +) + +type availabilityStatus struct { + State availabilityState + Reason availabilityReason + ResetAt time.Time +} + +type availabilityPayload struct { + State string `json:"state"` + Reason string `json:"reason,omitempty"` + ResetAt int64 `json:"reset_at,omitempty"` +} + +func (s availabilityStatus) normalized() availabilityStatus { + if s.State == "" { + s.State = availabilityStateUnknown + } + if s.Reason == "" && s.State != availabilityStateUsable { + s.Reason = availabilityReasonUnknown + } + return s +} + +func (s availabilityStatus) toPayload() *availabilityPayload { + s = s.normalized() + payload := &availabilityPayload{ + State: string(s.State), + } + if s.Reason != "" && s.Reason != availabilityReasonUnknown { + payload.Reason = string(s.Reason) + } + if !s.ResetAt.IsZero() { + payload.ResetAt = s.ResetAt.Unix() + } + return payload +} + +type creditsSnapshot struct { + HasCredits bool `json:"has_credits"` + Unlimited bool `json:"unlimited"` + Balance string `json:"balance,omitempty"` +} + +type rateLimitWindow struct { + UsedPercent float64 `json:"used_percent"` + WindowMinutes int64 `json:"window_minutes,omitempty"` + ResetAt int64 `json:"reset_at,omitempty"` +} + +type rateLimitSnapshot struct { + LimitID string `json:"limit_id,omitempty"` + LimitName string `json:"limit_name,omitempty"` + Primary *rateLimitWindow `json:"primary,omitempty"` + Secondary *rateLimitWindow `json:"secondary,omitempty"` + Credits *creditsSnapshot `json:"credits,omitempty"` + PlanType string `json:"plan_type,omitempty"` +} + +func normalizeStoredLimitID(limitID string) string { + normalized := normalizeRateLimitIdentifier(limitID) + if normalized == "" { + return "" + } + return strings.ReplaceAll(normalized, "-", "_") +} + +func headerLimitID(limitID string) string { + if limitID == "" { + return "codex" + } + return strings.ReplaceAll(normalizeStoredLimitID(limitID), "_", "-") +} + +func defaultRateLimitSnapshot(limitID string) rateLimitSnapshot { + if limitID == "" { + limitID = "codex" + } + return rateLimitSnapshot{LimitID: normalizeStoredLimitID(limitID)} +} + +func cloneCreditsSnapshot(snapshot *creditsSnapshot) *creditsSnapshot { + if snapshot == nil { + return nil + } + cloned := *snapshot + return &cloned +} + +func cloneRateLimitWindow(window *rateLimitWindow) *rateLimitWindow { + if window == nil { + return nil + } + cloned := *window + return &cloned +} + +func cloneRateLimitSnapshot(snapshot rateLimitSnapshot) rateLimitSnapshot { + snapshot.Primary = cloneRateLimitWindow(snapshot.Primary) + snapshot.Secondary = cloneRateLimitWindow(snapshot.Secondary) + snapshot.Credits = cloneCreditsSnapshot(snapshot.Credits) + return snapshot +} + +func sortRateLimitSnapshots(snapshots []rateLimitSnapshot) { + slices.SortFunc(snapshots, func(a, b rateLimitSnapshot) int { + return strings.Compare(a.LimitID, b.LimitID) + }) +} + +func parseHeaderFloat(headers http.Header, name string) (float64, bool) { + value := strings.TrimSpace(headers.Get(name)) + if value == "" { + return 0, false + } + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0, false + } + if !isFinite(parsed) { + return 0, false + } + return parsed, true +} + +func isFinite(value float64) bool { + return !((value != value) || value > 1e308 || value < -1e308) +} + +func parseCreditsSnapshotFromHeaders(headers http.Header) *creditsSnapshot { + hasCreditsValue := strings.TrimSpace(headers.Get("x-codex-credits-has-credits")) + unlimitedValue := strings.TrimSpace(headers.Get("x-codex-credits-unlimited")) + if hasCreditsValue == "" || unlimitedValue == "" { + return nil + } + hasCredits := strings.EqualFold(hasCreditsValue, "true") || hasCreditsValue == "1" + unlimited := strings.EqualFold(unlimitedValue, "true") || unlimitedValue == "1" + return &creditsSnapshot{ + HasCredits: hasCredits, + Unlimited: unlimited, + Balance: strings.TrimSpace(headers.Get("x-codex-credits-balance")), + } +} + +func parseRateLimitWindowFromHeaders(headers http.Header, prefix string, windowName string) *rateLimitWindow { + usedPercent, hasPercent := parseHeaderFloat(headers, prefix+"-"+windowName+"-used-percent") + windowMinutes, hasWindow := parseInt64Header(headers, prefix+"-"+windowName+"-window-minutes") + resetAt, hasReset := parseInt64Header(headers, prefix+"-"+windowName+"-reset-at") + if !hasPercent && !hasWindow && !hasReset { + return nil + } + window := &rateLimitWindow{} + if hasPercent { + window.UsedPercent = usedPercent + } + if hasWindow { + window.WindowMinutes = windowMinutes + } + if hasReset { + window.ResetAt = resetAt + } + return window +} + +func parseRateLimitSnapshotsFromHeaders(headers http.Header) []rateLimitSnapshot { + limitIDs := map[string]struct{}{} + for key := range headers { + lowerKey := strings.ToLower(key) + if strings.HasPrefix(lowerKey, "x-") && strings.Contains(lowerKey, "-primary-") { + limitID := strings.TrimPrefix(lowerKey, "x-") + if suffix := strings.Index(limitID, "-primary-"); suffix > 0 { + limitIDs[normalizeStoredLimitID(limitID[:suffix])] = struct{}{} + } + } + if strings.HasPrefix(lowerKey, "x-") && strings.Contains(lowerKey, "-secondary-") { + limitID := strings.TrimPrefix(lowerKey, "x-") + if suffix := strings.Index(limitID, "-secondary-"); suffix > 0 { + limitIDs[normalizeStoredLimitID(limitID[:suffix])] = struct{}{} + } + } + } + if activeLimit := normalizeStoredLimitID(headers.Get("x-codex-active-limit")); activeLimit != "" { + limitIDs[activeLimit] = struct{}{} + } + if credits := parseCreditsSnapshotFromHeaders(headers); credits != nil { + _ = credits + limitIDs["codex"] = struct{}{} + } + if len(limitIDs) == 0 { + return nil + } + snapshots := make([]rateLimitSnapshot, 0, len(limitIDs)) + for limitID := range limitIDs { + prefix := "x-" + headerLimitID(limitID) + snapshot := defaultRateLimitSnapshot(limitID) + snapshot.LimitName = strings.TrimSpace(headers.Get(prefix + "-limit-name")) + snapshot.Primary = parseRateLimitWindowFromHeaders(headers, prefix, "primary") + snapshot.Secondary = parseRateLimitWindowFromHeaders(headers, prefix, "secondary") + if limitID == "codex" { + snapshot.Credits = parseCreditsSnapshotFromHeaders(headers) + } + if snapshot.Primary == nil && snapshot.Secondary == nil && snapshot.Credits == nil { + continue + } + snapshots = append(snapshots, snapshot) + } + sortRateLimitSnapshots(snapshots) + return snapshots +} + +type usageRateLimitWindowPayload struct { + UsedPercent float64 `json:"used_percent"` + LimitWindowSeconds int64 `json:"limit_window_seconds"` + ResetAt int64 `json:"reset_at"` +} + +type usageRateLimitDetailsPayload struct { + PrimaryWindow *usageRateLimitWindowPayload `json:"primary_window"` + SecondaryWindow *usageRateLimitWindowPayload `json:"secondary_window"` +} + +type usageCreditsPayload struct { + HasCredits bool `json:"has_credits"` + Unlimited bool `json:"unlimited"` + Balance *string `json:"balance"` +} + +type additionalRateLimitPayload struct { + LimitName string `json:"limit_name"` + MeteredFeature string `json:"metered_feature"` + RateLimit *usageRateLimitDetailsPayload `json:"rate_limit"` +} + +type usageRateLimitStatusPayload struct { + PlanType string `json:"plan_type"` + RateLimit *usageRateLimitDetailsPayload `json:"rate_limit"` + Credits *usageCreditsPayload `json:"credits"` + AdditionalRateLimits []additionalRateLimitPayload `json:"additional_rate_limits"` +} + +func windowFromUsagePayload(window *usageRateLimitWindowPayload) *rateLimitWindow { + if window == nil { + return nil + } + result := &rateLimitWindow{ + UsedPercent: window.UsedPercent, + } + if window.LimitWindowSeconds > 0 { + result.WindowMinutes = (window.LimitWindowSeconds + 59) / 60 + } + if window.ResetAt > 0 { + result.ResetAt = window.ResetAt + } + return result +} + +func snapshotsFromUsagePayload(payload usageRateLimitStatusPayload) []rateLimitSnapshot { + snapshots := make([]rateLimitSnapshot, 0, 1+len(payload.AdditionalRateLimits)) + codex := defaultRateLimitSnapshot("codex") + codex.PlanType = payload.PlanType + if payload.RateLimit != nil { + codex.Primary = windowFromUsagePayload(payload.RateLimit.PrimaryWindow) + codex.Secondary = windowFromUsagePayload(payload.RateLimit.SecondaryWindow) + } + if payload.Credits != nil { + codex.Credits = &creditsSnapshot{ + HasCredits: payload.Credits.HasCredits, + Unlimited: payload.Credits.Unlimited, + } + if payload.Credits.Balance != nil { + codex.Credits.Balance = *payload.Credits.Balance + } + } + if codex.Primary != nil || codex.Secondary != nil || codex.Credits != nil || codex.PlanType != "" { + snapshots = append(snapshots, codex) + } + for _, additional := range payload.AdditionalRateLimits { + snapshot := defaultRateLimitSnapshot(additional.MeteredFeature) + snapshot.LimitName = additional.LimitName + snapshot.PlanType = payload.PlanType + if additional.RateLimit != nil { + snapshot.Primary = windowFromUsagePayload(additional.RateLimit.PrimaryWindow) + snapshot.Secondary = windowFromUsagePayload(additional.RateLimit.SecondaryWindow) + } + if snapshot.Primary == nil && snapshot.Secondary == nil { + continue + } + snapshots = append(snapshots, snapshot) + } + sortRateLimitSnapshots(snapshots) + return snapshots +} + +func applyRateLimitSnapshotsLocked(state *credentialState, snapshots []rateLimitSnapshot, activeLimitID string, planWeight float64, planType string) { + if len(snapshots) == 0 { + return + } + if state.rateLimitSnapshots == nil { + state.rateLimitSnapshots = make(map[string]rateLimitSnapshot, len(snapshots)) + } else { + clear(state.rateLimitSnapshots) + } + for _, snapshot := range snapshots { + snapshot = cloneRateLimitSnapshot(snapshot) + if snapshot.LimitID == "" { + snapshot.LimitID = "codex" + } + if snapshot.LimitName == "" && snapshot.LimitID != "codex" { + snapshot.LimitName = strings.ReplaceAll(snapshot.LimitID, "_", "-") + } + if snapshot.PlanType == "" { + snapshot.PlanType = planType + } + state.rateLimitSnapshots[snapshot.LimitID] = snapshot + } + if planWeight > 0 { + state.remotePlanWeight = planWeight + } + if planType != "" { + state.accountType = planType + } + if normalizedActive := normalizeStoredLimitID(activeLimitID); normalizedActive != "" { + state.activeLimitID = normalizedActive + } else if state.activeLimitID == "" { + if _, exists := state.rateLimitSnapshots["codex"]; exists { + state.activeLimitID = "codex" + } else { + for limitID := range state.rateLimitSnapshots { + state.activeLimitID = limitID + break + } + } + } + legacy := state.rateLimitSnapshots["codex"] + if legacy.LimitID == "" && state.activeLimitID != "" { + legacy = state.rateLimitSnapshots[state.activeLimitID] + } + state.fiveHourUtilization = 0 + state.fiveHourReset = time.Time{} + state.weeklyUtilization = 0 + state.weeklyReset = time.Time{} + if legacy.Primary != nil { + state.fiveHourUtilization = legacy.Primary.UsedPercent + if legacy.Primary.ResetAt > 0 { + state.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0) + } + } + if legacy.Secondary != nil { + state.weeklyUtilization = legacy.Secondary.UsedPercent + if legacy.Secondary.ResetAt > 0 { + state.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0) + } + } + state.noteSnapshotData() +} diff --git a/service/ocm/service.go b/service/ocm/service.go index 7c54115a3b..76a63f94bd 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -65,7 +65,6 @@ func writePlainTextError(w http.ResponseWriter, statusCode int, message string) const ( retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" - retryableUsageCode = "credential_usage_exhausted" ) func hasAlternativeCredential(provider credentialProvider, currentCredential Credential, selection credentialSelection) bool { @@ -98,7 +97,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string) } func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { - writeJSONErrorWithCode(w, r, http.StatusServiceUnavailable, "server_error", retryableUsageCode, retryableUsageMessage) + writeJSONErrorWithCode(w, r, http.StatusTooManyRequests, "usage_limit_reached", "", retryableUsageMessage) } func writeNonRetryableCredentialError(w http.ResponseWriter, message string) { @@ -117,6 +116,10 @@ func writeCredentialUnavailableError( writeRetryableUsageError(w, r) return } + if provider != nil && strings.HasPrefix(allRateLimitedError(provider.allCredentials()).Error(), "all credentials rate-limited") { + writeRetryableUsageError(w, r) + return + } writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback)) } diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index 092209e981..92959e5e84 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/json" "net/http" + "reflect" + "slices" "strconv" "strings" "time" @@ -12,11 +14,14 @@ import ( ) type statusPayload struct { - FiveHourUtilization float64 `json:"five_hour_utilization"` - FiveHourReset int64 `json:"five_hour_reset"` - WeeklyUtilization float64 `json:"weekly_utilization"` - WeeklyReset int64 `json:"weekly_reset"` - PlanWeight float64 `json:"plan_weight"` + FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` + WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` + PlanWeight float64 `json:"plan_weight"` + ActiveLimit string `json:"active_limit,omitempty"` + Limits []rateLimitSnapshot `json:"limits,omitempty"` + Availability *availabilityPayload `json:"availability,omitempty"` } type aggregatedStatus struct { @@ -25,6 +30,9 @@ type aggregatedStatus struct { totalWeight float64 fiveHourReset time.Time weeklyReset time.Time + activeLimitID string + limits []rateLimitSnapshot + availability availabilityStatus } func resetToEpoch(t time.Time) int64 { @@ -35,11 +43,7 @@ func resetToEpoch(t time.Time) int64 { } func (s aggregatedStatus) equal(other aggregatedStatus) bool { - return s.fiveHourUtilization == other.fiveHourUtilization && - s.weeklyUtilization == other.weeklyUtilization && - s.totalWeight == other.totalWeight && - resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) && - resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset) + return reflect.DeepEqual(s.toPayload(), other.toPayload()) } func (s aggregatedStatus) toPayload() statusPayload { @@ -49,9 +53,253 @@ func (s aggregatedStatus) toPayload() statusPayload { WeeklyUtilization: s.weeklyUtilization, WeeklyReset: resetToEpoch(s.weeklyReset), PlanWeight: s.totalWeight, + ActiveLimit: s.activeLimitID, + Limits: slices.Clone(s.limits), + Availability: s.availability.toPayload(), } } +type aggregateInput struct { + weight float64 + snapshots []rateLimitSnapshot + activeLimit string + availability availabilityStatus +} + +type snapshotContribution struct { + weight float64 + snapshot rateLimitSnapshot +} + +func aggregateAvailability(inputs []aggregateInput) availabilityStatus { + if len(inputs) == 0 { + return availabilityStatus{ + State: availabilityStateUnavailable, + Reason: availabilityReasonNoCredentials, + } + } + var earliestRateLimited time.Time + var hasRateLimited bool + var bestBlocked availabilityStatus + var hasBlocked bool + var hasUnavailable bool + blockedPriority := func(reason availabilityReason) int { + switch reason { + case availabilityReasonConnectionLimit: + return 3 + case availabilityReasonPollFailed: + return 2 + case availabilityReasonUpstreamRejected: + return 1 + default: + return 0 + } + } + for _, input := range inputs { + availability := input.availability.normalized() + switch availability.State { + case availabilityStateUsable: + return availabilityStatus{State: availabilityStateUsable} + case availabilityStateRateLimited: + hasRateLimited = true + if !availability.ResetAt.IsZero() && (earliestRateLimited.IsZero() || availability.ResetAt.Before(earliestRateLimited)) { + earliestRateLimited = availability.ResetAt + } + case availabilityStateTemporarilyBlocked: + if !hasBlocked || blockedPriority(availability.Reason) > blockedPriority(bestBlocked.Reason) { + bestBlocked = availability + hasBlocked = true + } + if hasBlocked && !availability.ResetAt.IsZero() && (bestBlocked.ResetAt.IsZero() || availability.ResetAt.Before(bestBlocked.ResetAt)) { + bestBlocked.ResetAt = availability.ResetAt + } + case availabilityStateUnavailable: + hasUnavailable = true + } + } + if hasRateLimited { + return availabilityStatus{ + State: availabilityStateRateLimited, + Reason: availabilityReasonHardRateLimit, + ResetAt: earliestRateLimited, + } + } + if hasBlocked { + return bestBlocked + } + if hasUnavailable { + return availabilityStatus{ + State: availabilityStateUnavailable, + Reason: availabilityReasonUnknown, + } + } + return availabilityStatus{ + State: availabilityStateUnknown, + Reason: availabilityReasonUnknown, + } +} + +func aggregateRateLimitWindow(contributions []snapshotContribution, selector func(rateLimitSnapshot) *rateLimitWindow) *rateLimitWindow { + var totalWeight float64 + var totalRemaining float64 + var totalWindowMinutes float64 + var totalResetHours float64 + var resetWeight float64 + now := time.Now() + for _, contribution := range contributions { + window := selector(contribution.snapshot) + if window == nil { + continue + } + totalWeight += contribution.weight + totalRemaining += (100 - window.UsedPercent) * contribution.weight + if window.WindowMinutes > 0 { + totalWindowMinutes += float64(window.WindowMinutes) * contribution.weight + } + if window.ResetAt > 0 { + resetTime := time.Unix(window.ResetAt, 0) + hours := resetTime.Sub(now).Hours() + if hours > 0 { + totalResetHours += hours * contribution.weight + resetWeight += contribution.weight + } + } + } + if totalWeight == 0 { + return nil + } + window := &rateLimitWindow{ + UsedPercent: 100 - totalRemaining/totalWeight, + } + if totalWindowMinutes > 0 { + window.WindowMinutes = int64(totalWindowMinutes / totalWeight) + } + if resetWeight > 0 { + window.ResetAt = now.Add(time.Duration(totalResetHours / resetWeight * float64(time.Hour))).Unix() + } + return window +} + +func aggregateCredits(contributions []snapshotContribution) *creditsSnapshot { + var hasCredits bool + var unlimited bool + var balanceTotal float64 + var hasBalance bool + for _, contribution := range contributions { + if contribution.snapshot.Credits == nil { + continue + } + hasCredits = hasCredits || contribution.snapshot.Credits.HasCredits + unlimited = unlimited || contribution.snapshot.Credits.Unlimited + if balance := strings.TrimSpace(contribution.snapshot.Credits.Balance); balance != "" { + value, err := strconv.ParseFloat(balance, 64) + if err == nil { + balanceTotal += value + hasBalance = true + } + } + } + if !hasCredits && !unlimited && !hasBalance { + return nil + } + credits := &creditsSnapshot{ + HasCredits: hasCredits, + Unlimited: unlimited, + } + if hasBalance && !unlimited { + credits.Balance = strconv.FormatFloat(balanceTotal, 'f', -1, 64) + } + return credits +} + +func aggregateSnapshots(inputs []aggregateInput) []rateLimitSnapshot { + grouped := make(map[string][]snapshotContribution) + for _, input := range inputs { + for _, snapshot := range input.snapshots { + limitID := snapshot.LimitID + if limitID == "" { + limitID = "codex" + } + grouped[limitID] = append(grouped[limitID], snapshotContribution{ + weight: input.weight, + snapshot: snapshot, + }) + } + } + if len(grouped) == 0 { + return nil + } + aggregated := make([]rateLimitSnapshot, 0, len(grouped)) + for limitID, contributions := range grouped { + snapshot := defaultRateLimitSnapshot(limitID) + var bestPlanWeight float64 + for _, contribution := range contributions { + if contribution.snapshot.LimitName != "" && snapshot.LimitName == "" { + snapshot.LimitName = contribution.snapshot.LimitName + } + if contribution.snapshot.PlanType != "" && contribution.weight >= bestPlanWeight { + bestPlanWeight = contribution.weight + snapshot.PlanType = contribution.snapshot.PlanType + } + } + snapshot.Primary = aggregateRateLimitWindow(contributions, func(snapshot rateLimitSnapshot) *rateLimitWindow { + return snapshot.Primary + }) + snapshot.Secondary = aggregateRateLimitWindow(contributions, func(snapshot rateLimitSnapshot) *rateLimitWindow { + return snapshot.Secondary + }) + snapshot.Credits = aggregateCredits(contributions) + if snapshot.Primary == nil && snapshot.Secondary == nil && snapshot.Credits == nil { + continue + } + aggregated = append(aggregated, snapshot) + } + sortRateLimitSnapshots(aggregated) + return aggregated +} + +func selectActiveLimitID(inputs []aggregateInput, snapshots []rateLimitSnapshot) string { + if len(snapshots) == 0 { + return "" + } + weights := make(map[string]float64) + for _, input := range inputs { + if input.activeLimit == "" { + continue + } + weights[normalizeStoredLimitID(input.activeLimit)] += input.weight + } + var ( + bestID string + bestWeight float64 + ) + for limitID, weight := range weights { + if weight > bestWeight { + bestID = limitID + bestWeight = weight + } + } + if bestID != "" { + return bestID + } + for _, snapshot := range snapshots { + if snapshot.LimitID == "codex" { + return "codex" + } + } + return snapshots[0].LimitID +} + +func findSnapshotByLimitID(snapshots []rateLimitSnapshot, limitID string) *rateLimitSnapshot { + for _, snapshot := range snapshots { + if snapshot.LimitID == limitID { + snapshotCopy := snapshot + return &snapshotCopy + } + } + return nil +} + func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -171,74 +419,86 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) aggregatedStatus { - var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 - now := time.Now() - var totalWeightedHoursUntil5hReset, total5hResetWeight float64 - var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 + inputs := make([]aggregateInput, 0, len(provider.allCredentials())) + var totalWeight float64 + var hasSnapshotData bool for _, credential := range provider.allCredentials() { - if !credential.isUsable() { - continue - } if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { continue } if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() { continue } - weight := credential.planWeight() - remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization() - if remaining5h < 0 { - remaining5h = 0 + input := aggregateInput{ + weight: credential.planWeight(), + snapshots: credential.rateLimitSnapshots(), + activeLimit: credential.activeLimitID(), + availability: credential.availabilityStatus(), } - remainingWeekly := credential.weeklyCap() - credential.weeklyUtilization() - if remainingWeekly < 0 { - remainingWeekly = 0 + inputs = append(inputs, input) + if credential.hasSnapshotData() { + hasSnapshotData = true } - totalWeightedRemaining5h += remaining5h * weight - totalWeightedRemainingWeekly += remainingWeekly * weight - totalWeight += weight - - fiveHourReset := credential.fiveHourResetTime() - if !fiveHourReset.IsZero() { - hours := fiveHourReset.Sub(now).Hours() - if hours > 0 { - totalWeightedHoursUntil5hReset += hours * weight - total5hResetWeight += weight + totalWeight += input.weight + } + limits := aggregateSnapshots(inputs) + result := aggregatedStatus{ + totalWeight: totalWeight, + availability: aggregateAvailability(inputs), + limits: limits, + activeLimitID: selectActiveLimitID(inputs, limits), + } + if legacy := findSnapshotByLimitID(result.limits, "codex"); legacy != nil { + if legacy.Primary != nil { + result.fiveHourUtilization = legacy.Primary.UsedPercent + if legacy.Primary.ResetAt > 0 { + result.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0) } } - weeklyReset := credential.weeklyResetTime() - if !weeklyReset.IsZero() { - hours := weeklyReset.Sub(now).Hours() - if hours > 0 { - totalWeightedHoursUntilWeeklyReset += hours * weight - totalWeeklyResetWeight += weight + if legacy.Secondary != nil { + result.weeklyUtilization = legacy.Secondary.UsedPercent + if legacy.Secondary.ResetAt > 0 { + result.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0) } } - } - if totalWeight == 0 { - return aggregatedStatus{ - fiveHourUtilization: 100, - weeklyUtilization: 100, + } else if legacy := findSnapshotByLimitID(result.limits, result.activeLimitID); legacy != nil { + if legacy.Primary != nil { + result.fiveHourUtilization = legacy.Primary.UsedPercent + if legacy.Primary.ResetAt > 0 { + result.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0) + } + } + if legacy.Secondary != nil { + result.weeklyUtilization = legacy.Secondary.UsedPercent + if legacy.Secondary.ResetAt > 0 { + result.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0) + } } } - result := aggregatedStatus{ - fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight, - weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight: totalWeight, - } - if total5hResetWeight > 0 { - avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight - result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour))) - } - if totalWeeklyResetWeight > 0 { - avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight - result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + if len(result.limits) == 0 && !hasSnapshotData { + result.fiveHourUtilization = 100 + result.weeklyUtilization = 100 } return result } func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) { status := s.computeAggregatedUtilization(provider, userConfig) + for key := range headers { + lowerKey := strings.ToLower(key) + if lowerKey == "x-codex-active-limit" || + strings.HasSuffix(lowerKey, "-primary-used-percent") || + strings.HasSuffix(lowerKey, "-primary-window-minutes") || + strings.HasSuffix(lowerKey, "-primary-reset-at") || + strings.HasSuffix(lowerKey, "-secondary-used-percent") || + strings.HasSuffix(lowerKey, "-secondary-window-minutes") || + strings.HasSuffix(lowerKey, "-secondary-reset-at") || + strings.HasSuffix(lowerKey, "-limit-name") || + strings.HasPrefix(lowerKey, "x-codex-credits-") { + headers.Del(key) + } + } + headers.Set("x-codex-active-limit", headerLimitID(status.activeLimitID)) headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64)) headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64)) if !status.fiveHourReset.IsZero() { @@ -254,25 +514,34 @@ func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentia if status.totalWeight > 0 { headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } - rateLimitSuffixes := [...]string{ - "-primary-used-percent", - "-primary-reset-at", - "-secondary-used-percent", - "-secondary-reset-at", - "-secondary-window-minutes", - "-limit-name", - } - for key := range headers { - lowerKey := strings.ToLower(key) - if !strings.HasPrefix(lowerKey, "x-") { - continue + for _, snapshot := range status.limits { + prefix := "x-" + headerLimitID(snapshot.LimitID) + if snapshot.Primary != nil { + headers.Set(prefix+"-primary-used-percent", strconv.FormatFloat(snapshot.Primary.UsedPercent, 'f', 2, 64)) + if snapshot.Primary.WindowMinutes > 0 { + headers.Set(prefix+"-primary-window-minutes", strconv.FormatInt(snapshot.Primary.WindowMinutes, 10)) + } + if snapshot.Primary.ResetAt > 0 { + headers.Set(prefix+"-primary-reset-at", strconv.FormatInt(snapshot.Primary.ResetAt, 10)) + } } - for _, suffix := range rateLimitSuffixes { - if strings.HasSuffix(lowerKey, suffix) { - if strings.TrimSuffix(lowerKey, suffix) != "x-codex" { - headers.Del(key) - } - break + if snapshot.Secondary != nil { + headers.Set(prefix+"-secondary-used-percent", strconv.FormatFloat(snapshot.Secondary.UsedPercent, 'f', 2, 64)) + if snapshot.Secondary.WindowMinutes > 0 { + headers.Set(prefix+"-secondary-window-minutes", strconv.FormatInt(snapshot.Secondary.WindowMinutes, 10)) + } + if snapshot.Secondary.ResetAt > 0 { + headers.Set(prefix+"-secondary-reset-at", strconv.FormatInt(snapshot.Secondary.ResetAt, 10)) + } + } + if snapshot.LimitName != "" { + headers.Set(prefix+"-limit-name", snapshot.LimitName) + } + if snapshot.LimitID == "codex" && snapshot.Credits != nil { + headers.Set("x-codex-credits-has-credits", strconv.FormatBool(snapshot.Credits.HasCredits)) + headers.Set("x-codex-credits-unlimited", strconv.FormatBool(snapshot.Credits.Unlimited)) + if snapshot.Credits.Balance != "" { + headers.Set("x-codex-credits-balance", snapshot.Credits.Balance) } } } diff --git a/service/ocm/service_status_test.go b/service/ocm/service_status_test.go new file mode 100644 index 0000000000..ba7e9324ad --- /dev/null +++ b/service/ocm/service_status_test.go @@ -0,0 +1,220 @@ +package ocm + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/observable" +) + +type testCredential struct { + tag string + external bool + available bool + usable bool + hasData bool + fiveHour float64 + weekly float64 + fiveHourCapV float64 + weeklyCapV float64 + weight float64 + fiveReset time.Time + weeklyReset time.Time + availability availabilityStatus + activeLimit string + snapshots []rateLimitSnapshot +} + +func (c *testCredential) tagName() string { return c.tag } +func (c *testCredential) isAvailable() bool { return c.available } +func (c *testCredential) isUsable() bool { return c.usable } +func (c *testCredential) isExternal() bool { return c.external } +func (c *testCredential) hasSnapshotData() bool { return c.hasData } +func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour } +func (c *testCredential) weeklyUtilization() float64 { return c.weekly } +func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV } +func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV } +func (c *testCredential) planWeight() float64 { return c.weight } +func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset } +func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset } +func (c *testCredential) markRateLimited(time.Time) {} +func (c *testCredential) markUpstreamRejected() {} +func (c *testCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) { + c.availability = availabilityStatus{State: availabilityStateTemporarilyBlocked, Reason: reason, ResetAt: resetAt} +} +func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability } +func (c *testCredential) rateLimitSnapshots() []rateLimitSnapshot { + return slicesCloneSnapshots(c.snapshots) +} +func (c *testCredential) activeLimitID() string { return c.activeLimit } +func (c *testCredential) earliestReset() time.Time { return c.fiveReset } +func (c *testCredential) unavailableError() error { return nil } +func (c *testCredential) getAccessToken() (string, error) { return "", nil } +func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) { + return nil, nil +} +func (c *testCredential) updateStateFromHeaders(http.Header) {} +func (c *testCredential) wrapRequestContext(context.Context) *credentialRequestContext { return nil } +func (c *testCredential) interruptConnections() {} +func (c *testCredential) setOnBecameUnusable(func()) {} +func (c *testCredential) setStatusSubscriber(*observable.Subscriber[struct{}]) {} +func (c *testCredential) start() error { return nil } +func (c *testCredential) pollUsage() {} +func (c *testCredential) lastUpdatedTime() time.Time { return time.Now() } +func (c *testCredential) pollBackoff(time.Duration) time.Duration { return 0 } +func (c *testCredential) usageTrackerOrNil() *AggregatedUsage { return nil } +func (c *testCredential) httpClient() *http.Client { return nil } +func (c *testCredential) close() {} +func (c *testCredential) ocmDialer() N.Dialer { return nil } +func (c *testCredential) ocmIsAPIKeyMode() bool { return false } +func (c *testCredential) ocmGetAccountID() string { return "" } +func (c *testCredential) ocmGetBaseURL() string { return "" } + +func slicesCloneSnapshots(snapshots []rateLimitSnapshot) []rateLimitSnapshot { + if len(snapshots) == 0 { + return nil + } + cloned := make([]rateLimitSnapshot, 0, len(snapshots)) + for _, snapshot := range snapshots { + cloned = append(cloned, cloneRateLimitSnapshot(snapshot)) + } + return cloned +} + +type testProvider struct { + credentials []Credential +} + +func (p *testProvider) selectCredential(string, credentialSelection) (Credential, bool, error) { + return nil, false, nil +} +func (p *testProvider) onRateLimited(string, Credential, time.Time, credentialSelection) Credential { + return nil +} +func (p *testProvider) linkProviderInterrupt(Credential, credentialSelection, func()) func() bool { + return func() bool { return true } +} +func (p *testProvider) pollIfStale() {} +func (p *testProvider) pollCredentialIfStale(Credential) {} +func (p *testProvider) allCredentials() []Credential { return p.credentials } +func (p *testProvider) close() {} + +func TestComputeAggregatedUtilizationPreservesStoredSnapshots(t *testing.T) { + t.Parallel() + + service := &Service{} + status := service.computeAggregatedUtilization(&testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: false, + hasData: true, + weight: 1, + activeLimit: "codex", + availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)}, + snapshots: []rateLimitSnapshot{ + { + LimitID: "codex", + Primary: &rateLimitWindow{UsedPercent: 44, WindowMinutes: 300, ResetAt: time.Now().Add(time.Hour).Unix()}, + Secondary: &rateLimitWindow{UsedPercent: 12, WindowMinutes: 10080, ResetAt: time.Now().Add(24 * time.Hour).Unix()}, + }, + }, + }, + }}, nil) + + if status.fiveHourUtilization != 44 || status.weeklyUtilization != 12 { + t.Fatalf("expected stored snapshot utilization, got 5h=%v weekly=%v", status.fiveHourUtilization, status.weeklyUtilization) + } + if status.availability.State != availabilityStateRateLimited { + t.Fatalf("expected rate-limited availability, got %#v", status.availability) + } +} + +func TestRewriteResponseHeadersIncludesAdditionalLimitFamiliesAndCredits(t *testing.T) { + t.Parallel() + + service := &Service{} + headers := make(http.Header) + service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: true, + hasData: true, + weight: 1, + activeLimit: "codex_other", + availability: availabilityStatus{State: availabilityStateUsable}, + snapshots: []rateLimitSnapshot{ + { + LimitID: "codex", + Primary: &rateLimitWindow{UsedPercent: 20, WindowMinutes: 300, ResetAt: time.Now().Add(time.Hour).Unix()}, + Secondary: &rateLimitWindow{UsedPercent: 40, WindowMinutes: 10080, ResetAt: time.Now().Add(24 * time.Hour).Unix()}, + Credits: &creditsSnapshot{HasCredits: true, Unlimited: false, Balance: "12"}, + }, + { + LimitID: "codex_other", + LimitName: "codex-other", + Primary: &rateLimitWindow{UsedPercent: 60, WindowMinutes: 60, ResetAt: time.Now().Add(30 * time.Minute).Unix()}, + }, + }, + }, + }}, nil) + + if headers.Get("x-codex-active-limit") != "codex-other" { + t.Fatalf("expected active limit header, got %q", headers.Get("x-codex-active-limit")) + } + if headers.Get("x-codex-other-primary-used-percent") == "" { + t.Fatal("expected additional rate-limit family header") + } + if headers.Get("x-codex-credits-balance") != "12" { + t.Fatalf("expected credits balance header, got %q", headers.Get("x-codex-credits-balance")) + } +} + +func TestHandleWebSocketErrorEventConnectionLimitDoesNotUseRateLimitPath(t *testing.T) { + t.Parallel() + + credential := &testCredential{availability: availabilityStatus{State: availabilityStateUsable}} + service := &Service{} + service.handleWebSocketErrorEvent([]byte(`{"type":"error","status_code":400,"error":{"code":"websocket_connection_limit_reached"}}`), credential) + + if credential.availability.State != availabilityStateTemporarilyBlocked || credential.availability.Reason != availabilityReasonConnectionLimit { + t.Fatalf("expected temporary connection limit block, got %#v", credential.availability) + } +} + +func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/v1/responses", nil) + provider := &testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: false, + hasData: true, + weight: 1, + availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)}, + snapshots: []rateLimitSnapshot{{LimitID: "codex", Primary: &rateLimitWindow{UsedPercent: 80}}}, + }, + }} + + writeCredentialUnavailableError(recorder, request, provider, provider.credentials[0], credentialSelection{}, "all credentials rate-limited") + + if recorder.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429, got %d", recorder.Code) + } + var body map[string]map[string]string + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatal(err) + } + if body["error"]["type"] != "usage_limit_reached" { + t.Fatalf("expected usage_limit_reached type, got %#v", body) + } +} diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 96d7e58d3a..4c552cbf5b 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -430,9 +430,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe s.handleWebSocketRateLimitsEvent(data, selectedCredential) continue case "error": - if event.StatusCode == http.StatusTooManyRequests { - s.handleWebSocketErrorRateLimited(data, selectedCredential) - } + s.handleWebSocketErrorEvent(data, selectedCredential) case "response.completed": if usageTracker != nil { select { @@ -460,17 +458,22 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential Credential) { var rateLimitsEvent struct { - RateLimits struct { + MeteredLimitName string `json:"metered_limit_name"` + LimitName string `json:"limit_name"` + RateLimits struct { Primary *struct { - UsedPercent float64 `json:"used_percent"` - ResetAt int64 `json:"reset_at"` + UsedPercent float64 `json:"used_percent"` + WindowMinutes int64 `json:"window_minutes"` + ResetAt int64 `json:"reset_at"` } `json:"primary"` Secondary *struct { - UsedPercent float64 `json:"used_percent"` - ResetAt int64 `json:"reset_at"` + UsedPercent float64 `json:"used_percent"` + WindowMinutes int64 `json:"window_minutes"` + ResetAt int64 `json:"reset_at"` } `json:"secondary"` } `json:"rate_limits"` - PlanWeight float64 `json:"plan_weight"` + Credits *creditsSnapshot `json:"credits"` + PlanWeight float64 `json:"plan_weight"` } err := json.Unmarshal(data, &rateLimitsEvent) if err != nil { @@ -478,17 +481,41 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential } headers := make(http.Header) - headers.Set("x-codex-active-limit", "codex") + limitID := rateLimitsEvent.MeteredLimitName + if limitID == "" { + limitID = rateLimitsEvent.LimitName + } + if limitID == "" { + limitID = "codex" + } + headerLimit := headerLimitID(limitID) + headers.Set("x-codex-active-limit", headerLimit) if w := rateLimitsEvent.RateLimits.Primary; w != nil { - headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) + headers.Set("x-"+headerLimit+"-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) + if w.WindowMinutes > 0 { + headers.Set("x-"+headerLimit+"-primary-window-minutes", strconv.FormatInt(w.WindowMinutes, 10)) + } if w.ResetAt > 0 { - headers.Set("x-codex-primary-reset-at", strconv.FormatInt(w.ResetAt, 10)) + headers.Set("x-"+headerLimit+"-primary-reset-at", strconv.FormatInt(w.ResetAt, 10)) } } if w := rateLimitsEvent.RateLimits.Secondary; w != nil { - headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) + headers.Set("x-"+headerLimit+"-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64)) + if w.WindowMinutes > 0 { + headers.Set("x-"+headerLimit+"-secondary-window-minutes", strconv.FormatInt(w.WindowMinutes, 10)) + } if w.ResetAt > 0 { - headers.Set("x-codex-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10)) + headers.Set("x-"+headerLimit+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10)) + } + } + if rateLimitsEvent.LimitName != "" { + headers.Set("x-"+headerLimit+"-limit-name", rateLimitsEvent.LimitName) + } + if rateLimitsEvent.Credits != nil && normalizeStoredLimitID(limitID) == "codex" { + headers.Set("x-codex-credits-has-credits", strconv.FormatBool(rateLimitsEvent.Credits.HasCredits)) + headers.Set("x-codex-credits-unlimited", strconv.FormatBool(rateLimitsEvent.Credits.Unlimited)) + if rateLimitsEvent.Credits.Balance != "" { + headers.Set("x-codex-credits-balance", rateLimitsEvent.Credits.Balance) } } if rateLimitsEvent.PlanWeight > 0 { @@ -497,14 +524,25 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential selectedCredential.updateStateFromHeaders(headers) } -func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredential Credential) { +func (s *Service) handleWebSocketErrorEvent(data []byte, selectedCredential Credential) { var errorEvent struct { - Headers map[string]string `json:"headers"` + StatusCode int `json:"status_code"` + Headers map[string]string `json:"headers"` + Error struct { + Code string `json:"code"` + } `json:"error"` } err := json.Unmarshal(data, &errorEvent) if err != nil { return } + if errorEvent.StatusCode == http.StatusBadRequest && errorEvent.Error.Code == "websocket_connection_limit_reached" { + selectedCredential.markTemporarilyBlocked(availabilityReasonConnectionLimit, time.Now().Add(time.Minute)) + return + } + if errorEvent.StatusCode != http.StatusTooManyRequests { + return + } headers := make(http.Header) for key, value := range errorEvent.Headers { headers.Set(key, value) @@ -515,10 +553,14 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia } func writeWebSocketAggregatedStatus(clientConn net.Conn, clientWriteAccess *sync.Mutex, status aggregatedStatus) error { - data := buildSyntheticRateLimitsEvent(status) clientWriteAccess.Lock() defer clientWriteAccess.Unlock() - return wsutil.WriteServerMessage(clientConn, ws.OpText, data) + for _, data := range buildSyntheticRateLimitsEvents(status) { + if err := wsutil.WriteServerMessage(clientConn, ws.OpText, data); err != nil { + return err + } + } + return nil } func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, sessionClosed <-chan struct{}, firstRealRequest <-chan struct{}, provider credentialProvider, userConfig *option.OCMUser) { @@ -573,34 +615,106 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn } } -func buildSyntheticRateLimitsEvent(status aggregatedStatus) []byte { +func buildSyntheticRateLimitsEvents(status aggregatedStatus) [][]byte { type rateLimitWindow struct { - UsedPercent float64 `json:"used_percent"` - ResetAt int64 `json:"reset_at,omitempty"` + UsedPercent float64 `json:"used_percent"` + WindowMinutes int64 `json:"window_minutes,omitempty"` + ResetAt int64 `json:"reset_at,omitempty"` + } + type creditsEvent struct { + HasCredits bool `json:"has_credits"` + Unlimited bool `json:"unlimited"` + Balance string `json:"balance,omitempty"` } - event := struct { + type eventPayload struct { Type string `json:"type"` RateLimits struct { Primary *rateLimitWindow `json:"primary,omitempty"` Secondary *rateLimitWindow `json:"secondary,omitempty"` } `json:"rate_limits"` - LimitName string `json:"limit_name"` - PlanWeight float64 `json:"plan_weight,omitempty"` - }{ - Type: "codex.rate_limits", - LimitName: "codex", - PlanWeight: status.totalWeight, - } - event.RateLimits.Primary = &rateLimitWindow{ + MeteredLimitName string `json:"metered_limit_name,omitempty"` + LimitName string `json:"limit_name,omitempty"` + Credits *creditsEvent `json:"credits,omitempty"` + PlanWeight float64 `json:"plan_weight,omitempty"` + } + buildEvent := func(snapshot rateLimitSnapshot, primary *rateLimitWindow, secondary *rateLimitWindow) []byte { + event := eventPayload{ + Type: "codex.rate_limits", + MeteredLimitName: snapshot.LimitID, + LimitName: snapshot.LimitName, + PlanWeight: status.totalWeight, + } + if event.MeteredLimitName == "" { + event.MeteredLimitName = "codex" + } + if event.LimitName == "" { + event.LimitName = strings.ReplaceAll(event.MeteredLimitName, "_", "-") + } + event.RateLimits.Primary = primary + event.RateLimits.Secondary = secondary + if snapshot.Credits != nil { + event.Credits = &creditsEvent{ + HasCredits: snapshot.Credits.HasCredits, + Unlimited: snapshot.Credits.Unlimited, + Balance: snapshot.Credits.Balance, + } + } + data, _ := json.Marshal(event) + return data + } + defaultPrimary := &rateLimitWindow{ UsedPercent: status.fiveHourUtilization, ResetAt: resetToEpoch(status.fiveHourReset), } - event.RateLimits.Secondary = &rateLimitWindow{ + defaultSecondary := &rateLimitWindow{ UsedPercent: status.weeklyUtilization, ResetAt: resetToEpoch(status.weeklyReset), } - data, _ := json.Marshal(event) - return data + events := make([][]byte, 0, 1+len(status.limits)) + if snapshot := findSnapshotByLimitID(status.limits, "codex"); snapshot != nil { + primary := defaultPrimary + if snapshot.Primary != nil { + primary = &rateLimitWindow{ + UsedPercent: snapshot.Primary.UsedPercent, + WindowMinutes: snapshot.Primary.WindowMinutes, + ResetAt: snapshot.Primary.ResetAt, + } + } + secondary := defaultSecondary + if snapshot.Secondary != nil { + secondary = &rateLimitWindow{ + UsedPercent: snapshot.Secondary.UsedPercent, + WindowMinutes: snapshot.Secondary.WindowMinutes, + ResetAt: snapshot.Secondary.ResetAt, + } + } + events = append(events, buildEvent(*snapshot, primary, secondary)) + } else { + events = append(events, buildEvent(rateLimitSnapshot{LimitID: "codex", LimitName: "codex"}, defaultPrimary, defaultSecondary)) + } + for _, snapshot := range status.limits { + if snapshot.LimitID == "codex" { + continue + } + var primary *rateLimitWindow + if snapshot.Primary != nil { + primary = &rateLimitWindow{ + UsedPercent: snapshot.Primary.UsedPercent, + WindowMinutes: snapshot.Primary.WindowMinutes, + ResetAt: snapshot.Primary.ResetAt, + } + } + var secondary *rateLimitWindow + if snapshot.Secondary != nil { + secondary = &rateLimitWindow{ + UsedPercent: snapshot.Secondary.UsedPercent, + WindowMinutes: snapshot.Secondary.WindowMinutes, + ResetAt: snapshot.Secondary.ResetAt, + } + } + events = append(events, buildEvent(snapshot, primary, secondary)) + } + return events } func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) { From 6721dff48a558750270b0d1804c217e9f126b648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 12:33:53 +0800 Subject: [PATCH 83/96] fix(ocm): send rate limit status immediately on WebSocket connect Codex CLI ignores x-codex-* headers in the WebSocket upgrade response and only reads rate limits from in-band codex.rate_limits events. Previously, the first synthetic event was gated by firstRealRequest (after warmup), delaying usage display. Now send aggregated status right after subscribing, so the client sees rate limits before the first turn begins. --- service/ocm/service_websocket.go | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 4c552cbf5b..4f76179ef6 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -570,8 +570,11 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn } defer s.statusObserver.UnSubscribe(subscription) - var last aggregatedStatus - hasLast := false + last := s.computeAggregatedUtilization(provider, userConfig) + err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, last) + if err != nil { + return + } for { select { @@ -582,13 +585,6 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn case <-sessionClosed: return case <-firstRealRequest: - current := s.computeAggregatedUtilization(provider, userConfig) - err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current) - if err != nil { - return - } - last = current - hasLast = true firstRealRequest = nil case <-subscription: for { @@ -599,9 +595,6 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn } } drained: - if !hasLast { - continue - } current := s.computeAggregatedUtilization(provider, userConfig) if current.equal(last) { continue From 1774d98793f4b3e5cbfa98689387b97dccc391cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 14:01:24 +0800 Subject: [PATCH 84/96] fix(ccm,ocm): restore fixed usage polling Remove the poll_interval config surface from CCM and OCM so both services fall back to the built-in 1h polling cadence again. Also isolate CCM credential lock mocking per test instance so the access-token refresh tests stop racing on shared global state. --- docs/configuration/service/ccm.md | 13 +++---------- docs/configuration/service/ccm.zh.md | 13 +++---------- docs/configuration/service/ocm.md | 13 +++---------- docs/configuration/service/ocm.zh.md | 13 +++---------- option/ccm.go | 10 ++++------ option/ocm.go | 10 ++++------ service/ccm/credential_builder.go | 3 +-- service/ccm/credential_default.go | 10 +++++++--- service/ccm/credential_default_test.go | 6 +----- service/ccm/credential_external.go | 11 ++--------- service/ccm/credential_provider.go | 11 +++-------- service/ccm/test_helpers_test.go | 1 + service/ocm/credential_builder.go | 3 +-- service/ocm/credential_external.go | 11 ++--------- service/ocm/credential_provider.go | 11 +++-------- 15 files changed, 41 insertions(+), 98 deletions(-) diff --git a/docs/configuration/service/ccm.md b/docs/configuration/service/ccm.md index 5823901395..d037fb52ad 100644 --- a/docs/configuration/service/ccm.md +++ b/docs/configuration/service/ccm.md @@ -104,8 +104,7 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de "tag": "pool", "type": "balancer", "strategy": "", - "credentials": ["a", "b"], - "poll_interval": "60s" + "credentials": ["a", "b"] } ``` @@ -113,7 +112,6 @@ Assigns sessions to default credentials based on the selected strategy. Sessions - `strategy`: Selection strategy. One of `least_used` `round_robin` `random` `fallback`. `least_used` will be used by default. - `credentials`: ==Required== List of default credential tags. -- `poll_interval`: How often to poll upstream usage API. Default `60s`. ##### Fallback Strategy @@ -122,15 +120,13 @@ Assigns sessions to default credentials based on the selected strategy. Sessions "tag": "backup", "type": "balancer", "strategy": "fallback", - "credentials": ["a", "b"], - "poll_interval": "30s" + "credentials": ["a", "b"] } ``` A balancer with `strategy: "fallback"` uses credentials in order. It falls through to the next when the current one is exhausted. - `credentials`: ==Required== Ordered list of default credential tags. -- `poll_interval`: How often to poll upstream usage API. Default `60s`. ##### External Credential @@ -144,8 +140,7 @@ A balancer with `strategy: "fallback"` uses credentials in order. It falls throu "token": "", "reverse": false, "detour": "", - "usages_path": "", - "poll_interval": "30m" + "usages_path": "" } ``` @@ -158,7 +153,6 @@ Proxies requests through a remote CCM instance instead of using a local OAuth cr - `reverse`: Enable connector mode. Requires `url`. A connector dials out to `/ccm/v1/reverse` on the remote instance and cannot serve local requests directly. When `url` is set without `reverse`, the credential proxies requests through the remote instance normally and prefers an established reverse connection when one is available. - `detour`: Outbound tag for connecting to the remote instance. - `usages_path`: Optional usage tracking file. -- `poll_interval`: How often to poll the remote status endpoint. Default `30m`. #### usages_path @@ -290,7 +284,6 @@ claude { "tag": "pool", "type": "balancer", - "poll_interval": "60s", "credentials": ["a", "b"] } ], diff --git a/docs/configuration/service/ccm.zh.md b/docs/configuration/service/ccm.zh.md index 586cb5bb15..8fe2e1373a 100644 --- a/docs/configuration/service/ccm.zh.md +++ b/docs/configuration/service/ccm.zh.md @@ -104,8 +104,7 @@ Claude Code OAuth 凭据文件的路径。 "tag": "pool", "type": "balancer", "strategy": "", - "credentials": ["a", "b"], - "poll_interval": "60s" + "credentials": ["a", "b"] } ``` @@ -113,7 +112,6 @@ Claude Code OAuth 凭据文件的路径。 - `strategy`:选择策略。可选值:`least_used` `round_robin` `random` `fallback`。默认使用 `least_used`。 - `credentials`:==必填== 默认凭据标签列表。 -- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 ##### 回退策略 @@ -122,15 +120,13 @@ Claude Code OAuth 凭据文件的路径。 "tag": "backup", "type": "balancer", "strategy": "fallback", - "credentials": ["a", "b"], - "poll_interval": "30s" + "credentials": ["a", "b"] } ``` 将 `strategy` 设为 `fallback` 的均衡凭据会按顺序使用凭据。当前凭据耗尽后切换到下一个。 - `credentials`:==必填== 有序的默认凭据标签列表。 -- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 ##### 外部凭据 @@ -144,8 +140,7 @@ Claude Code OAuth 凭据文件的路径。 "token": "", "reverse": false, "detour": "", - "usages_path": "", - "poll_interval": "30m" + "usages_path": "" } ``` @@ -158,7 +153,6 @@ Claude Code OAuth 凭据文件的路径。 - `reverse`:启用连接器模式。要求设置 `url`。启用后,此凭据会主动拨出到远程实例的 `/ccm/v1/reverse`,且不能直接为本地请求提供服务。当设置了 `url` 但未启用 `reverse` 时,此凭据会正常通过远程实例转发请求,并在反向连接建立后优先使用该反向连接。 - `detour`:用于连接远程实例的出站标签。 - `usages_path`:可选的使用跟踪文件。 -- `poll_interval`:轮询远程状态端点的间隔。默认 `30m`。 #### usages_path @@ -290,7 +284,6 @@ claude { "tag": "pool", "type": "balancer", - "poll_interval": "60s", "credentials": ["a", "b"] } ], diff --git a/docs/configuration/service/ocm.md b/docs/configuration/service/ocm.md index 43027d9dbb..e47dfdc33a 100644 --- a/docs/configuration/service/ocm.md +++ b/docs/configuration/service/ocm.md @@ -100,8 +100,7 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de "tag": "pool", "type": "balancer", "strategy": "", - "credentials": ["a", "b"], - "poll_interval": "60s" + "credentials": ["a", "b"] } ``` @@ -109,7 +108,6 @@ Assigns sessions to default credentials based on the selected strategy. Sessions - `strategy`: Selection strategy. One of `least_used` `round_robin` `random` `fallback`. `least_used` will be used by default. - `credentials`: ==Required== List of default credential tags. -- `poll_interval`: How often to poll upstream usage API. Default `60s`. ##### Fallback Strategy @@ -118,15 +116,13 @@ Assigns sessions to default credentials based on the selected strategy. Sessions "tag": "backup", "type": "balancer", "strategy": "fallback", - "credentials": ["a", "b"], - "poll_interval": "30s" + "credentials": ["a", "b"] } ``` A balancer with `strategy: "fallback"` uses credentials in order. It falls through to the next when the current one is exhausted. - `credentials`: ==Required== Ordered list of default credential tags. -- `poll_interval`: How often to poll upstream usage API. Default `60s`. ##### External Credential @@ -140,8 +136,7 @@ A balancer with `strategy: "fallback"` uses credentials in order. It falls throu "token": "", "reverse": false, "detour": "", - "usages_path": "", - "poll_interval": "30m" + "usages_path": "" } ``` @@ -154,7 +149,6 @@ Proxies requests through a remote OCM instance instead of using a local OAuth cr - `reverse`: Enable connector mode. Requires `url`. A connector dials out to `/ocm/v1/reverse` on the remote instance and cannot serve local requests directly. When `url` is set without `reverse`, the credential proxies requests through the remote instance normally and prefers an established reverse connection when one is available. - `detour`: Outbound tag for connecting to the remote instance. - `usages_path`: Optional usage tracking file. -- `poll_interval`: How often to poll the remote status endpoint. Default `30m`. #### usages_path @@ -342,7 +336,6 @@ codex --profile ocm { "tag": "pool", "type": "balancer", - "poll_interval": "60s", "credentials": ["a", "b"] } ], diff --git a/docs/configuration/service/ocm.zh.md b/docs/configuration/service/ocm.zh.md index 2d06206f07..76a5067ddb 100644 --- a/docs/configuration/service/ocm.zh.md +++ b/docs/configuration/service/ocm.zh.md @@ -100,8 +100,7 @@ OpenAI OAuth 凭据文件的路径。 "tag": "pool", "type": "balancer", "strategy": "", - "credentials": ["a", "b"], - "poll_interval": "60s" + "credentials": ["a", "b"] } ``` @@ -109,7 +108,6 @@ OpenAI OAuth 凭据文件的路径。 - `strategy`:选择策略。可选值:`least_used` `round_robin` `random` `fallback`。默认使用 `least_used`。 - `credentials`:==必填== 默认凭据标签列表。 -- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 ##### 回退策略 @@ -118,15 +116,13 @@ OpenAI OAuth 凭据文件的路径。 "tag": "backup", "type": "balancer", "strategy": "fallback", - "credentials": ["a", "b"], - "poll_interval": "30s" + "credentials": ["a", "b"] } ``` 将 `strategy` 设为 `fallback` 的均衡凭据会按顺序使用凭据。当前凭据耗尽后切换到下一个。 - `credentials`:==必填== 有序的默认凭据标签列表。 -- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 ##### 外部凭据 @@ -140,8 +136,7 @@ OpenAI OAuth 凭据文件的路径。 "token": "", "reverse": false, "detour": "", - "usages_path": "", - "poll_interval": "30m" + "usages_path": "" } ``` @@ -154,7 +149,6 @@ OpenAI OAuth 凭据文件的路径。 - `reverse`:启用连接器模式。要求设置 `url`。启用后,此凭据会主动拨出到远程实例的 `/ocm/v1/reverse`,且不能直接为本地请求提供服务。当设置了 `url` 但未启用 `reverse` 时,此凭据会正常通过远程实例转发请求,并在反向连接建立后优先使用该反向连接。 - `detour`:用于连接远程实例的出站标签。 - `usages_path`:可选的使用跟踪文件。 -- `poll_interval`:轮询远程状态端点的间隔。默认 `30m`。 #### usages_path @@ -343,7 +337,6 @@ codex --profile ocm { "tag": "pool", "type": "balancer", - "poll_interval": "60s", "credentials": ["a", "b"] } ], diff --git a/option/ccm.go b/option/ccm.go index 481068e617..1079bad0f1 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -89,16 +89,14 @@ type CCMDefaultCredentialOptions struct { type CCMBalancerCredentialOptions struct { Strategy string `json:"strategy,omitempty"` Credentials badoption.Listable[string] `json:"credentials"` - PollInterval badoption.Duration `json:"poll_interval,omitempty"` RebalanceThreshold float64 `json:"rebalance_threshold,omitempty"` } type CCMExternalCredentialOptions struct { URL string `json:"url,omitempty"` ServerOptions - Token string `json:"token"` - Reverse bool `json:"reverse,omitempty"` - Detour string `json:"detour,omitempty"` - UsagesPath string `json:"usages_path,omitempty"` - PollInterval badoption.Duration `json:"poll_interval,omitempty"` + Token string `json:"token"` + Reverse bool `json:"reverse,omitempty"` + Detour string `json:"detour,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` } diff --git a/option/ocm.go b/option/ocm.go index af937560b8..aeb1f75e79 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -88,16 +88,14 @@ type OCMDefaultCredentialOptions struct { type OCMBalancerCredentialOptions struct { Strategy string `json:"strategy,omitempty"` Credentials badoption.Listable[string] `json:"credentials"` - PollInterval badoption.Duration `json:"poll_interval,omitempty"` RebalanceThreshold float64 `json:"rebalance_threshold,omitempty"` } type OCMExternalCredentialOptions struct { URL string `json:"url,omitempty"` ServerOptions - Token string `json:"token"` - Reverse bool `json:"reverse,omitempty"` - Detour string `json:"detour,omitempty"` - UsagesPath string `json:"usages_path,omitempty"` - PollInterval badoption.Duration `json:"poll_interval,omitempty"` + Token string `json:"token"` + Reverse bool `json:"reverse,omitempty"` + Detour string `json:"detour,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` } diff --git a/service/ccm/credential_builder.go b/service/ccm/credential_builder.go index 8ffd42b0dd..4650ef1f58 100644 --- a/service/ccm/credential_builder.go +++ b/service/ccm/credential_builder.go @@ -2,7 +2,6 @@ package ccm import ( "context" - "time" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" @@ -48,7 +47,7 @@ func buildCredentialProviders( if err != nil { return nil, nil, err } - providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, time.Duration(credentialOption.BalancerOptions.PollInterval), credentialOption.BalancerOptions.RebalanceThreshold, logger) + providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, credentialOption.BalancerOptions.RebalanceThreshold, logger) } } diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 5467d3d8d8..60a02c55b5 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -26,8 +26,6 @@ import ( "github.com/sagernet/sing/common/observable" ) -var acquireCredentialLockFunc = acquireCredentialLock - type claudeProfileSnapshot struct { OAuthAccount *claudeOAuthAccount AccountType string @@ -56,6 +54,7 @@ type defaultCredential struct { capWeekly float64 usageTracker *AggregatedUsage forwardHTTPClient *http.Client + acquireLock func(string) (func(), error) logger log.ContextLogger watcher *fswatch.Watcher watcherRetryAt time.Time @@ -122,6 +121,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef cap5h: cap5h, capWeekly: capWeekly, forwardHTTPClient: httpClient, + acquireLock: acquireCredentialLock, logger: logger, requestContext: requestContext, cancelRequests: cancelRequests, @@ -363,7 +363,11 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool { if !c.shouldAttemptRefresh(currentCredentials, force) { return false } - release, err := acquireCredentialLockFunc(c.configDir) + acquireLock := c.acquireLock + if acquireLock == nil { + acquireLock = acquireCredentialLock + } + release, err := acquireLock(c.configDir) if err != nil { c.logger.Debug("acquire credential lock for ", c.tag, ": ", err) return false diff --git a/service/ccm/credential_default_test.go b/service/ccm/credential_default_test.go index a5535aa26b..9435f3db7f 100644 --- a/service/ccm/credential_default_test.go +++ b/service/ccm/credential_default_test.go @@ -31,13 +31,9 @@ func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) { t.Fatal(err) } - originalLockFunc := acquireCredentialLockFunc - acquireCredentialLockFunc = func(string) (func(), error) { + credential.acquireLock = func(string) (func(), error) { return nil, errors.New("locked") } - t.Cleanup(func() { - acquireCredentialLockFunc = originalLockFunc - }) token, err := credential.getAccessToken() if err != nil { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index e8a0cf1ea6..be1bb7f5bb 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -38,7 +38,6 @@ type externalCredential struct { state credentialState stateAccess sync.RWMutex pollAccess sync.Mutex - pollInterval time.Duration usageTracker *AggregatedUsage logger log.ContextLogger @@ -113,18 +112,12 @@ func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) stri } func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { - pollInterval := time.Duration(options.PollInterval) - if pollInterval <= 0 { - pollInterval = 30 * time.Minute - } - requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) credential := &externalCredential{ tag: tag, token: options.Token, - pollInterval: pollInterval, logger: logger, requestContext: requestContext, cancelRequests: cancelRequests, @@ -355,9 +348,9 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { } func (c *externalCredential) markUpstreamRejected() { - c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(c.pollInterval)) + c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(defaultPollInterval)) c.stateAccess.Lock() - c.state.upstreamRejectedUntil = time.Now().Add(c.pollInterval) + c.state.upstreamRejectedUntil = time.Now().Add(defaultPollInterval) c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go index 640ced702a..4f5f2ad32b 100644 --- a/service/ccm/credential_provider.go +++ b/service/ccm/credential_provider.go @@ -112,7 +112,6 @@ type balancerProvider struct { credentials []Credential strategy string roundRobinIndex atomic.Uint64 - pollInterval time.Duration rebalanceThreshold float64 sessionAccess sync.RWMutex sessions map[string]sessionEntry @@ -121,14 +120,10 @@ type balancerProvider struct { logger log.ContextLogger } -func newBalancerProvider(credentials []Credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { - if pollInterval <= 0 { - pollInterval = defaultPollInterval - } +func newBalancerProvider(credentials []Credential, strategy string, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { return &balancerProvider{ credentials: credentials, strategy: strategy, - pollInterval: pollInterval, rebalanceThreshold: rebalanceThreshold, sessions: make(map[string]sessionEntry), credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), @@ -383,14 +378,14 @@ func (p *balancerProvider) pollIfStale() { p.interruptAccess.Unlock() for _, credential := range p.credentials { - if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) { credential.pollUsage() } } } func (p *balancerProvider) pollCredentialIfStale(credential Credential) { - if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) { credential.pollUsage() } } diff --git a/service/ccm/test_helpers_test.go b/service/ccm/test_helpers_test.go index 175dc71add..2f52591592 100644 --- a/service/ccm/test_helpers_test.go +++ b/service/ccm/test_helpers_test.go @@ -85,6 +85,7 @@ func newTestDefaultCredential(t *testing.T, credentialPath string, transport htt cap5h: 99, capWeekly: 99, forwardHTTPClient: &http.Client{Transport: transport}, + acquireLock: acquireCredentialLock, logger: log.NewNOPFactory().Logger(), requestContext: requestContext, cancelRequests: cancelRequests, diff --git a/service/ocm/credential_builder.go b/service/ocm/credential_builder.go index fb69c62c93..57a34fb7c3 100644 --- a/service/ocm/credential_builder.go +++ b/service/ocm/credential_builder.go @@ -2,7 +2,6 @@ package ocm import ( "context" - "time" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" @@ -48,7 +47,7 @@ func buildOCMCredentialProviders( if err != nil { return nil, nil, err } - providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, time.Duration(credentialOption.BalancerOptions.PollInterval), credentialOption.BalancerOptions.RebalanceThreshold, logger) + providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, credentialOption.BalancerOptions.RebalanceThreshold, logger) } } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 7553cdf439..765c3a27b8 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -40,7 +40,6 @@ type externalCredential struct { state credentialState stateAccess sync.RWMutex pollAccess sync.Mutex - pollInterval time.Duration usageTracker *AggregatedUsage logger log.ContextLogger @@ -131,18 +130,12 @@ func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) stri } func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { - pollInterval := time.Duration(options.PollInterval) - if pollInterval <= 0 { - pollInterval = 30 * time.Minute - } - requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) credential := &externalCredential{ tag: tag, token: options.Token, - pollInterval: pollInterval, logger: logger, requestContext: requestContext, cancelRequests: cancelRequests, @@ -377,9 +370,9 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { } func (c *externalCredential) markUpstreamRejected() { - c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(c.pollInterval)) + c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(defaultPollInterval)) c.stateAccess.Lock() - c.state.upstreamRejectedUntil = time.Now().Add(c.pollInterval) + c.state.upstreamRejectedUntil = time.Now().Add(defaultPollInterval) c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil) shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() diff --git a/service/ocm/credential_provider.go b/service/ocm/credential_provider.go index 714e44ab7c..b296543cc1 100644 --- a/service/ocm/credential_provider.go +++ b/service/ocm/credential_provider.go @@ -112,7 +112,6 @@ type balancerProvider struct { credentials []Credential strategy string roundRobinIndex atomic.Uint64 - pollInterval time.Duration rebalanceThreshold float64 sessionAccess sync.RWMutex sessions map[string]sessionEntry @@ -125,14 +124,10 @@ func compositeCredentialSelectable(credential Credential) bool { return !credential.ocmIsAPIKeyMode() } -func newBalancerProvider(credentials []Credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { - if pollInterval <= 0 { - pollInterval = defaultPollInterval - } +func newBalancerProvider(credentials []Credential, strategy string, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { return &balancerProvider{ credentials: credentials, strategy: strategy, - pollInterval: pollInterval, rebalanceThreshold: rebalanceThreshold, sessions: make(map[string]sessionEntry), credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), @@ -410,14 +405,14 @@ func (p *balancerProvider) pollIfStale() { p.interruptAccess.Unlock() for _, credential := range p.credentials { - if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) { credential.pollUsage() } } } func (p *balancerProvider) pollCredentialIfStale(credential Credential) { - if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) { credential.pollUsage() } } From ca60f93184aeb1ca0d18c4a538c66d900bc9f086 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 15:42:51 +0800 Subject: [PATCH 85/96] fix(ocm): inject synthetic rate limits inline when intercepting upstream events The initial synthetic event from 6721dff48 arrives before the Codex CLI's response stream reader is active. Additionally, the shouldEmit gate in updateStateFromHeaders suppresses the async replacement when values haven't changed. Send aggregated status inline in proxyWebSocketUpstreamToClient so the client receives it at the exact protocol position it expects. --- service/ocm/service_websocket.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 4f76179ef6..a441bdc8dd 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -342,7 +342,7 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint, provider, userConfig) }() go func() { defer waitGroup.Done() @@ -407,7 +407,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn } } -func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint, provider credentialProvider, userConfig *option.OCMUser) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { @@ -428,6 +428,8 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe switch event.Type { case "codex.rate_limits": s.handleWebSocketRateLimitsEvent(data, selectedCredential) + status := s.computeAggregatedUtilization(provider, userConfig) + writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, status) continue case "error": s.handleWebSocketErrorEvent(data, selectedCredential) From 6f45ea9c270c4def4f0ab4e4e727d13ad4c65b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 16:03:39 +0800 Subject: [PATCH 86/96] Revert "fix(ocm): inject synthetic rate limits inline when intercepting upstream events" This reverts commit ca60f93184aeb1ca0d18c4a538c66d900bc9f086. --- service/ocm/service_websocket.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index a441bdc8dd..4f76179ef6 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -342,7 +342,7 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint, provider, userConfig) + s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint) }() go func() { defer waitGroup.Done() @@ -407,7 +407,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn } } -func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint, provider credentialProvider, userConfig *option.OCMUser) { +func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { @@ -428,8 +428,6 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe switch event.Type { case "codex.rate_limits": s.handleWebSocketRateLimitsEvent(data, selectedCredential) - status := s.computeAggregatedUtilization(provider, userConfig) - writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, status) continue case "error": s.handleWebSocketErrorEvent(data, selectedCredential) From e1c966731991997b4b36c4a89403ded3f52d5b76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 16:03:46 +0800 Subject: [PATCH 87/96] Revert "fix(ocm): send rate limit status immediately on WebSocket connect" This reverts commit 6721dff48a558750270b0d1804c217e9f126b648. --- service/ocm/service_websocket.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 4f76179ef6..4c552cbf5b 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -570,11 +570,8 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn } defer s.statusObserver.UnSubscribe(subscription) - last := s.computeAggregatedUtilization(provider, userConfig) - err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, last) - if err != nil { - return - } + var last aggregatedStatus + hasLast := false for { select { @@ -585,6 +582,13 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn case <-sessionClosed: return case <-firstRealRequest: + current := s.computeAggregatedUtilization(provider, userConfig) + err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current) + if err != nil { + return + } + last = current + hasLast = true firstRealRequest = nil case <-subscription: for { @@ -595,6 +599,9 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn } } drained: + if !hasLast { + continue + } current := s.computeAggregatedUtilization(provider, userConfig) if current.equal(last) { continue From e49d0685ad8f3a749dd274b3156bc43a3f98abba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 21:37:12 +0800 Subject: [PATCH 88/96] ccm: Fix token refresh --- service/ccm/credential_oauth.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/service/ccm/credential_oauth.go b/service/ccm/credential_oauth.go index 03fcb4023a..6a1394b880 100644 --- a/service/ccm/credential_oauth.go +++ b/service/ccm/credential_oauth.go @@ -50,7 +50,10 @@ func resolveRefreshScopes(stored []string) string { return strings.Join(stored, " ") } -const ccmUserAgentFallback = "claude-code/2.1.72" +const ( + ccmRefreshUserAgent = "axios/1.13.6" + ccmUserAgentFallback = "claude-code/2.1.85" +) var ( ccmUserAgentOnce sync.Once @@ -215,7 +218,7 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau return nil, err } request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) + request.Header.Set("User-Agent", ccmRefreshUserAgent) return request, nil }) if err != nil { From cd5007ffbbd803be7c58a4ae69693da0d0211d64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 22:16:02 +0800 Subject: [PATCH 89/96] fix(ccm,ocm): track external credential poll failures and re-poll on user connect External credentials now properly increment consecutivePollFailures on poll errors (matching defaultCredential behavior), marking the credential as temporarily blocked. When a user with external_credential connects and the credential is not usable, a forced poll is triggered to check recovery. --- service/ccm/credential_external.go | 23 ++++++++++++++--------- service/ccm/service_handler.go | 8 ++++++++ service/ocm/credential_external.go | 23 ++++++++++++++--------- service/ocm/service_handler.go | 8 ++++++++ 4 files changed, 44 insertions(+), 18 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index be1bb7f5bb..11ddc8dad6 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -614,7 +614,7 @@ func (c *externalCredential) pollUsage() { response, err := c.doPollUsageRequest(ctx) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": ", err) - c.clearPollFailures() + c.incrementPollFailures() return } defer response.Body.Close() @@ -622,35 +622,35 @@ func (c *externalCredential) pollUsage() { if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - c.clearPollFailures() + c.incrementPollFailures() return } body, err := io.ReadAll(response.Body) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": read body: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } var rawFields map[string]json.RawMessage err = json.Unmarshal(body, &rawFields) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || rawFields["plan_weight"] == nil { c.logger.Error("poll usage for ", c.tag, ": invalid response") - c.clearPollFailures() + c.incrementPollFailures() return } var statusResponse statusPayload err = json.Unmarshal(body, &statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } @@ -943,11 +943,16 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati return baseInterval } -func (c *externalCredential) clearPollFailures() { +func (c *externalCredential) incrementPollFailures() { c.stateAccess.Lock() - c.state.consecutivePollFailures = 0 - c.checkTransitionLocked() + c.state.consecutivePollFailures++ + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{}) + shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + c.emitStatusUpdate() } func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index ad5dfcaee0..dd8c71d2d5 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -218,6 +218,14 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale() + if userConfig != nil && userConfig.ExternalCredential != "" { + for _, credential := range s.allCredentials { + if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() { + credential.pollUsage() + break + } + } + } s.cleanSessionModels() anthropicBetaHeader := r.Header.Get("anthropic-beta") diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 765c3a27b8..67b9b2a1b5 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -656,7 +656,7 @@ func (c *externalCredential) pollUsage() { response, err := c.doPollUsageRequest(ctx) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": ", err) - c.clearPollFailures() + c.incrementPollFailures() return } defer response.Body.Close() @@ -664,35 +664,35 @@ func (c *externalCredential) pollUsage() { if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - c.clearPollFailures() + c.incrementPollFailures() return } body, err := io.ReadAll(response.Body) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": read body: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } var rawFields map[string]json.RawMessage err = json.Unmarshal(body, &rawFields) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } if rawFields["limits"] == nil && (rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || rawFields["plan_weight"] == nil) { c.logger.Error("poll usage for ", c.tag, ": invalid response") - c.clearPollFailures() + c.incrementPollFailures() return } var statusResponse statusPayload err = json.Unmarshal(body, &statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } @@ -985,11 +985,16 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati return baseInterval } -func (c *externalCredential) clearPollFailures() { +func (c *externalCredential) incrementPollFailures() { c.stateAccess.Lock() - c.state.consecutivePollFailures = 0 - c.checkTransitionLocked() + c.state.consecutivePollFailures++ + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{}) + shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + c.emitStatusUpdate() } func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index c2e90a582d..cfb34e15b0 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -132,6 +132,14 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale() + if userConfig != nil && userConfig.ExternalCredential != "" { + for _, credential := range s.allCredentials { + if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() { + credential.pollUsage() + break + } + } + } selection := credentialSelectionForUser(userConfig) From d9c298af1ee7ce1ef733254cb5a08171d79095cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 23:42:14 +0800 Subject: [PATCH 90/96] fix(ccm,ocm): remove upstream rate limit header forwarding, compute locally Strip all upstream rate limit headers and compute unified-status, representative-claim, reset times, and surpassed-threshold from aggregated utilization data. Never expose per-account overage or fallback information. Remove per-credential unified state storage, snapshot aggregation, and WebSocket synthetic rate limit events. --- service/ccm/credential.go | 19 -- service/ccm/credential_default.go | 21 -- service/ccm/credential_external.go | 78 ------- service/ccm/rate_limit_state.go | 48 ----- service/ccm/service_status.go | 169 +++++---------- service/ccm/service_status_test.go | 111 +++++++--- service/ocm/credential.go | 2 - service/ocm/credential_default.go | 20 -- service/ocm/credential_external.go | 99 ++------- service/ocm/rate_limit_state.go | 20 -- service/ocm/service_status.go | 319 ++++++----------------------- service/ocm/service_status_test.go | 90 -------- service/ocm/service_websocket.go | 172 +--------------- 13 files changed, 210 insertions(+), 958 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index e3788c43fc..48b97b95b7 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -63,13 +63,6 @@ type credentialState struct { availabilityReason availabilityReason availabilityResetAt time.Time lastKnownDataAt time.Time - unifiedStatus unifiedRateLimitStatus - unifiedResetAt time.Time - representativeClaim string - unifiedFallbackAvailable bool - overageStatus string - overageResetAt time.Time - overageDisabledReason string accountUUID string accountType string rateLimitTier string @@ -125,7 +118,6 @@ type Credential interface { markRateLimited(resetAt time.Time) markUpstreamRejected() availabilityStatus() availabilityStatus - unifiedRateLimitState() unifiedRateLimitInfo earliestReset() time.Time unavailableError() error @@ -252,17 +244,6 @@ func (s credentialState) currentAvailability() availabilityStatus { } } -func (s credentialState) currentUnifiedRateLimit() unifiedRateLimitInfo { - return unifiedRateLimitInfo{ - Status: s.unifiedStatus, - ResetAt: s.unifiedResetAt, - RepresentativeClaim: s.representativeClaim, - FallbackAvailable: s.unifiedFallbackAvailable, - OverageStatus: s.overageStatus, - OverageResetAt: s.overageResetAt, - OverageDisabledReason: s.overageDisabledReason, - }.normalized() -} func parseRateLimitResetFromHeaders(headers http.Header) time.Time { claim := headers.Get("anthropic-ratelimit-unified-representative-claim") diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 60a02c55b5..bf8404f880 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -629,19 +629,6 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { c.state.lastUpdated = time.Now() c.state.noteSnapshotData() } - if unifiedStatus := unifiedRateLimitStatus(headers.Get("anthropic-ratelimit-unified-status")); unifiedStatus != "" { - c.state.unifiedStatus = unifiedStatus - } - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-reset"); exists { - c.state.unifiedResetAt = value - } - c.state.representativeClaim = headers.Get("anthropic-ratelimit-unified-representative-claim") - c.state.unifiedFallbackAvailable = headers.Get("anthropic-ratelimit-unified-fallback") == "available" - c.state.overageStatus = headers.Get("anthropic-ratelimit-unified-overage-status") - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-overage-reset"); exists { - c.state.overageResetAt = value - } - c.state.overageDisabledReason = headers.Get("anthropic-ratelimit-unified-overage-disabled-reason") if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { @@ -666,8 +653,6 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) { c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt) - c.state.unifiedStatus = unifiedRateLimitStatusRejected - c.state.unifiedResetAt = resetAt shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -800,12 +785,6 @@ func (c *defaultCredential) availabilityStatus() availabilityStatus { return c.state.currentAvailability() } -func (c *defaultCredential) unifiedRateLimitState() unifiedRateLimitInfo { - c.stateAccess.RLock() - defer c.stateAccess.RUnlock() - return c.state.currentUnifiedRateLimit() -} - func (c *defaultCredential) unavailableError() error { c.stateAccess.RLock() defer c.stateAccess.RUnlock() diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 11ddc8dad6..f57ead1581 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -337,8 +337,6 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) { c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt) - c.state.unifiedStatus = unifiedRateLimitStatusRejected - c.state.unifiedResetAt = resetAt shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() if shouldInterrupt { @@ -492,19 +490,6 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.lastUpdated = time.Now() c.state.noteSnapshotData() } - if unifiedStatus := unifiedRateLimitStatus(headers.Get("anthropic-ratelimit-unified-status")); unifiedStatus != "" { - c.state.unifiedStatus = unifiedStatus - } - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-reset"); exists { - c.state.unifiedResetAt = value - } - c.state.representativeClaim = headers.Get("anthropic-ratelimit-unified-representative-claim") - c.state.unifiedFallbackAvailable = headers.Get("anthropic-ratelimit-unified-fallback") == "available" - c.state.overageStatus = headers.Get("anthropic-ratelimit-unified-overage-status") - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-overage-reset"); exists { - c.state.overageResetAt = value - } - c.state.overageDisabledReason = headers.Get("anthropic-ratelimit-unified-overage-disabled-reason") if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { resetSuffix := "" if !c.state.weeklyReset.IsZero() { @@ -662,11 +647,6 @@ func (c *externalCredential) pollUsage() { c.state.upstreamRejectedUntil = time.Time{} c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization - c.state.unifiedStatus = unifiedRateLimitStatus(statusResponse.UnifiedStatus) - c.state.representativeClaim = statusResponse.RepresentativeClaim - c.state.unifiedFallbackAvailable = statusResponse.FallbackAvailable - c.state.overageStatus = statusResponse.OverageStatus - c.state.overageDisabledReason = statusResponse.OverageDisabledReason if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } @@ -676,30 +656,6 @@ func (c *externalCredential) pollUsage() { if statusResponse.WeeklyReset > 0 { c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) } - if statusResponse.UnifiedReset > 0 { - c.state.unifiedResetAt = time.Unix(statusResponse.UnifiedReset, 0) - } - if statusResponse.OverageReset > 0 { - c.state.overageResetAt = time.Unix(statusResponse.OverageReset, 0) - } - if statusResponse.Availability != nil { - switch availabilityState(statusResponse.Availability.State) { - case availabilityStateRateLimited: - c.state.hardRateLimited = true - if statusResponse.Availability.ResetAt > 0 { - c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0) - } - case availabilityStateTemporarilyBlocked: - resetAt := time.Time{} - if statusResponse.Availability.ResetAt > 0 { - resetAt = time.Unix(statusResponse.Availability.ResetAt, 0) - } - c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt) - if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() { - c.state.upstreamRejectedUntil = resetAt - } - } - } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } @@ -800,11 +756,6 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr c.state.upstreamRejectedUntil = time.Time{} c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization - c.state.unifiedStatus = unifiedRateLimitStatus(statusResponse.UnifiedStatus) - c.state.representativeClaim = statusResponse.RepresentativeClaim - c.state.unifiedFallbackAvailable = statusResponse.FallbackAvailable - c.state.overageStatus = statusResponse.OverageStatus - c.state.overageDisabledReason = statusResponse.OverageDisabledReason if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } @@ -814,30 +765,6 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if statusResponse.WeeklyReset > 0 { c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) } - if statusResponse.UnifiedReset > 0 { - c.state.unifiedResetAt = time.Unix(statusResponse.UnifiedReset, 0) - } - if statusResponse.OverageReset > 0 { - c.state.overageResetAt = time.Unix(statusResponse.OverageReset, 0) - } - if statusResponse.Availability != nil { - switch availabilityState(statusResponse.Availability.State) { - case availabilityStateRateLimited: - c.state.hardRateLimited = true - if statusResponse.Availability.ResetAt > 0 { - c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0) - } - case availabilityStateTemporarilyBlocked: - resetAt := time.Time{} - if statusResponse.Availability.ResetAt > 0 { - resetAt = time.Unix(statusResponse.Availability.ResetAt, 0) - } - c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt) - if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() { - c.state.upstreamRejectedUntil = resetAt - } - } - } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } @@ -921,11 +848,6 @@ func (c *externalCredential) availabilityStatus() availabilityStatus { return c.state.currentAvailability() } -func (c *externalCredential) unifiedRateLimitState() unifiedRateLimitInfo { - c.stateAccess.RLock() - defer c.stateAccess.RUnlock() - return c.state.currentUnifiedRateLimit() -} func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() diff --git a/service/ccm/rate_limit_state.go b/service/ccm/rate_limit_state.go index ab584419fb..696fe81429 100644 --- a/service/ccm/rate_limit_state.go +++ b/service/ccm/rate_limit_state.go @@ -29,12 +29,6 @@ type availabilityStatus struct { ResetAt time.Time } -type availabilityPayload struct { - State string `json:"state"` - Reason string `json:"reason,omitempty"` - ResetAt int64 `json:"reset_at,omitempty"` -} - func (s availabilityStatus) normalized() availabilityStatus { if s.State == "" { s.State = availabilityStateUnknown @@ -45,48 +39,6 @@ func (s availabilityStatus) normalized() availabilityStatus { return s } -func (s availabilityStatus) toPayload() *availabilityPayload { - s = s.normalized() - if s.State == "" { - return nil - } - payload := &availabilityPayload{ - State: string(s.State), - } - if s.Reason != "" && s.Reason != availabilityReasonUnknown { - payload.Reason = string(s.Reason) - } - if !s.ResetAt.IsZero() { - payload.ResetAt = s.ResetAt.Unix() - } - return payload -} - -type unifiedRateLimitStatus string - -const ( - unifiedRateLimitStatusAllowed unifiedRateLimitStatus = "allowed" - unifiedRateLimitStatusAllowedWarning unifiedRateLimitStatus = "allowed_warning" - unifiedRateLimitStatusRejected unifiedRateLimitStatus = "rejected" -) - -type unifiedRateLimitInfo struct { - Status unifiedRateLimitStatus - ResetAt time.Time - RepresentativeClaim string - FallbackAvailable bool - OverageStatus string - OverageResetAt time.Time - OverageDisabledReason string -} - -func (s unifiedRateLimitInfo) normalized() unifiedRateLimitInfo { - if s.Status == "" { - s.Status = unifiedRateLimitStatusAllowed - } - return s -} - func claudeWindowProgress(resetAt time.Time, windowSeconds float64, now time.Time) float64 { if resetAt.IsZero() || windowSeconds <= 0 { return 0 diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index 11ae3fd3ad..f5bd9bc63c 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -13,19 +13,11 @@ import ( ) type statusPayload struct { - FiveHourUtilization float64 `json:"five_hour_utilization"` - FiveHourReset int64 `json:"five_hour_reset"` - WeeklyUtilization float64 `json:"weekly_utilization"` - WeeklyReset int64 `json:"weekly_reset"` - PlanWeight float64 `json:"plan_weight"` - UnifiedStatus string `json:"unified_status,omitempty"` - UnifiedReset int64 `json:"unified_reset,omitempty"` - RepresentativeClaim string `json:"representative_claim,omitempty"` - FallbackAvailable bool `json:"fallback_available,omitempty"` - OverageStatus string `json:"overage_status,omitempty"` - OverageReset int64 `json:"overage_reset,omitempty"` - OverageDisabledReason string `json:"overage_disabled_reason,omitempty"` - Availability *availabilityPayload `json:"availability,omitempty"` + FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` + WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` + PlanWeight float64 `json:"plan_weight"` } type aggregatedStatus struct { @@ -34,7 +26,6 @@ type aggregatedStatus struct { totalWeight float64 fiveHourReset time.Time weeklyReset time.Time - unifiedRateLimit unifiedRateLimitInfo availability availabilityStatus } @@ -50,27 +41,17 @@ func (s aggregatedStatus) equal(other aggregatedStatus) bool { } func (s aggregatedStatus) toPayload() statusPayload { - unified := s.unifiedRateLimit.normalized() return statusPayload{ - FiveHourUtilization: s.fiveHourUtilization, - FiveHourReset: resetToEpoch(s.fiveHourReset), - WeeklyUtilization: s.weeklyUtilization, - WeeklyReset: resetToEpoch(s.weeklyReset), - PlanWeight: s.totalWeight, - UnifiedStatus: string(unified.Status), - UnifiedReset: resetToEpoch(unified.ResetAt), - RepresentativeClaim: unified.RepresentativeClaim, - FallbackAvailable: unified.FallbackAvailable, - OverageStatus: unified.OverageStatus, - OverageReset: resetToEpoch(unified.OverageResetAt), - OverageDisabledReason: unified.OverageDisabledReason, - Availability: s.availability.toPayload(), + FiveHourUtilization: s.fiveHourUtilization, + FiveHourReset: resetToEpoch(s.fiveHourReset), + WeeklyUtilization: s.weeklyUtilization, + WeeklyReset: resetToEpoch(s.weeklyReset), + PlanWeight: s.totalWeight, } } type aggregateInput struct { availability availabilityStatus - unified unifiedRateLimitInfo } func aggregateAvailability(inputs []aggregateInput) availabilityStatus { @@ -133,7 +114,9 @@ func aggregateAvailability(inputs []aggregateInput) availabilityStatus { } } -func chooseRepresentativeClaim(status unifiedRateLimitStatus, fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, now time.Time) string { +func chooseRepresentativeClaim(fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, now time.Time) string { + fiveHourWarning := claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now) + weeklyWarning := claudeWeeklyWarning(weeklyUtilization, weeklyReset, now) type claimCandidate struct { name string priority int @@ -142,15 +125,15 @@ func chooseRepresentativeClaim(status unifiedRateLimitStatus, fiveHourUtilizatio candidateFor := func(name string, utilization float64, warning bool) claimCandidate { priority := 0 switch { - case status == unifiedRateLimitStatusRejected && utilization >= 100: + case utilization >= 100: priority = 2 case warning: priority = 1 } return claimCandidate{name: name, priority: priority, utilization: utilization} } - five := candidateFor("5h", fiveHourUtilization, claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now)) - weekly := candidateFor("7d", weeklyUtilization, claudeWeeklyWarning(weeklyUtilization, weeklyReset, now)) + five := candidateFor("5h", fiveHourUtilization, fiveHourWarning) + weekly := candidateFor("7d", weeklyUtilization, weeklyWarning) switch { case five.priority > weekly.priority: return five.name @@ -169,53 +152,6 @@ func chooseRepresentativeClaim(status unifiedRateLimitStatus, fiveHourUtilizatio } } -func aggregateUnifiedRateLimit(inputs []aggregateInput, fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, availability availabilityStatus) unifiedRateLimitInfo { - now := time.Now() - info := unifiedRateLimitInfo{} - usableCount := 0 - for _, input := range inputs { - if input.availability.State == availabilityStateUsable { - usableCount++ - } - if input.unified.OverageStatus != "" && info.OverageStatus == "" { - info.OverageStatus = input.unified.OverageStatus - info.OverageResetAt = input.unified.OverageResetAt - info.OverageDisabledReason = input.unified.OverageDisabledReason - } - if input.unified.Status == unifiedRateLimitStatusRejected { - info.Status = unifiedRateLimitStatusRejected - if !input.unified.ResetAt.IsZero() && (info.ResetAt.IsZero() || input.unified.ResetAt.Before(info.ResetAt)) { - info.ResetAt = input.unified.ResetAt - info.RepresentativeClaim = input.unified.RepresentativeClaim - } - } - } - if info.Status == "" { - switch { - case availability.State == availabilityStateRateLimited || fiveHourUtilization >= 100 || weeklyUtilization >= 100: - info.Status = unifiedRateLimitStatusRejected - info.ResetAt = availability.ResetAt - case claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now) || claudeWeeklyWarning(weeklyUtilization, weeklyReset, now): - info.Status = unifiedRateLimitStatusAllowedWarning - default: - info.Status = unifiedRateLimitStatusAllowed - } - } - info.FallbackAvailable = usableCount > 0 && len(inputs) > 1 - if info.RepresentativeClaim == "" { - info.RepresentativeClaim = chooseRepresentativeClaim(info.Status, fiveHourUtilization, fiveHourReset, weeklyUtilization, weeklyReset, now) - } - if info.ResetAt.IsZero() { - switch info.RepresentativeClaim { - case "7d": - info.ResetAt = weeklyReset - default: - info.ResetAt = fiveHourReset - } - } - return info.normalized() -} - func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -350,7 +286,6 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user } visibleInputs = append(visibleInputs, aggregateInput{ availability: credential.availabilityStatus(), - unified: credential.unifiedRateLimitState(), }) if !credential.hasSnapshotData() { continue @@ -393,7 +328,6 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user result.fiveHourUtilization = 100 result.weeklyUtilization = 100 } - result.unifiedRateLimit = aggregateUnifiedRateLimit(visibleInputs, result.fiveHourUtilization, result.fiveHourReset, result.weeklyUtilization, result.weeklyReset, availability) return result } result := aggregatedStatus{ @@ -410,66 +344,55 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour))) } - result.unifiedRateLimit = aggregateUnifiedRateLimit(visibleInputs, result.fiveHourUtilization, result.fiveHourReset, result.weeklyUtilization, result.weeklyReset, availability) return result } func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) { + for key := range headers { + if strings.HasPrefix(strings.ToLower(key), "anthropic-ratelimit-unified-") { + headers.Del(key) + } + } status := s.computeAggregatedUtilization(provider, userConfig) + now := time.Now() headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(status.fiveHourUtilization/100, 'f', 6, 64)) headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(status.weeklyUtilization/100, 'f', 6, 64)) if !status.fiveHourReset.IsZero() { headers.Set("anthropic-ratelimit-unified-5h-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) - } else { - headers.Del("anthropic-ratelimit-unified-5h-reset") } if !status.weeklyReset.IsZero() { headers.Set("anthropic-ratelimit-unified-7d-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10)) - } else { - headers.Del("anthropic-ratelimit-unified-7d-reset") } if status.totalWeight > 0 { headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } - headers.Set("anthropic-ratelimit-unified-status", string(status.unifiedRateLimit.normalized().Status)) - if !status.unifiedRateLimit.ResetAt.IsZero() { - headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.unifiedRateLimit.ResetAt.Unix(), 10)) - } else { - headers.Del("anthropic-ratelimit-unified-reset") - } - if status.unifiedRateLimit.RepresentativeClaim != "" { - headers.Set("anthropic-ratelimit-unified-representative-claim", status.unifiedRateLimit.RepresentativeClaim) - } else { - headers.Del("anthropic-ratelimit-unified-representative-claim") - } - if status.unifiedRateLimit.FallbackAvailable { - headers.Set("anthropic-ratelimit-unified-fallback", "available") - } else { - headers.Del("anthropic-ratelimit-unified-fallback") - } - if status.unifiedRateLimit.OverageStatus != "" { - headers.Set("anthropic-ratelimit-unified-overage-status", status.unifiedRateLimit.OverageStatus) - } else { - headers.Del("anthropic-ratelimit-unified-overage-status") - } - if !status.unifiedRateLimit.OverageResetAt.IsZero() { - headers.Set("anthropic-ratelimit-unified-overage-reset", strconv.FormatInt(status.unifiedRateLimit.OverageResetAt.Unix(), 10)) - } else { - headers.Del("anthropic-ratelimit-unified-overage-reset") - } - if status.unifiedRateLimit.OverageDisabledReason != "" { - headers.Set("anthropic-ratelimit-unified-overage-disabled-reason", status.unifiedRateLimit.OverageDisabledReason) - } else { - headers.Del("anthropic-ratelimit-unified-overage-disabled-reason") + fiveHourWarning := claudeFiveHourWarning(status.fiveHourUtilization, status.fiveHourReset, now) + weeklyWarning := claudeWeeklyWarning(status.weeklyUtilization, status.weeklyReset, now) + switch { + case status.fiveHourUtilization >= 100 || status.weeklyUtilization >= 100 || + status.availability.State == availabilityStateRateLimited: + headers.Set("anthropic-ratelimit-unified-status", "rejected") + case fiveHourWarning || weeklyWarning: + headers.Set("anthropic-ratelimit-unified-status", "allowed_warning") + default: + headers.Set("anthropic-ratelimit-unified-status", "allowed") + } + claim := chooseRepresentativeClaim(status.fiveHourUtilization, status.fiveHourReset, status.weeklyUtilization, status.weeklyReset, now) + headers.Set("anthropic-ratelimit-unified-representative-claim", claim) + switch claim { + case "7d": + if !status.weeklyReset.IsZero() { + headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10)) + } + default: + if !status.fiveHourReset.IsZero() { + headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) + } } - if claudeFiveHourWarning(status.fiveHourUtilization, status.fiveHourReset, time.Now()) || status.fiveHourUtilization >= 100 { + if fiveHourWarning || status.fiveHourUtilization >= 100 { headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true") - } else { - headers.Del("anthropic-ratelimit-unified-5h-surpassed-threshold") } - if claudeWeeklyWarning(status.weeklyUtilization, status.weeklyReset, time.Now()) || status.weeklyUtilization >= 100 { + if weeklyWarning || status.weeklyUtilization >= 100 { headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "true") - } else { - headers.Del("anthropic-ratelimit-unified-7d-surpassed-threshold") } } diff --git a/service/ccm/service_status_test.go b/service/ccm/service_status_test.go index 9aef16de39..c2dbea1a21 100644 --- a/service/ccm/service_status_test.go +++ b/service/ccm/service_status_test.go @@ -24,28 +24,26 @@ type testCredential struct { fiveReset time.Time weeklyReset time.Time availability availabilityStatus - unified unifiedRateLimitInfo } -func (c *testCredential) tagName() string { return c.tag } -func (c *testCredential) isAvailable() bool { return c.available } -func (c *testCredential) isUsable() bool { return c.usable } -func (c *testCredential) isExternal() bool { return c.external } -func (c *testCredential) hasSnapshotData() bool { return c.hasData } -func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour } -func (c *testCredential) weeklyUtilization() float64 { return c.weekly } -func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV } -func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV } -func (c *testCredential) planWeight() float64 { return c.weight } -func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset } -func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset } -func (c *testCredential) markRateLimited(time.Time) {} -func (c *testCredential) markUpstreamRejected() {} -func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability } -func (c *testCredential) unifiedRateLimitState() unifiedRateLimitInfo { return c.unified } -func (c *testCredential) earliestReset() time.Time { return c.fiveReset } -func (c *testCredential) unavailableError() error { return nil } -func (c *testCredential) getAccessToken() (string, error) { return "", nil } +func (c *testCredential) tagName() string { return c.tag } +func (c *testCredential) isAvailable() bool { return c.available } +func (c *testCredential) isUsable() bool { return c.usable } +func (c *testCredential) isExternal() bool { return c.external } +func (c *testCredential) hasSnapshotData() bool { return c.hasData } +func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour } +func (c *testCredential) weeklyUtilization() float64 { return c.weekly } +func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV } +func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV } +func (c *testCredential) planWeight() float64 { return c.weight } +func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset } +func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset } +func (c *testCredential) markRateLimited(time.Time) {} +func (c *testCredential) markUpstreamRejected() {} +func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability } +func (c *testCredential) earliestReset() time.Time { return c.fiveReset } +func (c *testCredential) unavailableError() error { return nil } +func (c *testCredential) getAccessToken() (string, error) { return "", nil } func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) { return nil, nil } @@ -98,22 +96,18 @@ func TestComputeAggregatedUtilizationPreservesSnapshotForRateLimitedCredential(t fiveReset: reset, weeklyReset: reset.Add(2 * time.Hour), availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: reset}, - unified: unifiedRateLimitInfo{Status: unifiedRateLimitStatusRejected, ResetAt: reset, RepresentativeClaim: "5h"}, }, }}, nil) if status.fiveHourUtilization != 42 || status.weeklyUtilization != 18 { t.Fatalf("expected preserved utilization, got 5h=%v weekly=%v", status.fiveHourUtilization, status.weeklyUtilization) } - if status.unifiedRateLimit.Status != unifiedRateLimitStatusRejected { - t.Fatalf("expected rejected unified status, got %q", status.unifiedRateLimit.Status) - } if status.availability.State != availabilityStateRateLimited { t.Fatalf("expected rate-limited availability, got %#v", status.availability) } } -func TestRewriteResponseHeadersIncludesUnifiedHeaders(t *testing.T) { +func TestRewriteResponseHeadersComputesUnifiedStatus(t *testing.T) { t.Parallel() reset := time.Now().Add(80 * time.Minute) @@ -147,6 +141,73 @@ func TestRewriteResponseHeadersIncludesUnifiedHeaders(t *testing.T) { } } +func TestRewriteResponseHeadersStripsUpstreamHeaders(t *testing.T) { + t.Parallel() + + service := &Service{} + headers := make(http.Header) + headers.Set("anthropic-ratelimit-unified-overage-status", "rejected") + headers.Set("anthropic-ratelimit-unified-overage-disabled-reason", "org_level_disabled") + headers.Set("anthropic-ratelimit-unified-fallback", "available") + service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: true, + hasData: true, + fiveHour: 10, + weekly: 5, + fiveHourCapV: 100, + weeklyCapV: 100, + weight: 1, + fiveReset: time.Now().Add(3 * time.Hour), + weeklyReset: time.Now().Add(5 * 24 * time.Hour), + availability: availabilityStatus{State: availabilityStateUsable}, + }, + }}, nil) + + if headers.Get("anthropic-ratelimit-unified-overage-status") != "" { + t.Fatalf("expected overage-status stripped, got %q", headers.Get("anthropic-ratelimit-unified-overage-status")) + } + if headers.Get("anthropic-ratelimit-unified-overage-disabled-reason") != "" { + t.Fatalf("expected overage-disabled-reason stripped, got %q", headers.Get("anthropic-ratelimit-unified-overage-disabled-reason")) + } + if headers.Get("anthropic-ratelimit-unified-fallback") != "" { + t.Fatalf("expected fallback stripped, got %q", headers.Get("anthropic-ratelimit-unified-fallback")) + } + if headers.Get("anthropic-ratelimit-unified-status") != "allowed" { + t.Fatalf("expected allowed status, got %q", headers.Get("anthropic-ratelimit-unified-status")) + } +} + +func TestRewriteResponseHeadersRejectedOnHardRateLimit(t *testing.T) { + t.Parallel() + + reset := time.Now().Add(10 * time.Minute) + service := &Service{} + headers := make(http.Header) + service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: false, + hasData: true, + fiveHour: 50, + weekly: 20, + fiveHourCapV: 100, + weeklyCapV: 100, + weight: 1, + fiveReset: reset, + weeklyReset: time.Now().Add(5 * 24 * time.Hour), + availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: reset}, + }, + }}, nil) + + if headers.Get("anthropic-ratelimit-unified-status") != "rejected" { + t.Fatalf("expected rejected (hard rate limited), got %q", headers.Get("anthropic-ratelimit-unified-status")) + } +} + func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) { t.Parallel() diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 2e2589366f..6070f1a8e9 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -120,8 +120,6 @@ type Credential interface { markUpstreamRejected() markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) availabilityStatus() availabilityStatus - rateLimitSnapshots() []rateLimitSnapshot - activeLimitID() string earliestReset() time.Time unavailableError() error diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 977a545e64..56ffbba207 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -544,26 +544,6 @@ func (c *defaultCredential) availabilityStatus() availabilityStatus { return c.state.currentAvailability() } -func (c *defaultCredential) rateLimitSnapshots() []rateLimitSnapshot { - c.stateAccess.RLock() - defer c.stateAccess.RUnlock() - if len(c.state.rateLimitSnapshots) == 0 { - return nil - } - snapshots := make([]rateLimitSnapshot, 0, len(c.state.rateLimitSnapshots)) - for _, snapshot := range c.state.rateLimitSnapshots { - snapshots = append(snapshots, cloneRateLimitSnapshot(snapshot)) - } - sortRateLimitSnapshots(snapshots) - return snapshots -} - -func (c *defaultCredential) activeLimitID() string { - c.stateAccess.RLock() - defer c.stateAccess.RUnlock() - return c.state.activeLimitID -} - func (c *defaultCredential) unavailableError() error { c.stateAccess.RLock() defer c.stateAccess.RUnlock() diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 67b9b2a1b5..222a22e94d 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -702,38 +702,16 @@ func (c *externalCredential) pollUsage() { oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 c.state.upstreamRejectedUntil = time.Time{} - if len(statusResponse.Limits) > 0 { - applyRateLimitSnapshotsLocked(&c.state, statusResponse.Limits, statusResponse.ActiveLimit, statusResponse.PlanWeight, c.state.accountType) - } else { - c.state.fiveHourUtilization = statusResponse.FiveHourUtilization - c.state.weeklyUtilization = statusResponse.WeeklyUtilization - if statusResponse.FiveHourReset > 0 { - c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) - } - if statusResponse.WeeklyReset > 0 { - c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) - } - if statusResponse.PlanWeight > 0 { - c.state.remotePlanWeight = statusResponse.PlanWeight - } + c.state.fiveHourUtilization = statusResponse.FiveHourUtilization + c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) } - if statusResponse.Availability != nil { - switch availabilityState(statusResponse.Availability.State) { - case availabilityStateRateLimited: - c.state.hardRateLimited = true - if statusResponse.Availability.ResetAt > 0 { - c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0) - } - case availabilityStateTemporarilyBlocked: - resetAt := time.Time{} - if statusResponse.Availability.ResetAt > 0 { - resetAt = time.Unix(statusResponse.Availability.ResetAt, 0) - } - c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt) - if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() { - c.state.upstreamRejectedUntil = resetAt - } - } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false @@ -833,38 +811,16 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr oldWeekly := c.state.weeklyUtilization c.state.consecutivePollFailures = 0 c.state.upstreamRejectedUntil = time.Time{} - if len(statusResponse.Limits) > 0 { - applyRateLimitSnapshotsLocked(&c.state, statusResponse.Limits, statusResponse.ActiveLimit, statusResponse.PlanWeight, c.state.accountType) - } else { - c.state.fiveHourUtilization = statusResponse.FiveHourUtilization - c.state.weeklyUtilization = statusResponse.WeeklyUtilization - if statusResponse.FiveHourReset > 0 { - c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) - } - if statusResponse.WeeklyReset > 0 { - c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) - } - if statusResponse.PlanWeight > 0 { - c.state.remotePlanWeight = statusResponse.PlanWeight - } + c.state.fiveHourUtilization = statusResponse.FiveHourUtilization + c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) } - if statusResponse.Availability != nil { - switch availabilityState(statusResponse.Availability.State) { - case availabilityStateRateLimited: - c.state.hardRateLimited = true - if statusResponse.Availability.ResetAt > 0 { - c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0) - } - case availabilityStateTemporarilyBlocked: - resetAt := time.Time{} - if statusResponse.Availability.ResetAt > 0 { - resetAt = time.Unix(statusResponse.Availability.ResetAt, 0) - } - c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt) - if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() { - c.state.upstreamRejectedUntil = resetAt - } - } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false @@ -949,25 +905,6 @@ func (c *externalCredential) availabilityStatus() availabilityStatus { return c.state.currentAvailability() } -func (c *externalCredential) rateLimitSnapshots() []rateLimitSnapshot { - c.stateAccess.RLock() - defer c.stateAccess.RUnlock() - if len(c.state.rateLimitSnapshots) == 0 { - return nil - } - snapshots := make([]rateLimitSnapshot, 0, len(c.state.rateLimitSnapshots)) - for _, snapshot := range c.state.rateLimitSnapshots { - snapshots = append(snapshots, cloneRateLimitSnapshot(snapshot)) - } - sortRateLimitSnapshots(snapshots) - return snapshots -} - -func (c *externalCredential) activeLimitID() string { - c.stateAccess.RLock() - defer c.stateAccess.RUnlock() - return c.state.activeLimitID -} func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() diff --git a/service/ocm/rate_limit_state.go b/service/ocm/rate_limit_state.go index f0e4f34b1b..82a01f5a73 100644 --- a/service/ocm/rate_limit_state.go +++ b/service/ocm/rate_limit_state.go @@ -35,12 +35,6 @@ type availabilityStatus struct { ResetAt time.Time } -type availabilityPayload struct { - State string `json:"state"` - Reason string `json:"reason,omitempty"` - ResetAt int64 `json:"reset_at,omitempty"` -} - func (s availabilityStatus) normalized() availabilityStatus { if s.State == "" { s.State = availabilityStateUnknown @@ -51,20 +45,6 @@ func (s availabilityStatus) normalized() availabilityStatus { return s } -func (s availabilityStatus) toPayload() *availabilityPayload { - s = s.normalized() - payload := &availabilityPayload{ - State: string(s.State), - } - if s.Reason != "" && s.Reason != availabilityReasonUnknown { - payload.Reason = string(s.Reason) - } - if !s.ResetAt.IsZero() { - payload.ResetAt = s.ResetAt.Unix() - } - return payload -} - type creditsSnapshot struct { HasCredits bool `json:"has_credits"` Unlimited bool `json:"unlimited"` diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index 92959e5e84..c6897b92f1 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -5,7 +5,6 @@ import ( "encoding/json" "net/http" "reflect" - "slices" "strconv" "strings" "time" @@ -14,14 +13,11 @@ import ( ) type statusPayload struct { - FiveHourUtilization float64 `json:"five_hour_utilization"` - FiveHourReset int64 `json:"five_hour_reset"` - WeeklyUtilization float64 `json:"weekly_utilization"` - WeeklyReset int64 `json:"weekly_reset"` - PlanWeight float64 `json:"plan_weight"` - ActiveLimit string `json:"active_limit,omitempty"` - Limits []rateLimitSnapshot `json:"limits,omitempty"` - Availability *availabilityPayload `json:"availability,omitempty"` + FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` + WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` + PlanWeight float64 `json:"plan_weight"` } type aggregatedStatus struct { @@ -30,8 +26,6 @@ type aggregatedStatus struct { totalWeight float64 fiveHourReset time.Time weeklyReset time.Time - activeLimitID string - limits []rateLimitSnapshot availability availabilityStatus } @@ -53,24 +47,13 @@ func (s aggregatedStatus) toPayload() statusPayload { WeeklyUtilization: s.weeklyUtilization, WeeklyReset: resetToEpoch(s.weeklyReset), PlanWeight: s.totalWeight, - ActiveLimit: s.activeLimitID, - Limits: slices.Clone(s.limits), - Availability: s.availability.toPayload(), } } type aggregateInput struct { - weight float64 - snapshots []rateLimitSnapshot - activeLimit string availability availabilityStatus } -type snapshotContribution struct { - weight float64 - snapshot rateLimitSnapshot -} - func aggregateAvailability(inputs []aggregateInput) availabilityStatus { if len(inputs) == 0 { return availabilityStatus{ @@ -139,167 +122,6 @@ func aggregateAvailability(inputs []aggregateInput) availabilityStatus { } } -func aggregateRateLimitWindow(contributions []snapshotContribution, selector func(rateLimitSnapshot) *rateLimitWindow) *rateLimitWindow { - var totalWeight float64 - var totalRemaining float64 - var totalWindowMinutes float64 - var totalResetHours float64 - var resetWeight float64 - now := time.Now() - for _, contribution := range contributions { - window := selector(contribution.snapshot) - if window == nil { - continue - } - totalWeight += contribution.weight - totalRemaining += (100 - window.UsedPercent) * contribution.weight - if window.WindowMinutes > 0 { - totalWindowMinutes += float64(window.WindowMinutes) * contribution.weight - } - if window.ResetAt > 0 { - resetTime := time.Unix(window.ResetAt, 0) - hours := resetTime.Sub(now).Hours() - if hours > 0 { - totalResetHours += hours * contribution.weight - resetWeight += contribution.weight - } - } - } - if totalWeight == 0 { - return nil - } - window := &rateLimitWindow{ - UsedPercent: 100 - totalRemaining/totalWeight, - } - if totalWindowMinutes > 0 { - window.WindowMinutes = int64(totalWindowMinutes / totalWeight) - } - if resetWeight > 0 { - window.ResetAt = now.Add(time.Duration(totalResetHours / resetWeight * float64(time.Hour))).Unix() - } - return window -} - -func aggregateCredits(contributions []snapshotContribution) *creditsSnapshot { - var hasCredits bool - var unlimited bool - var balanceTotal float64 - var hasBalance bool - for _, contribution := range contributions { - if contribution.snapshot.Credits == nil { - continue - } - hasCredits = hasCredits || contribution.snapshot.Credits.HasCredits - unlimited = unlimited || contribution.snapshot.Credits.Unlimited - if balance := strings.TrimSpace(contribution.snapshot.Credits.Balance); balance != "" { - value, err := strconv.ParseFloat(balance, 64) - if err == nil { - balanceTotal += value - hasBalance = true - } - } - } - if !hasCredits && !unlimited && !hasBalance { - return nil - } - credits := &creditsSnapshot{ - HasCredits: hasCredits, - Unlimited: unlimited, - } - if hasBalance && !unlimited { - credits.Balance = strconv.FormatFloat(balanceTotal, 'f', -1, 64) - } - return credits -} - -func aggregateSnapshots(inputs []aggregateInput) []rateLimitSnapshot { - grouped := make(map[string][]snapshotContribution) - for _, input := range inputs { - for _, snapshot := range input.snapshots { - limitID := snapshot.LimitID - if limitID == "" { - limitID = "codex" - } - grouped[limitID] = append(grouped[limitID], snapshotContribution{ - weight: input.weight, - snapshot: snapshot, - }) - } - } - if len(grouped) == 0 { - return nil - } - aggregated := make([]rateLimitSnapshot, 0, len(grouped)) - for limitID, contributions := range grouped { - snapshot := defaultRateLimitSnapshot(limitID) - var bestPlanWeight float64 - for _, contribution := range contributions { - if contribution.snapshot.LimitName != "" && snapshot.LimitName == "" { - snapshot.LimitName = contribution.snapshot.LimitName - } - if contribution.snapshot.PlanType != "" && contribution.weight >= bestPlanWeight { - bestPlanWeight = contribution.weight - snapshot.PlanType = contribution.snapshot.PlanType - } - } - snapshot.Primary = aggregateRateLimitWindow(contributions, func(snapshot rateLimitSnapshot) *rateLimitWindow { - return snapshot.Primary - }) - snapshot.Secondary = aggregateRateLimitWindow(contributions, func(snapshot rateLimitSnapshot) *rateLimitWindow { - return snapshot.Secondary - }) - snapshot.Credits = aggregateCredits(contributions) - if snapshot.Primary == nil && snapshot.Secondary == nil && snapshot.Credits == nil { - continue - } - aggregated = append(aggregated, snapshot) - } - sortRateLimitSnapshots(aggregated) - return aggregated -} - -func selectActiveLimitID(inputs []aggregateInput, snapshots []rateLimitSnapshot) string { - if len(snapshots) == 0 { - return "" - } - weights := make(map[string]float64) - for _, input := range inputs { - if input.activeLimit == "" { - continue - } - weights[normalizeStoredLimitID(input.activeLimit)] += input.weight - } - var ( - bestID string - bestWeight float64 - ) - for limitID, weight := range weights { - if weight > bestWeight { - bestID = limitID - bestWeight = weight - } - } - if bestID != "" { - return bestID - } - for _, snapshot := range snapshots { - if snapshot.LimitID == "codex" { - return "codex" - } - } - return snapshots[0].LimitID -} - -func findSnapshotByLimitID(snapshots []rateLimitSnapshot, limitID string) *rateLimitSnapshot { - for _, snapshot := range snapshots { - if snapshot.LimitID == limitID { - snapshotCopy := snapshot - return &snapshotCopy - } - } - return nil -} - func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -420,7 +242,10 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) aggregatedStatus { inputs := make([]aggregateInput, 0, len(provider.allCredentials())) - var totalWeight float64 + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + now := time.Now() + var totalWeightedHoursUntil5hReset, total5hResetWeight float64 + var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 var hasSnapshotData bool for _, credential := range provider.allCredentials() { if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential { @@ -429,61 +254,70 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() { continue } - input := aggregateInput{ - weight: credential.planWeight(), - snapshots: credential.rateLimitSnapshots(), - activeLimit: credential.activeLimitID(), + inputs = append(inputs, aggregateInput{ availability: credential.availabilityStatus(), + }) + if !credential.hasSnapshotData() { + continue } - inputs = append(inputs, input) - if credential.hasSnapshotData() { - hasSnapshotData = true + hasSnapshotData = true + weight := credential.planWeight() + remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 } - totalWeight += input.weight - } - limits := aggregateSnapshots(inputs) - result := aggregatedStatus{ - totalWeight: totalWeight, - availability: aggregateAvailability(inputs), - limits: limits, - activeLimitID: selectActiveLimitID(inputs, limits), - } - if legacy := findSnapshotByLimitID(result.limits, "codex"); legacy != nil { - if legacy.Primary != nil { - result.fiveHourUtilization = legacy.Primary.UsedPercent - if legacy.Primary.ResetAt > 0 { - result.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0) - } + remainingWeekly := credential.weeklyCap() - credential.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 } - if legacy.Secondary != nil { - result.weeklyUtilization = legacy.Secondary.UsedPercent - if legacy.Secondary.ResetAt > 0 { - result.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0) + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight + + fiveHourReset := credential.fiveHourResetTime() + if !fiveHourReset.IsZero() { + hours := fiveHourReset.Sub(now).Hours() + if hours > 0 { + totalWeightedHoursUntil5hReset += hours * weight + total5hResetWeight += weight } } - } else if legacy := findSnapshotByLimitID(result.limits, result.activeLimitID); legacy != nil { - if legacy.Primary != nil { - result.fiveHourUtilization = legacy.Primary.UsedPercent - if legacy.Primary.ResetAt > 0 { - result.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0) + weeklyReset := credential.weeklyResetTime() + if !weeklyReset.IsZero() { + hours := weeklyReset.Sub(now).Hours() + if hours > 0 { + totalWeightedHoursUntilWeeklyReset += hours * weight + totalWeeklyResetWeight += weight } } - if legacy.Secondary != nil { - result.weeklyUtilization = legacy.Secondary.UsedPercent - if legacy.Secondary.ResetAt > 0 { - result.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0) - } + } + availability := aggregateAvailability(inputs) + if totalWeight == 0 { + result := aggregatedStatus{availability: availability} + if !hasSnapshotData { + result.fiveHourUtilization = 100 + result.weeklyUtilization = 100 } + return result } - if len(result.limits) == 0 && !hasSnapshotData { - result.fiveHourUtilization = 100 - result.weeklyUtilization = 100 + result := aggregatedStatus{ + fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight, + weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight: totalWeight, + availability: availability, + } + if total5hResetWeight > 0 { + avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight + result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + } + if totalWeeklyResetWeight > 0 { + avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight + result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour))) } return result } func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) { - status := s.computeAggregatedUtilization(provider, userConfig) for key := range headers { lowerKey := strings.ToLower(key) if lowerKey == "x-codex-active-limit" || @@ -498,51 +332,16 @@ func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentia headers.Del(key) } } - headers.Set("x-codex-active-limit", headerLimitID(status.activeLimitID)) + status := s.computeAggregatedUtilization(provider, userConfig) headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64)) headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64)) if !status.fiveHourReset.IsZero() { headers.Set("x-codex-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) - } else { - headers.Del("x-codex-primary-reset-at") } if !status.weeklyReset.IsZero() { headers.Set("x-codex-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10)) - } else { - headers.Del("x-codex-secondary-reset-at") } if status.totalWeight > 0 { headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } - for _, snapshot := range status.limits { - prefix := "x-" + headerLimitID(snapshot.LimitID) - if snapshot.Primary != nil { - headers.Set(prefix+"-primary-used-percent", strconv.FormatFloat(snapshot.Primary.UsedPercent, 'f', 2, 64)) - if snapshot.Primary.WindowMinutes > 0 { - headers.Set(prefix+"-primary-window-minutes", strconv.FormatInt(snapshot.Primary.WindowMinutes, 10)) - } - if snapshot.Primary.ResetAt > 0 { - headers.Set(prefix+"-primary-reset-at", strconv.FormatInt(snapshot.Primary.ResetAt, 10)) - } - } - if snapshot.Secondary != nil { - headers.Set(prefix+"-secondary-used-percent", strconv.FormatFloat(snapshot.Secondary.UsedPercent, 'f', 2, 64)) - if snapshot.Secondary.WindowMinutes > 0 { - headers.Set(prefix+"-secondary-window-minutes", strconv.FormatInt(snapshot.Secondary.WindowMinutes, 10)) - } - if snapshot.Secondary.ResetAt > 0 { - headers.Set(prefix+"-secondary-reset-at", strconv.FormatInt(snapshot.Secondary.ResetAt, 10)) - } - } - if snapshot.LimitName != "" { - headers.Set(prefix+"-limit-name", snapshot.LimitName) - } - if snapshot.LimitID == "codex" && snapshot.Credits != nil { - headers.Set("x-codex-credits-has-credits", strconv.FormatBool(snapshot.Credits.HasCredits)) - headers.Set("x-codex-credits-unlimited", strconv.FormatBool(snapshot.Credits.Unlimited)) - if snapshot.Credits.Balance != "" { - headers.Set("x-codex-credits-balance", snapshot.Credits.Balance) - } - } - } } diff --git a/service/ocm/service_status_test.go b/service/ocm/service_status_test.go index ba7e9324ad..c3187d5394 100644 --- a/service/ocm/service_status_test.go +++ b/service/ocm/service_status_test.go @@ -26,8 +26,6 @@ type testCredential struct { fiveReset time.Time weeklyReset time.Time availability availabilityStatus - activeLimit string - snapshots []rateLimitSnapshot } func (c *testCredential) tagName() string { return c.tag } @@ -48,10 +46,6 @@ func (c *testCredential) markTemporarilyBlocked(reason availabilityReason, reset c.availability = availabilityStatus{State: availabilityStateTemporarilyBlocked, Reason: reason, ResetAt: resetAt} } func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability } -func (c *testCredential) rateLimitSnapshots() []rateLimitSnapshot { - return slicesCloneSnapshots(c.snapshots) -} -func (c *testCredential) activeLimitID() string { return c.activeLimit } func (c *testCredential) earliestReset() time.Time { return c.fiveReset } func (c *testCredential) unavailableError() error { return nil } func (c *testCredential) getAccessToken() (string, error) { return "", nil } @@ -75,17 +69,6 @@ func (c *testCredential) ocmIsAPIKeyMode() bool func (c *testCredential) ocmGetAccountID() string { return "" } func (c *testCredential) ocmGetBaseURL() string { return "" } -func slicesCloneSnapshots(snapshots []rateLimitSnapshot) []rateLimitSnapshot { - if len(snapshots) == 0 { - return nil - } - cloned := make([]rateLimitSnapshot, 0, len(snapshots)) - for _, snapshot := range snapshots { - cloned = append(cloned, cloneRateLimitSnapshot(snapshot)) - } - return cloned -} - type testProvider struct { credentials []Credential } @@ -104,78 +87,6 @@ func (p *testProvider) pollCredentialIfStale(Credential) {} func (p *testProvider) allCredentials() []Credential { return p.credentials } func (p *testProvider) close() {} -func TestComputeAggregatedUtilizationPreservesStoredSnapshots(t *testing.T) { - t.Parallel() - - service := &Service{} - status := service.computeAggregatedUtilization(&testProvider{credentials: []Credential{ - &testCredential{ - tag: "a", - available: true, - usable: false, - hasData: true, - weight: 1, - activeLimit: "codex", - availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)}, - snapshots: []rateLimitSnapshot{ - { - LimitID: "codex", - Primary: &rateLimitWindow{UsedPercent: 44, WindowMinutes: 300, ResetAt: time.Now().Add(time.Hour).Unix()}, - Secondary: &rateLimitWindow{UsedPercent: 12, WindowMinutes: 10080, ResetAt: time.Now().Add(24 * time.Hour).Unix()}, - }, - }, - }, - }}, nil) - - if status.fiveHourUtilization != 44 || status.weeklyUtilization != 12 { - t.Fatalf("expected stored snapshot utilization, got 5h=%v weekly=%v", status.fiveHourUtilization, status.weeklyUtilization) - } - if status.availability.State != availabilityStateRateLimited { - t.Fatalf("expected rate-limited availability, got %#v", status.availability) - } -} - -func TestRewriteResponseHeadersIncludesAdditionalLimitFamiliesAndCredits(t *testing.T) { - t.Parallel() - - service := &Service{} - headers := make(http.Header) - service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{ - &testCredential{ - tag: "a", - available: true, - usable: true, - hasData: true, - weight: 1, - activeLimit: "codex_other", - availability: availabilityStatus{State: availabilityStateUsable}, - snapshots: []rateLimitSnapshot{ - { - LimitID: "codex", - Primary: &rateLimitWindow{UsedPercent: 20, WindowMinutes: 300, ResetAt: time.Now().Add(time.Hour).Unix()}, - Secondary: &rateLimitWindow{UsedPercent: 40, WindowMinutes: 10080, ResetAt: time.Now().Add(24 * time.Hour).Unix()}, - Credits: &creditsSnapshot{HasCredits: true, Unlimited: false, Balance: "12"}, - }, - { - LimitID: "codex_other", - LimitName: "codex-other", - Primary: &rateLimitWindow{UsedPercent: 60, WindowMinutes: 60, ResetAt: time.Now().Add(30 * time.Minute).Unix()}, - }, - }, - }, - }}, nil) - - if headers.Get("x-codex-active-limit") != "codex-other" { - t.Fatalf("expected active limit header, got %q", headers.Get("x-codex-active-limit")) - } - if headers.Get("x-codex-other-primary-used-percent") == "" { - t.Fatal("expected additional rate-limit family header") - } - if headers.Get("x-codex-credits-balance") != "12" { - t.Fatalf("expected credits balance header, got %q", headers.Get("x-codex-credits-balance")) - } -} - func TestHandleWebSocketErrorEventConnectionLimitDoesNotUseRateLimitPath(t *testing.T) { t.Parallel() @@ -201,7 +112,6 @@ func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *t hasData: true, weight: 1, availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)}, - snapshots: []rateLimitSnapshot{{LimitID: "codex", Primary: &rateLimitWindow{UsedPercent: 80}}}, }, }} diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 4c552cbf5b..6691a49af6 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -333,7 +333,7 @@ func (s *Service) handleWebSocket( var firstRealRequestOnce sync.Once var waitGroup sync.WaitGroup - waitGroup.Add(3) + waitGroup.Add(2) go func() { defer waitGroup.Done() defer session.Close() @@ -344,11 +344,6 @@ func (s *Service) handleWebSocket( defer session.Close() s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint) }() - go func() { - defer waitGroup.Done() - defer session.Close() - s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, session.closed, firstRealRequest, provider, userConfig) - }() waitGroup.Wait() } @@ -552,171 +547,6 @@ func (s *Service) handleWebSocketErrorEvent(data []byte, selectedCredential Cred selectedCredential.markRateLimited(resetAt) } -func writeWebSocketAggregatedStatus(clientConn net.Conn, clientWriteAccess *sync.Mutex, status aggregatedStatus) error { - clientWriteAccess.Lock() - defer clientWriteAccess.Unlock() - for _, data := range buildSyntheticRateLimitsEvents(status) { - if err := wsutil.WriteServerMessage(clientConn, ws.OpText, data); err != nil { - return err - } - } - return nil -} - -func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, sessionClosed <-chan struct{}, firstRealRequest <-chan struct{}, provider credentialProvider, userConfig *option.OCMUser) { - subscription, done, err := s.statusObserver.Subscribe() - if err != nil { - return - } - defer s.statusObserver.UnSubscribe(subscription) - - var last aggregatedStatus - hasLast := false - - for { - select { - case <-ctx.Done(): - return - case <-done: - return - case <-sessionClosed: - return - case <-firstRealRequest: - current := s.computeAggregatedUtilization(provider, userConfig) - err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current) - if err != nil { - return - } - last = current - hasLast = true - firstRealRequest = nil - case <-subscription: - for { - select { - case <-subscription: - default: - goto drained - } - } - drained: - if !hasLast { - continue - } - current := s.computeAggregatedUtilization(provider, userConfig) - if current.equal(last) { - continue - } - last = current - err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current) - if err != nil { - return - } - } - } -} - -func buildSyntheticRateLimitsEvents(status aggregatedStatus) [][]byte { - type rateLimitWindow struct { - UsedPercent float64 `json:"used_percent"` - WindowMinutes int64 `json:"window_minutes,omitempty"` - ResetAt int64 `json:"reset_at,omitempty"` - } - type creditsEvent struct { - HasCredits bool `json:"has_credits"` - Unlimited bool `json:"unlimited"` - Balance string `json:"balance,omitempty"` - } - type eventPayload struct { - Type string `json:"type"` - RateLimits struct { - Primary *rateLimitWindow `json:"primary,omitempty"` - Secondary *rateLimitWindow `json:"secondary,omitempty"` - } `json:"rate_limits"` - MeteredLimitName string `json:"metered_limit_name,omitempty"` - LimitName string `json:"limit_name,omitempty"` - Credits *creditsEvent `json:"credits,omitempty"` - PlanWeight float64 `json:"plan_weight,omitempty"` - } - buildEvent := func(snapshot rateLimitSnapshot, primary *rateLimitWindow, secondary *rateLimitWindow) []byte { - event := eventPayload{ - Type: "codex.rate_limits", - MeteredLimitName: snapshot.LimitID, - LimitName: snapshot.LimitName, - PlanWeight: status.totalWeight, - } - if event.MeteredLimitName == "" { - event.MeteredLimitName = "codex" - } - if event.LimitName == "" { - event.LimitName = strings.ReplaceAll(event.MeteredLimitName, "_", "-") - } - event.RateLimits.Primary = primary - event.RateLimits.Secondary = secondary - if snapshot.Credits != nil { - event.Credits = &creditsEvent{ - HasCredits: snapshot.Credits.HasCredits, - Unlimited: snapshot.Credits.Unlimited, - Balance: snapshot.Credits.Balance, - } - } - data, _ := json.Marshal(event) - return data - } - defaultPrimary := &rateLimitWindow{ - UsedPercent: status.fiveHourUtilization, - ResetAt: resetToEpoch(status.fiveHourReset), - } - defaultSecondary := &rateLimitWindow{ - UsedPercent: status.weeklyUtilization, - ResetAt: resetToEpoch(status.weeklyReset), - } - events := make([][]byte, 0, 1+len(status.limits)) - if snapshot := findSnapshotByLimitID(status.limits, "codex"); snapshot != nil { - primary := defaultPrimary - if snapshot.Primary != nil { - primary = &rateLimitWindow{ - UsedPercent: snapshot.Primary.UsedPercent, - WindowMinutes: snapshot.Primary.WindowMinutes, - ResetAt: snapshot.Primary.ResetAt, - } - } - secondary := defaultSecondary - if snapshot.Secondary != nil { - secondary = &rateLimitWindow{ - UsedPercent: snapshot.Secondary.UsedPercent, - WindowMinutes: snapshot.Secondary.WindowMinutes, - ResetAt: snapshot.Secondary.ResetAt, - } - } - events = append(events, buildEvent(*snapshot, primary, secondary)) - } else { - events = append(events, buildEvent(rateLimitSnapshot{LimitID: "codex", LimitName: "codex"}, defaultPrimary, defaultSecondary)) - } - for _, snapshot := range status.limits { - if snapshot.LimitID == "codex" { - continue - } - var primary *rateLimitWindow - if snapshot.Primary != nil { - primary = &rateLimitWindow{ - UsedPercent: snapshot.Primary.UsedPercent, - WindowMinutes: snapshot.Primary.WindowMinutes, - ResetAt: snapshot.Primary.ResetAt, - } - } - var secondary *rateLimitWindow - if snapshot.Secondary != nil { - secondary = &rateLimitWindow{ - UsedPercent: snapshot.Secondary.UsedPercent, - WindowMinutes: snapshot.Secondary.WindowMinutes, - ResetAt: snapshot.Secondary.ResetAt, - } - } - events = append(events, buildEvent(snapshot, primary, secondary)) - } - return events -} - func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) { var streamEvent responses.ResponseStreamEventUnion if json.Unmarshal(data, &streamEvent) != nil { From a87a2b0e2b96268538d9528f9935b003dd2db534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Mar 2026 02:00:27 +0800 Subject: [PATCH 91/96] ocm: log think level --- service/ocm/request_log.go | 62 +++++++++++++++++++++++++ service/ocm/request_log_test.go | 80 ++++++++++++++++++++++++++++++++ service/ocm/service_handler.go | 29 ++---------- service/ocm/service_websocket.go | 20 ++------ 4 files changed, 151 insertions(+), 40 deletions(-) create mode 100644 service/ocm/request_log.go create mode 100644 service/ocm/request_log_test.go diff --git a/service/ocm/request_log.go b/service/ocm/request_log.go new file mode 100644 index 0000000000..2459381acc --- /dev/null +++ b/service/ocm/request_log.go @@ -0,0 +1,62 @@ +package ocm + +import "encoding/json" + +type requestLogMetadata struct { + Model string + ServiceTier string + ReasoningEffort string +} + +type requestLogReasoning struct { + Effort string `json:"effort"` +} + +type requestLogPayload struct { + Model string `json:"model"` + ServiceTier string `json:"service_tier"` + Reasoning *requestLogReasoning `json:"reasoning"` + ReasoningEffort string `json:"reasoning_effort"` +} + +func (p requestLogPayload) metadata() requestLogMetadata { + metadata := requestLogMetadata{ + Model: p.Model, + ServiceTier: p.ServiceTier, + } + if p.Reasoning != nil { + metadata.ReasoningEffort = p.Reasoning.Effort + } + if metadata.ReasoningEffort == "" { + metadata.ReasoningEffort = p.ReasoningEffort + } + return metadata +} + +func parseRequestLogMetadata(data []byte) requestLogMetadata { + var payload requestLogPayload + if json.Unmarshal(data, &payload) != nil { + return requestLogMetadata{} + } + return payload.metadata() +} + +func buildAssignedCredentialLogParts(credentialTag string, sessionID string, username string, metadata requestLogMetadata) []any { + logParts := []any{"assigned credential ", credentialTag} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + if metadata.Model != "" { + logParts = append(logParts, ", model=", metadata.Model) + } + if metadata.ReasoningEffort != "" { + logParts = append(logParts, ", think=", metadata.ReasoningEffort) + } + if metadata.ServiceTier == "priority" { + logParts = append(logParts, ", fast") + } + return logParts +} diff --git a/service/ocm/request_log_test.go b/service/ocm/request_log_test.go new file mode 100644 index 0000000000..27a2a287e0 --- /dev/null +++ b/service/ocm/request_log_test.go @@ -0,0 +1,80 @@ +package ocm + +import ( + "strings" + "testing" + + F "github.com/sagernet/sing/common/format" +) + +func TestParseRequestLogMetadata(t *testing.T) { + t.Parallel() + + metadata := parseRequestLogMetadata([]byte(`{ + "model":"gpt-5.4", + "service_tier":"priority", + "reasoning":{"effort":"xhigh"} + }`)) + + if metadata.Model != "gpt-5.4" { + t.Fatalf("expected model gpt-5.4, got %q", metadata.Model) + } + if metadata.ServiceTier != "priority" { + t.Fatalf("expected priority service tier, got %q", metadata.ServiceTier) + } + if metadata.ReasoningEffort != "xhigh" { + t.Fatalf("expected xhigh reasoning effort, got %q", metadata.ReasoningEffort) + } +} + +func TestParseRequestLogMetadataFallsBackToTopLevelReasoningEffort(t *testing.T) { + t.Parallel() + + metadata := parseRequestLogMetadata([]byte(`{ + "model":"gpt-5.4", + "reasoning_effort":"high" + }`)) + + if metadata.ReasoningEffort != "high" { + t.Fatalf("expected high reasoning effort, got %q", metadata.ReasoningEffort) + } +} + +func TestBuildAssignedCredentialLogPartsIncludesThinkLevel(t *testing.T) { + t.Parallel() + + message := F.ToString(buildAssignedCredentialLogParts("a", "session-1", "alice", requestLogMetadata{ + Model: "gpt-5.4", + ServiceTier: "priority", + ReasoningEffort: "xhigh", + })...) + + for _, fragment := range []string{ + "assigned credential a", + "for session session-1", + "by user alice", + "model=gpt-5.4", + "think=xhigh", + "fast", + } { + if !strings.Contains(message, fragment) { + t.Fatalf("expected %q in %q", fragment, message) + } + } +} + +func TestParseWebSocketResponseCreateRequestIncludesThinkLevel(t *testing.T) { + t.Parallel() + + request, ok := parseWebSocketResponseCreateRequest([]byte(`{ + "type":"response.create", + "model":"gpt-5.4", + "reasoning":{"effort":"xhigh"} + }`)) + if !ok { + t.Fatal("expected websocket response.create request to parse") + } + if request.metadata().ReasoningEffort != "xhigh" { + t.Fatalf("expected xhigh reasoning effort, got %q", request.metadata().ReasoningEffort) + } +} diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index cfb34e15b0..3c10d01f99 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -170,9 +170,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Read body for model extraction and retry buffer when JSON replay is useful. var bodyBytes []byte + var requestMetadata requestLogMetadata var requestModel string - var requestServiceTier string - if r.Body != nil && (shouldTrackUsage || canRetryRequest) { + if r.Body != nil && (isNew || shouldTrackUsage || canRetryRequest) { mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) if isJSONRequest { @@ -182,33 +182,14 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") return } - var request struct { - Model string `json:"model"` - ServiceTier string `json:"service_tier"` - } - if json.Unmarshal(bodyBytes, &request) == nil { - requestModel = request.Model - requestServiceTier = request.ServiceTier - } + requestMetadata = parseRequestLogMetadata(bodyBytes) + requestModel = requestMetadata.Model r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } } if isNew { - logParts := []any{"assigned credential ", selectedCredential.tagName()} - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if username != "" { - logParts = append(logParts, " by user ", username) - } - if requestModel != "" { - logParts = append(logParts, ", model=", requestModel) - } - if requestServiceTier == "priority" { - logParts = append(logParts, ", fast") - } - s.logger.DebugContext(ctx, logParts...) + s.logger.DebugContext(ctx, buildAssignedCredentialLogParts(selectedCredential.tagName(), sessionID, username, requestMetadata)...) } requestContext := selectedCredential.wrapRequestContext(ctx) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 6691a49af6..cdcdb22b89 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -50,10 +50,9 @@ func (s *webSocketSession) Close() { } type webSocketResponseCreateRequest struct { - Type string `json:"type"` - Model string `json:"model"` - ServiceTier string `json:"service_tier"` - Generate *bool `json:"generate"` + requestLogPayload + Type string `json:"type"` + Generate *bool `json:"generate"` } func parseWebSocketResponseCreateRequest(data []byte) (webSocketResponseCreateRequest, bool) { @@ -364,18 +363,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn isWarmup := request.isWarmup() if !isWarmup && isNew && !logged { logged = true - logParts := []any{"assigned credential ", selectedCredential.tagName()} - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if username != "" { - logParts = append(logParts, " by user ", username) - } - logParts = append(logParts, ", model=", request.Model) - if request.ServiceTier == "priority" { - logParts = append(logParts, ", fast") - } - s.logger.DebugContext(ctx, logParts...) + s.logger.DebugContext(ctx, buildAssignedCredentialLogParts(selectedCredential.tagName(), sessionID, username, request.metadata())...) } if !isWarmup && selectedCredential.usageTrackerOrNil() != nil { select { From cf11e0e74a530ebc4c1e0106e05210803c59d70f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Mar 2026 11:20:24 +0800 Subject: [PATCH 92/96] Reuse SDK JSON types in ccm and ocm --- service/ccm/service.go | 19 ++--- service/ccm/service_handler.go | 73 +++++++++++------- service/ccm/service_json_test.go | 115 +++++++++++++++++++++++++++++ service/ocm/request_log.go | 62 +++++++++++----- service/ocm/request_log_test.go | 50 ++++++++++++- service/ocm/service.go | 12 +-- service/ocm/service_handler.go | 2 +- service/ocm/service_json_test.go | 78 +++++++++++++++++++ service/ocm/service_status_test.go | 44 ++++++----- service/ocm/service_websocket.go | 43 +++++++++-- 10 files changed, 404 insertions(+), 94 deletions(-) create mode 100644 service/ccm/service_json_test.go create mode 100644 service/ocm/service_json_test.go diff --git a/service/ccm/service.go b/service/ccm/service.go index a95e3175a2..281e2ffb7a 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -21,6 +21,8 @@ import ( "github.com/sagernet/sing/common/observable" aTLS "github.com/sagernet/sing/common/tls" + "github.com/anthropics/anthropic-sdk-go" + anthropicconstant "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/go-chi/chi/v5" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -32,23 +34,12 @@ func RegisterService(registry *boxService.Registry) { boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService) } -type errorResponse struct { - Type string `json:"type"` - Error errorDetails `json:"error"` - RequestID string `json:"request_id,omitempty"` -} - -type errorDetails struct { - Type string `json:"type"` - Message string `json:"message"` -} - func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) - json.NewEncoder(w).Encode(errorResponse{ - Type: "error", - Error: errorDetails{ + json.NewEncoder(w).Encode(anthropic.ErrorResponse{ + Type: anthropicconstant.Error("").Default(), + Error: anthropic.ErrorObjectUnion{ Type: errorType, Message: message, }, diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index dd8c71d2d5..aca9dc647e 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -30,6 +30,12 @@ const ( weeklyWindowMinutes = weeklyWindowSeconds / 60 ) +type ccmRequestMetadata struct { + Model string + MessagesCount int + SessionID string +} + func isExtendedContextRequest(betaHeader string) bool { for _, feature := range strings.Split(betaHeader, ",") { if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { @@ -69,7 +75,7 @@ func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { } } -// extractCCMSessionID extracts the session ID from the request body's metadata.user_id field. +// extractCCMSessionID extracts the session ID from the metadata.user_id field. // // Claude Code >= 2.1.78 (@anthropic-ai/claude-code) encodes user_id as: // @@ -83,22 +89,12 @@ func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { // // ref: cli.js qs() — old metadata constructor // -// Returns ("", nil) when body has no metadata.user_id (non-message endpoints). +// Returns ("", nil) when userID is empty. // Returns error when user_id is present but in an unrecognized format. -func extractCCMSessionID(bodyBytes []byte) (string, error) { - var body struct { - Metadata *struct { - UserID string `json:"user_id"` - } `json:"metadata"` - } - err := json.Unmarshal(bodyBytes, &body) - if err != nil { - return "", nil - } - if body.Metadata == nil || body.Metadata.UserID == "" { +func extractCCMSessionID(userID string) (string, error) { + if userID == "" { return "", nil } - userID := body.Metadata.UserID // v2.1.78+ JSON object format var userIDObject struct { @@ -117,6 +113,40 @@ func extractCCMSessionID(bodyBytes []byte) (string, error) { return "", E.New("unrecognized metadata.user_id format: ", userID) } +func extractCCMRequestMetadata(path string, bodyBytes []byte) (ccmRequestMetadata, error) { + switch path { + case "/v1/messages": + var request anthropic.MessageNewParams + if json.Unmarshal(bodyBytes, &request) != nil { + return ccmRequestMetadata{}, nil + } + + metadata := ccmRequestMetadata{ + Model: string(request.Model), + MessagesCount: len(request.Messages), + } + if request.Metadata.UserID.Valid() { + sessionID, err := extractCCMSessionID(request.Metadata.UserID.Value) + if err != nil { + return ccmRequestMetadata{}, err + } + metadata.SessionID = sessionID + } + return metadata, nil + case "/v1/messages/count_tokens": + var request anthropic.MessageCountTokensParams + if json.Unmarshal(bodyBytes, &request) != nil { + return ccmRequestMetadata{}, nil + } + return ccmRequestMetadata{ + Model: string(request.Model), + MessagesCount: len(request.Messages), + }, nil + default: + return ccmRequestMetadata{}, nil + } +} + func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := log.ContextWithNewID(r.Context()) if r.URL.Path == "/ccm/v1/status" { @@ -178,22 +208,15 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - var request struct { - Model string `json:"model"` - Messages []anthropic.MessageParam `json:"messages"` - } - err = json.Unmarshal(bodyBytes, &request) - if err == nil { - requestModel = request.Model - messagesCount = len(request.Messages) - } - - sessionID, err = extractCCMSessionID(bodyBytes) + requestMetadata, err := extractCCMRequestMetadata(r.URL.Path, bodyBytes) if err != nil { s.logger.ErrorContext(ctx, "invalid metadata format: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "invalid metadata format") return } + requestModel = requestMetadata.Model + messagesCount = requestMetadata.MessagesCount + sessionID = requestMetadata.SessionID r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } diff --git a/service/ccm/service_json_test.go b/service/ccm/service_json_test.go new file mode 100644 index 0000000000..36cccd3412 --- /dev/null +++ b/service/ccm/service_json_test.go @@ -0,0 +1,115 @@ +package ccm + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/anthropics/anthropic-sdk-go" +) + +func TestWriteJSONErrorUsesAnthropicShape(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/v1/messages", nil) + request.Header.Set("Request-Id", "req_123") + + writeJSONError(recorder, request, http.StatusBadRequest, "invalid_request_error", "broken") + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", recorder.Code) + } + + var body anthropic.ErrorResponse + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatal(err) + } + + if string(body.Type) != "error" { + t.Fatalf("expected error type, got %q", body.Type) + } + if body.RequestID != "req_123" { + t.Fatalf("expected req_123 request ID, got %q", body.RequestID) + } + if body.Error.Type != "invalid_request_error" { + t.Fatalf("expected invalid_request_error, got %q", body.Error.Type) + } + if body.Error.Message != "broken" { + t.Fatalf("expected broken message, got %q", body.Error.Message) + } +} + +func TestExtractCCMRequestMetadataFromMessagesJSONSession(t *testing.T) { + t.Parallel() + + metadata, err := extractCCMRequestMetadata("/v1/messages", []byte(`{ + "model":"claude-sonnet-4-5", + "max_tokens":1, + "messages":[{"role":"user","content":"hello"}], + "metadata":{"user_id":"{\"session_id\":\"session-1\"}"} + }`)) + if err != nil { + t.Fatal(err) + } + if metadata.Model != "claude-sonnet-4-5" { + t.Fatalf("expected model, got %#v", metadata) + } + if metadata.MessagesCount != 1 { + t.Fatalf("expected one message, got %#v", metadata) + } + if metadata.SessionID != "session-1" { + t.Fatalf("expected session-1, got %#v", metadata) + } +} + +func TestExtractCCMRequestMetadataFromMessagesLegacySession(t *testing.T) { + t.Parallel() + + metadata, err := extractCCMRequestMetadata("/v1/messages", []byte(`{ + "model":"claude-sonnet-4-5", + "max_tokens":1, + "messages":[{"role":"user","content":"hello"}], + "metadata":{"user_id":"user_device_account_account_session_session-legacy"} + }`)) + if err != nil { + t.Fatal(err) + } + if metadata.SessionID != "session-legacy" { + t.Fatalf("expected session-legacy, got %#v", metadata) + } +} + +func TestExtractCCMRequestMetadataFromCountTokens(t *testing.T) { + t.Parallel() + + metadata, err := extractCCMRequestMetadata("/v1/messages/count_tokens", []byte(`{ + "model":"claude-sonnet-4-5", + "messages":[{"role":"user","content":"hello"}] + }`)) + if err != nil { + t.Fatal(err) + } + if metadata.Model != "claude-sonnet-4-5" { + t.Fatalf("expected model, got %#v", metadata) + } + if metadata.MessagesCount != 1 { + t.Fatalf("expected one message, got %#v", metadata) + } + if metadata.SessionID != "" { + t.Fatalf("expected empty session ID, got %#v", metadata) + } +} + +func TestExtractCCMRequestMetadataIgnoresUnsupportedPath(t *testing.T) { + t.Parallel() + + metadata, err := extractCCMRequestMetadata("/v1/models", []byte(`{"model":"claude"}`)) + if err != nil { + t.Fatal(err) + } + if metadata != (ccmRequestMetadata{}) { + t.Fatalf("expected zero metadata, got %#v", metadata) + } +} diff --git a/service/ocm/request_log.go b/service/ocm/request_log.go index 2459381acc..5422a27bb3 100644 --- a/service/ocm/request_log.go +++ b/service/ocm/request_log.go @@ -1,6 +1,12 @@ package ocm -import "encoding/json" +import ( + "encoding/json" + "strings" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" +) type requestLogMetadata struct { Model string @@ -8,37 +14,57 @@ type requestLogMetadata struct { ReasoningEffort string } -type requestLogReasoning struct { - Effort string `json:"effort"` +type legacyReasoningEffortPayload struct { + ReasoningEffort string `json:"reasoning_effort"` } -type requestLogPayload struct { - Model string `json:"model"` - ServiceTier string `json:"service_tier"` - Reasoning *requestLogReasoning `json:"reasoning"` - ReasoningEffort string `json:"reasoning_effort"` +func requestLogMetadataFromChatCompletionRequest(request openai.ChatCompletionNewParams) requestLogMetadata { + return requestLogMetadata{ + Model: string(request.Model), + ServiceTier: string(request.ServiceTier), + ReasoningEffort: string(request.ReasoningEffort), + } } -func (p requestLogPayload) metadata() requestLogMetadata { +func requestLogMetadataFromResponsesRequest(request responses.ResponseNewParams, legacyReasoningEffort string) requestLogMetadata { metadata := requestLogMetadata{ - Model: p.Model, - ServiceTier: p.ServiceTier, + Model: string(request.Model), + ServiceTier: string(request.ServiceTier), } - if p.Reasoning != nil { - metadata.ReasoningEffort = p.Reasoning.Effort + if request.Reasoning.Effort != "" { + metadata.ReasoningEffort = string(request.Reasoning.Effort) } if metadata.ReasoningEffort == "" { - metadata.ReasoningEffort = p.ReasoningEffort + metadata.ReasoningEffort = legacyReasoningEffort } return metadata } -func parseRequestLogMetadata(data []byte) requestLogMetadata { - var payload requestLogPayload - if json.Unmarshal(data, &payload) != nil { +func parseLegacyReasoningEffort(data []byte) string { + var legacy legacyReasoningEffortPayload + if json.Unmarshal(data, &legacy) != nil { + return "" + } + return legacy.ReasoningEffort +} + +func parseRequestLogMetadata(path string, data []byte) requestLogMetadata { + switch { + case path == "/v1/chat/completions": + var request openai.ChatCompletionNewParams + if json.Unmarshal(data, &request) != nil { + return requestLogMetadata{} + } + return requestLogMetadataFromChatCompletionRequest(request) + case strings.HasPrefix(path, "/v1/responses"): + var request responses.ResponseNewParams + if json.Unmarshal(data, &request) != nil { + return requestLogMetadata{} + } + return requestLogMetadataFromResponsesRequest(request, parseLegacyReasoningEffort(data)) + default: return requestLogMetadata{} } - return payload.metadata() } func buildAssignedCredentialLogParts(credentialTag string, sessionID string, username string, metadata requestLogMetadata) []any { diff --git a/service/ocm/request_log_test.go b/service/ocm/request_log_test.go index 27a2a287e0..79ad562321 100644 --- a/service/ocm/request_log_test.go +++ b/service/ocm/request_log_test.go @@ -10,7 +10,7 @@ import ( func TestParseRequestLogMetadata(t *testing.T) { t.Parallel() - metadata := parseRequestLogMetadata([]byte(`{ + metadata := parseRequestLogMetadata("/v1/responses", []byte(`{ "model":"gpt-5.4", "service_tier":"priority", "reasoning":{"effort":"xhigh"} @@ -30,7 +30,7 @@ func TestParseRequestLogMetadata(t *testing.T) { func TestParseRequestLogMetadataFallsBackToTopLevelReasoningEffort(t *testing.T) { t.Parallel() - metadata := parseRequestLogMetadata([]byte(`{ + metadata := parseRequestLogMetadata("/v1/responses", []byte(`{ "model":"gpt-5.4", "reasoning_effort":"high" }`)) @@ -40,6 +40,36 @@ func TestParseRequestLogMetadataFallsBackToTopLevelReasoningEffort(t *testing.T) } } +func TestParseRequestLogMetadataFromChatCompletions(t *testing.T) { + t.Parallel() + + metadata := parseRequestLogMetadata("/v1/chat/completions", []byte(`{ + "model":"gpt-5.4", + "service_tier":"priority", + "reasoning_effort":"xhigh", + "messages":[{"role":"user","content":"hi"}] + }`)) + + if metadata.Model != "gpt-5.4" { + t.Fatalf("expected model gpt-5.4, got %q", metadata.Model) + } + if metadata.ServiceTier != "priority" { + t.Fatalf("expected priority service tier, got %q", metadata.ServiceTier) + } + if metadata.ReasoningEffort != "xhigh" { + t.Fatalf("expected xhigh reasoning effort, got %q", metadata.ReasoningEffort) + } +} + +func TestParseRequestLogMetadataIgnoresUnsupportedPath(t *testing.T) { + t.Parallel() + + metadata := parseRequestLogMetadata("/v1/files", []byte(`{"model":"gpt-5.4"}`)) + if metadata != (requestLogMetadata{}) { + t.Fatalf("expected zero metadata, got %#v", metadata) + } +} + func TestBuildAssignedCredentialLogPartsIncludesThinkLevel(t *testing.T) { t.Parallel() @@ -78,3 +108,19 @@ func TestParseWebSocketResponseCreateRequestIncludesThinkLevel(t *testing.T) { t.Fatalf("expected xhigh reasoning effort, got %q", request.metadata().ReasoningEffort) } } + +func TestParseWebSocketResponseCreateRequestFallsBackToLegacyReasoningEffort(t *testing.T) { + t.Parallel() + + request, ok := parseWebSocketResponseCreateRequest([]byte(`{ + "type":"response.create", + "model":"gpt-5.4", + "reasoning_effort":"high" + }`)) + if !ok { + t.Fatal("expected websocket response.create request to parse") + } + if request.metadata().ReasoningEffort != "high" { + t.Fatalf("expected high reasoning effort, got %q", request.metadata().ReasoningEffort) + } +} diff --git a/service/ocm/service.go b/service/ocm/service.go index 76a63f94bd..abce7b1aac 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -22,6 +22,7 @@ import ( aTLS "github.com/sagernet/sing/common/tls" "github.com/go-chi/chi/v5" + openaishared "github.com/openai/openai-go/v3/shared" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -31,13 +32,7 @@ func RegisterService(registry *boxService.Registry) { } type errorResponse struct { - Error errorDetails `json:"error"` -} - -type errorDetails struct { - Type string `json:"type"` - Code string `json:"code,omitempty"` - Message string `json:"message"` + Error openaishared.ErrorObject `json:"error"` } func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) { @@ -49,10 +44,11 @@ func writeJSONErrorWithCode(w http.ResponseWriter, r *http.Request, statusCode i w.WriteHeader(statusCode) json.NewEncoder(w).Encode(errorResponse{ - Error: errorDetails{ + Error: openaishared.ErrorObject{ Type: errorType, Code: errorCode, Message: message, + Param: "", }, }) } diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 3c10d01f99..f5d6e97f0b 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -182,7 +182,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") return } - requestMetadata = parseRequestLogMetadata(bodyBytes) + requestMetadata = parseRequestLogMetadata(path, bodyBytes) requestModel = requestMetadata.Model r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } diff --git a/service/ocm/service_json_test.go b/service/ocm/service_json_test.go new file mode 100644 index 0000000000..930002bab3 --- /dev/null +++ b/service/ocm/service_json_test.go @@ -0,0 +1,78 @@ +package ocm + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" +) + +func TestWriteJSONErrorIncludesSDKErrorFields(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/v1/responses", nil) + + writeJSONErrorWithCode(recorder, request, http.StatusBadRequest, "invalid_request_error", "bad_thing", "broken") + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", recorder.Code) + } + + var body struct { + Error map[string]any `json:"error"` + } + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatal(err) + } + + for _, key := range []string{"type", "message", "code", "param"} { + if _, exists := body.Error[key]; !exists { + t.Fatalf("expected error.%s to be present, got %#v", key, body.Error) + } + } + if body.Error["type"] != "invalid_request_error" { + t.Fatalf("expected invalid_request_error type, got %#v", body.Error["type"]) + } + if body.Error["message"] != "broken" { + t.Fatalf("expected broken message, got %#v", body.Error["message"]) + } + if body.Error["code"] != "bad_thing" { + t.Fatalf("expected bad_thing code, got %#v", body.Error["code"]) + } + if body.Error["param"] != "" { + t.Fatalf("expected empty param, got %#v", body.Error["param"]) + } +} + +func TestHandleWebSocketErrorEventRateLimitTracksHeadersAndReset(t *testing.T) { + t.Parallel() + + credential := &testCredential{availability: availabilityStatus{State: availabilityStateUsable}} + service := &Service{} + resetAt := time.Now().Add(time.Minute).Unix() + + service.handleWebSocketErrorEvent([]byte(`{ + "type":"error", + "status_code":429, + "headers":{ + "x-codex-active-limit":"codex", + "x-codex-primary-reset-at":"`+strconv.FormatInt(resetAt, 10)+`" + }, + "error":{ + "type":"rate_limit_error", + "code":"rate_limited", + "message":"limit hit", + "param":"" + } + }`), credential) + + if credential.lastHeaders.Get("x-codex-active-limit") != "codex" { + t.Fatalf("expected headers to be forwarded, got %#v", credential.lastHeaders) + } + if credential.rateLimitedAt.Unix() != resetAt { + t.Fatalf("expected reset %d, got %d", resetAt, credential.rateLimitedAt.Unix()) + } +} diff --git a/service/ocm/service_status_test.go b/service/ocm/service_status_test.go index c3187d5394..6edd51e289 100644 --- a/service/ocm/service_status_test.go +++ b/service/ocm/service_status_test.go @@ -13,19 +13,21 @@ import ( ) type testCredential struct { - tag string - external bool - available bool - usable bool - hasData bool - fiveHour float64 - weekly float64 - fiveHourCapV float64 - weeklyCapV float64 - weight float64 - fiveReset time.Time - weeklyReset time.Time - availability availabilityStatus + tag string + external bool + available bool + usable bool + hasData bool + fiveHour float64 + weekly float64 + fiveHourCapV float64 + weeklyCapV float64 + weight float64 + fiveReset time.Time + weeklyReset time.Time + availability availabilityStatus + lastHeaders http.Header + rateLimitedAt time.Time } func (c *testCredential) tagName() string { return c.tag } @@ -40,19 +42,23 @@ func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV } func (c *testCredential) planWeight() float64 { return c.weight } func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset } func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset } -func (c *testCredential) markRateLimited(time.Time) {} -func (c *testCredential) markUpstreamRejected() {} +func (c *testCredential) markRateLimited(resetAt time.Time) { + c.rateLimitedAt = resetAt +} +func (c *testCredential) markUpstreamRejected() {} func (c *testCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) { c.availability = availabilityStatus{State: availabilityStateTemporarilyBlocked, Reason: reason, ResetAt: resetAt} } func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability } -func (c *testCredential) earliestReset() time.Time { return c.fiveReset } -func (c *testCredential) unavailableError() error { return nil } -func (c *testCredential) getAccessToken() (string, error) { return "", nil } +func (c *testCredential) earliestReset() time.Time { return c.fiveReset } +func (c *testCredential) unavailableError() error { return nil } +func (c *testCredential) getAccessToken() (string, error) { return "", nil } func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) { return nil, nil } -func (c *testCredential) updateStateFromHeaders(http.Header) {} +func (c *testCredential) updateStateFromHeaders(headers http.Header) { + c.lastHeaders = headers.Clone() +} func (c *testCredential) wrapRequestContext(context.Context) *credentialRequestContext { return nil } func (c *testCredential) interruptConnections() {} func (c *testCredential) setOnBecameUnusable(func()) {} diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index cdcdb22b89..3d762dca38 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -23,6 +23,8 @@ import ( "github.com/sagernet/ws/wsutil" "github.com/openai/openai-go/v3/responses" + openaishared "github.com/openai/openai-go/v3/shared" + openaiconstant "github.com/openai/openai-go/v3/shared/constant" ) type webSocketSession struct { @@ -50,17 +52,46 @@ func (s *webSocketSession) Close() { } type webSocketResponseCreateRequest struct { - requestLogPayload + responses.ResponseNewParams + legacyReasoningEffortPayload Type string `json:"type"` Generate *bool `json:"generate"` } +func (r *webSocketResponseCreateRequest) UnmarshalJSON(data []byte) error { + type requestEnvelope struct { + Type string `json:"type"` + Generate *bool `json:"generate"` + legacyReasoningEffortPayload + } + + var envelope requestEnvelope + if err := json.Unmarshal(data, &envelope); err != nil { + return err + } + + var params responses.ResponseNewParams + if err := json.Unmarshal(data, ¶ms); err != nil { + return err + } + + r.ResponseNewParams = params + r.legacyReasoningEffortPayload = envelope.legacyReasoningEffortPayload + r.Type = envelope.Type + r.Generate = envelope.Generate + return nil +} + +func (r webSocketResponseCreateRequest) metadata() requestLogMetadata { + return requestLogMetadataFromResponsesRequest(r.ResponseNewParams, r.ReasoningEffort) +} + func parseWebSocketResponseCreateRequest(data []byte) (webSocketResponseCreateRequest, bool) { var request webSocketResponseCreateRequest if json.Unmarshal(data, &request) != nil { return webSocketResponseCreateRequest{}, false } - if request.Type != "response.create" || request.Model == "" { + if request.Type != string(openaiconstant.ResponseCreate("").Default()) || request.Model == "" { return webSocketResponseCreateRequest{}, false } return request, true @@ -509,11 +540,9 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential func (s *Service) handleWebSocketErrorEvent(data []byte, selectedCredential Credential) { var errorEvent struct { - StatusCode int `json:"status_code"` - Headers map[string]string `json:"headers"` - Error struct { - Code string `json:"code"` - } `json:"error"` + StatusCode int `json:"status_code"` + Headers map[string]string `json:"headers"` + Error openaishared.ErrorObject `json:"error"` } err := json.Unmarshal(data, &errorEvent) if err != nil { From e7478ce9476711657313947dee15f13cc7240e64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Mar 2026 16:58:52 +0800 Subject: [PATCH 93/96] fix(ccm): mark credential unavailable on refresh failure, handle poll 401 tryRefreshCredentials now returns error and calls markCredentialsUnavailable when lock acquisition or file write permission fails. getAccessToken propagates the error instead of silently returning the expired token. pollUsage handles 401 by attempting auth recovery and marking unavailable on failure. All credential error paths now use Error log level instead of Debug. Startup checks expired tokens eagerly via tryRefreshCredentials. --- service/ccm/credential_default.go | 72 ++++++++++++++++++-------- service/ccm/credential_default_test.go | 44 +++++++++++++--- service/ccm/credential_external.go | 10 ++-- service/ccm/credential_file.go | 2 +- 4 files changed, 94 insertions(+), 34 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index bf8404f880..4f5ecefd1f 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -158,11 +158,15 @@ func (c *defaultCredential) start() error { } err = c.ensureCredentialWatcher() if err != nil { - c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + c.logger.Error("start credential watcher for ", c.tag, ": ", err) } err = c.reloadCredentials(true) if err != nil { - c.logger.Warn("initial credential load for ", c.tag, ": ", err) + c.logger.Error("initial credential load for ", c.tag, ": ", err) + } + if c.credentials != nil && c.credentials.needsRefresh() && + slices.Contains(c.credentials.Scopes, "user:inference") { + c.tryRefreshCredentials(false) } if c.usageTracker != nil { err = c.usageTracker.Load() @@ -240,7 +244,10 @@ func (c *defaultCredential) getAccessToken() (string, error) { if !currentCredentials.needsRefresh() || !slices.Contains(currentCredentials.Scopes, "user:inference") { return currentCredentials.AccessToken, nil } - c.tryRefreshCredentials(false) + refreshErr := c.tryRefreshCredentials(false) + if refreshErr != nil { + return "", refreshErr + } c.access.RLock() defer c.access.RUnlock() if c.credentials != nil && c.credentials.AccessToken != "" { @@ -354,14 +361,14 @@ func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, return credentials.needsRefresh() } -func (c *defaultCredential) tryRefreshCredentials(force bool) bool { +func (c *defaultCredential) tryRefreshCredentials(force bool) error { latestCredentials, err := platformReadCredentials(c.credentialPath) if err == nil && latestCredentials != nil { c.absorbCredentials(latestCredentials) } currentCredentials := c.currentCredentials() if !c.shouldAttemptRefresh(currentCredentials, force) { - return false + return nil } acquireLock := c.acquireLock if acquireLock == nil { @@ -369,8 +376,10 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool { } release, err := acquireLock(c.configDir) if err != nil { - c.logger.Debug("acquire credential lock for ", c.tag, ": ", err) - return false + lockErr := E.Cause(err, "acquire credential lock for ", c.tag) + c.logger.Error(lockErr) + c.markCredentialsUnavailable(lockErr) + return lockErr } defer release() @@ -382,30 +391,35 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool { currentCredentials = c.currentCredentials() } if !c.shouldAttemptRefresh(currentCredentials, force) { - return false + return nil } - if err := platformCanWriteCredentials(c.credentialPath); err != nil { - c.logger.Debug("credential file not writable for ", c.tag, ": ", err) - return false + err = platformCanWriteCredentials(c.credentialPath) + if err != nil { + writeErr := E.Cause(err, "credential file not writable for ", c.tag) + c.logger.Error(writeErr) + c.markCredentialsUnavailable(writeErr) + return writeErr } baseCredentials := cloneCredentials(currentCredentials) refreshResult, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, currentCredentials) if err != nil { if retryDelay != 0 { - c.logger.Debug("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err) + c.logger.Error("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err) } else { - c.logger.Debug("refresh token for ", c.tag, ": ", err) + c.logger.Error("refresh token for ", c.tag, ": ", err) } latestCredentials, readErr := platformReadCredentials(c.credentialPath) if readErr == nil && latestCredentials != nil { c.absorbCredentials(latestCredentials) - return latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh()) + if latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh()) { + return nil + } } - return false + return E.Cause(err, "refresh token for ", c.tag) } if refreshResult == nil || refreshResult.Credentials == nil { - return false + return E.New("refresh token for ", c.tag, ": empty result") } refreshedCredentials := cloneCredentials(refreshResult.Credentials) @@ -419,7 +433,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool { if c.needsProfileHydration() { profileSnapshot, profileErr := c.fetchProfileSnapshot(c.forwardHTTPClient, refreshedCredentials.AccessToken) if profileErr != nil { - c.logger.Debug("fetch profile for ", c.tag, ": ", profileErr) + c.logger.Error("fetch profile for ", c.tag, ": ", profileErr) } else if profileSnapshot != nil { credentialsChanged := c.applyProfileSnapshot(profileSnapshot) c.persistOAuthAccount() @@ -428,7 +442,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool { } } } - return true + return nil } func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool { @@ -439,7 +453,10 @@ func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool { return true } } - c.tryRefreshCredentials(true) + err = c.tryRefreshCredentials(true) + if err != nil { + return false + } currentCredentials := c.currentCredentials() return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken } @@ -924,7 +941,16 @@ func (c *defaultCredential) pollUsage() { return } body, _ := io.ReadAll(response.Body) - c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + if response.StatusCode == http.StatusUnauthorized { + c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + if !c.recoverAuthFailure(accessToken) { + c.markCredentialsUnavailable(E.New("poll usage unauthorized for ", c.tag)) + } + return + } + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + } c.incrementPollFailures() return } @@ -941,7 +967,9 @@ func (c *defaultCredential) pollUsage() { } err = json.NewDecoder(response.Body).Decode(&usageResponse) if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": decode: ", err) + } c.incrementPollFailures() return } @@ -982,7 +1010,7 @@ func (c *defaultCredential) pollUsage() { if needsProfileFetch { profileSnapshot, err := c.fetchProfileSnapshot(httpClient, accessToken) if err != nil { - c.logger.Debug("fetch profile for ", c.tag, ": ", err) + c.logger.Error("fetch profile for ", c.tag, ": ", err) return } if profileSnapshot != nil { diff --git a/service/ccm/credential_default_test.go b/service/ccm/credential_default_test.go index 9435f3db7f..8da97dbeac 100644 --- a/service/ccm/credential_default_test.go +++ b/service/ccm/credential_default_test.go @@ -9,7 +9,7 @@ import ( "time" ) -func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) { +func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) { t.Parallel() directory := t.TempDir() @@ -32,15 +32,47 @@ func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) { } credential.acquireLock = func(string) (func(), error) { - return nil, errors.New("locked") + return nil, errors.New("permission denied") } - token, err := credential.getAccessToken() - if err != nil { + _, err := credential.getAccessToken() + if err == nil { + t.Fatal("expected error when lock acquisition fails, got nil") + } + if credential.isUsable() { + t.Fatal("credential should be marked unavailable after lock failure") + } +} + +func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) { + t.Parallel() + + directory := t.TempDir() + credentialPath := filepath.Join(directory, ".credentials.json") + writeTestCredentials(t, credentialPath, &oauthCredentials{ + AccessToken: "old-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + Scopes: []string{"user:profile", "user:inference"}, + }) + + credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) { + t.Fatal("refresh should not be attempted when file is not writable") + return nil, nil + })) + if err := credential.reloadCredentials(true); err != nil { t.Fatal(err) } - if token != "old-token" { - t.Fatalf("expected old token, got %q", token) + + os.Chmod(credentialPath, 0o444) + t.Cleanup(func() { os.Chmod(credentialPath, 0o644) }) + + _, err := credential.getAccessToken() + if err == nil { + t.Fatal("expected error when credential file is not writable, got nil") + } + if credential.isUsable() { + t.Fatal("credential should be marked unavailable after write permission failure") } } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index f57ead1581..c0714641a0 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -598,7 +598,7 @@ func (c *externalCredential) pollUsage() { ctx := c.getReverseContext() response, err := c.doPollUsageRequest(ctx) if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": ", err) + c.logger.Error("poll usage for ", c.tag, ": ", err) c.incrementPollFailures() return } @@ -606,21 +606,21 @@ func (c *externalCredential) pollUsage() { if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) - c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) c.incrementPollFailures() return } body, err := io.ReadAll(response.Body) if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": read body: ", err) + c.logger.Error("poll usage for ", c.tag, ": read body: ", err) c.incrementPollFailures() return } var rawFields map[string]json.RawMessage err = json.Unmarshal(body, &rawFields) if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.logger.Error("poll usage for ", c.tag, ": decode: ", err) c.incrementPollFailures() return } @@ -634,7 +634,7 @@ func (c *externalCredential) pollUsage() { var statusResponse statusPayload err = json.Unmarshal(body, &statusResponse) if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.logger.Error("poll usage for ", c.tag, ": decode: ", err) c.incrementPollFailures() return } diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index cf2af4d2d0..7258dd4e0c 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -75,7 +75,7 @@ func (c *defaultCredential) retryCredentialReloadIfNeeded() { err := c.ensureCredentialWatcher() if err != nil { - c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + c.logger.Error("start credential watcher for ", c.tag, ": ", err) } _ = c.reloadCredentials(false) } From 471c9c3b470ce492b0496e0e7b6895f8c3f81117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Mar 2026 18:07:31 +0800 Subject: [PATCH 94/96] fix(ccm): make refresh failure fail fast --- service/ccm/credential_default.go | 175 ++++++++++++++++++------- service/ccm/credential_default_test.go | 24 +++- service/ccm/credential_file.go | 2 +- service/ccm/service_handler.go | 12 +- 4 files changed, 158 insertions(+), 55 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 4f5ecefd1f..86c86879c1 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -164,10 +164,6 @@ func (c *defaultCredential) start() error { if err != nil { c.logger.Error("initial credential load for ", c.tag, ": ", err) } - if c.credentials != nil && c.credentials.needsRefresh() && - slices.Contains(c.credentials.Scopes, "user:inference") { - c.tryRefreshCredentials(false) - } if c.usageTracker != nil { err = c.usageTracker.Load() if err != nil { @@ -216,6 +212,31 @@ type statusSnapshot struct { weight float64 } +type refreshFailureError struct { + err error + hard bool +} + +func (e *refreshFailureError) Error() string { + return e.err.Error() +} + +func (e *refreshFailureError) Unwrap() error { + return e.err +} + +func newRefreshFailure(err error, hard bool) error { + if err == nil { + return nil + } + return &refreshFailureError{err: err, hard: hard} +} + +func isHardRefreshFailure(err error) bool { + refreshErr, ok := err.(*refreshFailureError) + return ok && refreshErr.hard +} + func (c *defaultCredential) statusSnapshotLocked() statusSnapshot { if c.state.unavailable { return statusSnapshot{} @@ -339,13 +360,11 @@ func (c *defaultCredential) currentCredentials() *oauthCredentials { return cloneCredentials(c.credentials) } -func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) { +func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) error { if credentials == nil { - return - } - if err := platformWriteCredentials(credentials, c.credentialPath); err != nil { - c.logger.Error("persist refreshed token for ", c.tag, ": ", err) + return nil } + return platformWriteCredentials(credentials, c.credentialPath) } func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, force bool) bool { @@ -361,6 +380,18 @@ func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, return credentials.needsRefresh() } +func (c *defaultCredential) markRefreshUnavailable(err error) error { + return newRefreshFailure(c.markCredentialsUnavailable(err), true) +} + +func (c *defaultCredential) refreshCredentialsIfNeeded(force bool) error { + currentCredentials := c.currentCredentials() + if !c.shouldAttemptRefresh(currentCredentials, force) { + return nil + } + return c.tryRefreshCredentials(force) +} + func (c *defaultCredential) tryRefreshCredentials(force bool) error { latestCredentials, err := platformReadCredentials(c.credentialPath) if err == nil && latestCredentials != nil { @@ -378,8 +409,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error { if err != nil { lockErr := E.Cause(err, "acquire credential lock for ", c.tag) c.logger.Error(lockErr) - c.markCredentialsUnavailable(lockErr) - return lockErr + return c.markRefreshUnavailable(lockErr) } defer release() @@ -397,8 +427,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error { if err != nil { writeErr := E.Cause(err, "credential file not writable for ", c.tag) c.logger.Error(writeErr) - c.markCredentialsUnavailable(writeErr) - return writeErr + return c.markRefreshUnavailable(writeErr) } baseCredentials := cloneCredentials(currentCredentials) @@ -416,15 +445,20 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error { return nil } } - return E.Cause(err, "refresh token for ", c.tag) + return newRefreshFailure(E.Cause(err, "refresh token for ", c.tag), false) } if refreshResult == nil || refreshResult.Credentials == nil { - return E.New("refresh token for ", c.tag, ": empty result") + return newRefreshFailure(E.New("refresh token for ", c.tag, ": empty result"), false) } refreshedCredentials := cloneCredentials(refreshResult.Credentials) + err = c.persistCredentials(refreshedCredentials) + if err != nil { + persistErr := E.Cause(err, "persist refreshed token for ", c.tag) + c.logger.Error(persistErr) + return c.markRefreshUnavailable(persistErr) + } c.absorbCredentials(refreshedCredentials) - c.persistCredentials(refreshedCredentials) if refreshResult.TokenAccount != nil { c.absorbOAuthAccount(refreshResult.TokenAccount) @@ -438,27 +472,30 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error { credentialsChanged := c.applyProfileSnapshot(profileSnapshot) c.persistOAuthAccount() if credentialsChanged { - c.persistCredentials(c.currentCredentials()) + err = c.persistCredentials(c.currentCredentials()) + if err != nil { + c.logger.Error("persist credential metadata for ", c.tag, ": ", err) + } } } } return nil } -func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool { +func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) (bool, error) { latestCredentials, err := platformReadCredentials(c.credentialPath) if err == nil && latestCredentials != nil { c.absorbCredentials(latestCredentials) if latestCredentials.AccessToken != "" && latestCredentials.AccessToken != failedAccessToken { - return true + return true, nil } } err = c.tryRefreshCredentials(true) if err != nil { - return false + return false, err } currentCredentials := c.currentCredentials() - return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken + return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken, nil } func (c *defaultCredential) applyProfileSnapshot(snapshot *claudeProfileSnapshot) bool { @@ -895,7 +932,9 @@ func (c *defaultCredential) pollUsage() { if !c.isPollBackoffAtCap() { c.logger.Error("poll usage for ", c.tag, ": get token: ", err) } - c.incrementPollFailures() + if !isHardRefreshFailure(err) { + c.incrementPollFailures() + } return } @@ -905,55 +944,97 @@ func (c *defaultCredential) pollUsage() { Timeout: 5 * time.Second, } - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) + doUsageRequest := func(token string) (*http.Response, error) { + return doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+token) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + return request, nil + }) + } + + var response *http.Response + attemptedAuthRecovery := false + for { + response, err = doUsageRequest(accessToken) if err != nil { - return nil, err + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } + c.incrementPollFailures() + return } - request.Header.Set("Authorization", "Bearer "+accessToken) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - return request, nil - }) - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": ", err) + if response.StatusCode == http.StatusOK { + break } - c.incrementPollFailures() - return - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { if response.StatusCode == http.StatusTooManyRequests { retryDelay := time.Minute if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" { - seconds, err := strconv.ParseInt(retryAfter, 10, 64) - if err == nil && seconds > 0 { + seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64) + if parseErr == nil && seconds > 0 { retryDelay = time.Duration(seconds) * time.Second } } + response.Body.Close() c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay)) c.stateAccess.Lock() c.state.usageAPIRetryDelay = retryDelay c.stateAccess.Unlock() return } + body, _ := io.ReadAll(response.Body) - if response.StatusCode == http.StatusUnauthorized { - c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - if !c.recoverAuthFailure(accessToken) { - c.markCredentialsUnavailable(E.New("poll usage unauthorized for ", c.tag)) + response.Body.Close() + recoverableAuthFailure := !attemptedAuthRecovery && + (response.StatusCode == http.StatusUnauthorized || + (response.StatusCode == http.StatusForbidden && bytes.Contains(body, []byte("OAuth token has been revoked")))) + if recoverableAuthFailure { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) } - return + attemptedAuthRecovery = true + recovered, recoverErr := c.recoverAuthFailure(accessToken) + if recoverErr != nil { + if !isHardRefreshFailure(recoverErr) { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": auth recovery: ", recoverErr) + } + c.incrementPollFailures() + } + return + } + if !recovered { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": auth recovery did not produce a new token") + } + c.incrementPollFailures() + return + } + accessToken, err = c.getAccessToken() + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": get token after auth recovery: ", err) + } + if !isHardRefreshFailure(err) { + c.incrementPollFailures() + } + return + } + continue } + if !c.isPollBackoffAtCap() { c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) } c.incrementPollFailures() return } + defer response.Body.Close() var usageResponse struct { FiveHour struct { diff --git a/service/ccm/credential_default_test.go b/service/ccm/credential_default_test.go index 8da97dbeac..90158afe02 100644 --- a/service/ccm/credential_default_test.go +++ b/service/ccm/credential_default_test.go @@ -14,14 +14,15 @@ func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) { directory := t.TempDir() credentialPath := filepath.Join(directory, ".credentials.json") - writeTestCredentials(t, credentialPath, &oauthCredentials{ + credentials := &oauthCredentials{ AccessToken: "old-token", RefreshToken: "refresh-token", - ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + ExpiresAt: time.Now().Add(time.Hour).UnixMilli(), Scopes: []string{"user:profile", "user:inference"}, SubscriptionType: optionalStringPointer("max"), RateLimitTier: optionalStringPointer("default_claude_max_20x"), - }) + } + writeTestCredentials(t, credentialPath, credentials) credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) { t.Fatal("refresh should not be attempted when lock acquisition fails") @@ -31,6 +32,11 @@ func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) { t.Fatal(err) } + expiredCredentials := cloneCredentials(credentials) + expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli() + writeTestCredentials(t, credentialPath, expiredCredentials) + credential.absorbCredentials(expiredCredentials) + credential.acquireLock = func(string) (func(), error) { return nil, errors.New("permission denied") } @@ -49,12 +55,13 @@ func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) { directory := t.TempDir() credentialPath := filepath.Join(directory, ".credentials.json") - writeTestCredentials(t, credentialPath, &oauthCredentials{ + credentials := &oauthCredentials{ AccessToken: "old-token", RefreshToken: "refresh-token", - ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + ExpiresAt: time.Now().Add(time.Hour).UnixMilli(), Scopes: []string{"user:profile", "user:inference"}, - }) + } + writeTestCredentials(t, credentialPath, credentials) credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) { t.Fatal("refresh should not be attempted when file is not writable") @@ -64,6 +71,11 @@ func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) { t.Fatal(err) } + expiredCredentials := cloneCredentials(credentials) + expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli() + writeTestCredentials(t, credentialPath, expiredCredentials) + credential.absorbCredentials(expiredCredentials) + os.Chmod(credentialPath, 0o444) t.Cleanup(func() { os.Chmod(credentialPath, 0o644) }) diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index 7258dd4e0c..afff53d15c 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -107,7 +107,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error { } c.absorbCredentials(credentials) - return nil + return c.refreshCredentialsIfNeeded(false) } func (c *defaultCredential) markCredentialsUnavailable(err error) error { diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index aca9dc647e..796d38a063 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -422,6 +422,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if shouldRetry { recovered := false + var recoverErr error if defaultCred, ok := selectedCredential.(*defaultCredential); ok { failedAccessToken := "" currentCredentials := defaultCred.currentCredentials() @@ -429,7 +430,16 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { failedAccessToken = currentCredentials.AccessToken } s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying") - recovered = defaultCred.recoverAuthFailure(failedAccessToken) + recovered, recoverErr = defaultCred.recoverAuthFailure(failedAccessToken) + } + if recoverErr != nil { + response.Body.Close() + if isHardRefreshFailure(recoverErr) || selectedCredential.unavailableError() != nil { + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable during auth recovery") + return + } + writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(recoverErr, "auth recovery").Error()) + return } if recovered { response.Body.Close() From bf9e390cf4ce59cc6e31b117993f36b264c41ac6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Mar 2026 18:46:10 +0800 Subject: [PATCH 95/96] fix(ccm,ocm): allow reverse connector credentials to serve local requests Connector mode credentials were unconditionally blocked from local use by unavailableError(), despite having a working forwardHTTPClient. Also set credentialDialer in OCM connector mode to prevent nil panic in WebSocket handler. --- service/ccm/credential_external.go | 3 --- service/ocm/credential_external.go | 4 +--- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index c0714641a0..deb8436f44 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -372,9 +372,6 @@ func (c *externalCredential) earliestReset() time.Time { } func (c *externalCredential) unavailableError() error { - if c.reverse && c.connectorURL != nil { - return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests") - } if c.baseURL == reverseProxyBaseURL { session := c.getReverseSession() if session == nil || session.IsClosed() { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 222a22e94d..29a80a909d 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -197,6 +197,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx if options.Reverse { // Connector mode: we dial out to serve, not to proxy + credential.credentialDialer = credentialDialer credential.connectorDialer = credentialDialer if options.Server != "" { credential.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) @@ -407,9 +408,6 @@ func (c *externalCredential) earliestReset() time.Time { } func (c *externalCredential) unavailableError() error { - if c.reverse && c.connectorURL != nil { - return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests") - } if c.baseURL == reverseProxyBaseURL { session := c.getReverseSession() if session == nil || session.IsClosed() { From f9743212d5fa97db605c9f2dd32e37584afb6c19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 29 Mar 2026 21:04:47 +0800 Subject: [PATCH 96/96] ccm weekly burn factor --- service/ccm/credential.go | 3 +- service/ccm/credential_default.go | 20 ++++ service/ccm/credential_external.go | 20 +++- service/ccm/credential_provider.go | 13 +-- service/ccm/credential_provider_test.go | 121 +++++++++++++++++++++ service/ccm/service_status.go | 21 ++++ service/ccm/service_status_test.go | 64 +++++++++-- service/ccm/weekly_burn.go | 138 ++++++++++++++++++++++++ service/ccm/weekly_burn_test.go | 122 +++++++++++++++++++++ 9 files changed, 498 insertions(+), 24 deletions(-) create mode 100644 service/ccm/credential_provider_test.go create mode 100644 service/ccm/weekly_burn.go create mode 100644 service/ccm/weekly_burn_test.go diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 48b97b95b7..99ebeeb32c 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -68,6 +68,7 @@ type credentialState struct { rateLimitTier string oauthAccount *claudeOAuthAccount remotePlanWeight float64 + remoteWeeklyBurnFactor float64 lastUpdated time.Time consecutivePollFailures int usageAPIRetryDelay time.Duration @@ -113,6 +114,7 @@ type Credential interface { fiveHourCap() float64 weeklyCap() float64 planWeight() float64 + weeklyBurnFactor() float64 fiveHourResetTime() time.Time weeklyResetTime() time.Time markRateLimited(resetAt time.Time) @@ -244,7 +246,6 @@ func (s credentialState) currentAvailability() availabilityStatus { } } - func parseRateLimitResetFromHeaders(headers http.Header) time.Time { claim := headers.Get("anthropic-ratelimit-unified-representative-claim") switch claim { diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 86c86879c1..548ad9699b 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -813,6 +813,26 @@ func (c *defaultCredential) planWeight() float64 { return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) } +func (c *defaultCredential) weeklyBurnFactor() float64 { + c.stateAccess.RLock() + fiveHourUtilization := c.state.fiveHourUtilization + weeklyUtilization := c.state.weeklyUtilization + fiveHourReset := c.state.fiveHourReset + weeklyReset := c.state.weeklyReset + rateLimitTier := c.state.rateLimitTier + c.stateAccess.RUnlock() + return computeCredentialWeeklyBurnFactor( + time.Now(), + fiveHourReset, + weeklyReset, + fiveHourUtilization, + weeklyUtilization, + c.cap5h, + c.capWeekly, + ccmPlanWeight5h(rateLimitTier), + ) +} + func (c *defaultCredential) fiveHourResetTime() time.Time { c.stateAccess.RLock() defer c.stateAccess.RUnlock() diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index deb8436f44..a658f59b8d 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -319,6 +319,15 @@ func (c *externalCredential) planWeight() float64 { return 10 } +func (c *externalCredential) weeklyBurnFactor() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if c.state.remoteWeeklyBurnFactor > 0 { + return c.state.remoteWeeklyBurnFactor + } + return ccmWeeklyBurnFactorMin +} + func (c *externalCredential) fiveHourResetTime() time.Time { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -647,6 +656,11 @@ func (c *externalCredential) pollUsage() { if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } + if statusResponse.WeeklyBurnFactor > 0 { + c.state.remoteWeeklyBurnFactor = statusResponse.WeeklyBurnFactor + } else { + c.state.remoteWeeklyBurnFactor = ccmWeeklyBurnFactorMin + } if statusResponse.FiveHourReset > 0 { c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) } @@ -756,6 +770,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } + if statusResponse.WeeklyBurnFactor > 0 { + c.state.remoteWeeklyBurnFactor = statusResponse.WeeklyBurnFactor + } else { + c.state.remoteWeeklyBurnFactor = ccmWeeklyBurnFactorMin + } if statusResponse.FiveHourReset > 0 { c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) } @@ -845,7 +864,6 @@ func (c *externalCredential) availabilityStatus() availabilityStatus { return c.state.currentAvailability() } - func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() defer c.stateAccess.Unlock() diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go index 4f5f2ad32b..bfd5f93ab7 100644 --- a/service/ccm/credential_provider.go +++ b/service/ccm/credential_provider.go @@ -297,12 +297,9 @@ func (p *balancerProvider) pickFallback(filter func(Credential) bool) Credential return nil } -const weeklyWindowHours = 7 * 24 - func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credential { var best Credential bestScore := float64(-1) - now := time.Now() for _, credential := range p.credentials { if filter != nil && !filter(credential) { continue @@ -311,15 +308,7 @@ func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credentia continue } remaining := credential.weeklyCap() - credential.weeklyUtilization() - score := remaining * credential.planWeight() - resetTime := credential.weeklyResetTime() - if !resetTime.IsZero() { - timeUntilReset := resetTime.Sub(now) - if timeUntilReset < time.Hour { - timeUntilReset = time.Hour - } - score *= weeklyWindowHours / timeUntilReset.Hours() - } + score := remaining * credential.planWeight() * credential.weeklyBurnFactor() if score > bestScore { bestScore = score best = credential diff --git a/service/ccm/credential_provider_test.go b/service/ccm/credential_provider_test.go new file mode 100644 index 0000000000..e76224214d --- /dev/null +++ b/service/ccm/credential_provider_test.go @@ -0,0 +1,121 @@ +package ccm + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/sagernet/sing-box/log" +) + +func TestBalancerPickLeastUsedDoesNotBoostEarlierResetByDefault(t *testing.T) { + t.Parallel() + + now := time.Now() + provider := newBalancerProvider([]Credential{ + &testCredential{ + tag: "later", + available: true, + usable: true, + hasData: true, + weekly: 50, + weeklyCapV: 100, + weight: 1, + burnFactor: 1, + weeklyReset: now.Add(6 * 24 * time.Hour), + availability: availabilityStatus{State: availabilityStateUsable}, + }, + &testCredential{ + tag: "earlier", + available: true, + usable: true, + hasData: true, + weekly: 50, + weeklyCapV: 100, + weight: 1, + burnFactor: 1, + weeklyReset: now.Add(24 * time.Hour), + availability: availabilityStatus{State: availabilityStateUsable}, + }, + }, "", 0, log.NewNOPFactory().Logger()) + + best := provider.pickLeastUsed(nil) + if best == nil || best.tagName() != "later" { + t.Fatalf("expected later reset credential, got %#v", best) + } +} + +func TestBalancerPickLeastUsedUsesWeeklyBurnFactor(t *testing.T) { + t.Parallel() + + now := time.Now() + provider := newBalancerProvider([]Credential{ + &testCredential{ + tag: "calm", + available: true, + usable: true, + hasData: true, + weekly: 50, + weeklyCapV: 100, + weight: 1, + burnFactor: 1, + weeklyReset: now.Add(6 * 24 * time.Hour), + availability: availabilityStatus{State: availabilityStateUsable}, + }, + &testCredential{ + tag: "urgent", + available: true, + usable: true, + hasData: true, + weekly: 50, + weeklyCapV: 100, + weight: 1, + burnFactor: 1.5, + weeklyReset: now.Add(6 * 24 * time.Hour), + availability: availabilityStatus{State: availabilityStateUsable}, + }, + }, "", 0, log.NewNOPFactory().Logger()) + + best := provider.pickLeastUsed(nil) + if best == nil || best.tagName() != "urgent" { + t.Fatalf("expected urgent credential, got %#v", best) + } +} + +func TestExternalCredentialPollUsageDefaultsMissingWeeklyBurnFactor(t *testing.T) { + t.Parallel() + + requestContext, cancelRequests := context.WithCancel(context.Background()) + defer cancelRequests() + reverseContext, reverseCancel := context.WithCancel(context.Background()) + defer reverseCancel() + + credential := &externalCredential{ + tag: "remote", + baseURL: "http://remote", + token: "token", + forwardHTTPClient: &http.Client{ + Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + return newJSONResponse(http.StatusOK, `{ + "five_hour_utilization": 10, + "five_hour_reset": 1893456000, + "weekly_utilization": 20, + "weekly_reset": 1893801600, + "plan_weight": 5 + }`), nil + }), + }, + logger: log.NewNOPFactory().Logger(), + requestContext: requestContext, + cancelRequests: cancelRequests, + reverseContext: reverseContext, + reverseCancel: reverseCancel, + } + + credential.pollUsage() + + if factor := credential.weeklyBurnFactor(); factor != ccmWeeklyBurnFactorMin { + t.Fatalf("expected default weekly burn factor %v, got %v", ccmWeeklyBurnFactorMin, factor) + } +} diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index f5bd9bc63c..4eb11d09c6 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -18,6 +18,7 @@ type statusPayload struct { WeeklyUtilization float64 `json:"weekly_utilization"` WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` + WeeklyBurnFactor float64 `json:"weekly_burn_factor"` } type aggregatedStatus struct { @@ -26,6 +27,7 @@ type aggregatedStatus struct { totalWeight float64 fiveHourReset time.Time weeklyReset time.Time + weeklyBurnFactor float64 availability availabilityStatus } @@ -41,12 +43,17 @@ func (s aggregatedStatus) equal(other aggregatedStatus) bool { } func (s aggregatedStatus) toPayload() statusPayload { + weeklyBurnFactor := s.weeklyBurnFactor + if weeklyBurnFactor <= 0 { + weeklyBurnFactor = ccmWeeklyBurnFactorMin + } return statusPayload{ FiveHourUtilization: s.fiveHourUtilization, FiveHourReset: resetToEpoch(s.fiveHourReset), WeeklyUtilization: s.weeklyUtilization, WeeklyReset: resetToEpoch(s.weeklyReset), PlanWeight: s.totalWeight, + WeeklyBurnFactor: weeklyBurnFactor, } } @@ -273,6 +280,7 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus { visibleInputs := make([]aggregateInput, 0, len(provider.allCredentials())) var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + var totalBurnBase, totalWeightedBurnFactor float64 now := time.Now() var totalWeightedHoursUntil5hReset, total5hResetWeight float64 var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 @@ -303,6 +311,15 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user totalWeightedRemaining5h += remaining5h * weight totalWeightedRemainingWeekly += remainingWeekly * weight totalWeight += weight + burnBase := remainingWeekly * weight + totalBurnBase += burnBase + weeklyBurnFactor := credential.weeklyBurnFactor() + if weeklyBurnFactor < ccmWeeklyBurnFactorMin { + weeklyBurnFactor = ccmWeeklyBurnFactorMin + } else if weeklyBurnFactor > ccmWeeklyBurnFactorMax { + weeklyBurnFactor = ccmWeeklyBurnFactorMax + } + totalWeightedBurnFactor += burnBase * weeklyBurnFactor fiveHourReset := credential.fiveHourResetTime() if !fiveHourReset.IsZero() { @@ -334,8 +351,12 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight, weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight, totalWeight: totalWeight, + weeklyBurnFactor: ccmWeeklyBurnFactorMin, availability: availability, } + if totalBurnBase > 0 { + result.weeklyBurnFactor = totalWeightedBurnFactor / totalBurnBase + } if total5hResetWeight > 0 { avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour))) diff --git a/service/ccm/service_status_test.go b/service/ccm/service_status_test.go index c2dbea1a21..ff8b407354 100644 --- a/service/ccm/service_status_test.go +++ b/service/ccm/service_status_test.go @@ -21,21 +21,28 @@ type testCredential struct { fiveHourCapV float64 weeklyCapV float64 weight float64 + burnFactor float64 fiveReset time.Time weeklyReset time.Time availability availabilityStatus } -func (c *testCredential) tagName() string { return c.tag } -func (c *testCredential) isAvailable() bool { return c.available } -func (c *testCredential) isUsable() bool { return c.usable } -func (c *testCredential) isExternal() bool { return c.external } -func (c *testCredential) hasSnapshotData() bool { return c.hasData } -func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour } -func (c *testCredential) weeklyUtilization() float64 { return c.weekly } -func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV } -func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV } -func (c *testCredential) planWeight() float64 { return c.weight } +func (c *testCredential) tagName() string { return c.tag } +func (c *testCredential) isAvailable() bool { return c.available } +func (c *testCredential) isUsable() bool { return c.usable } +func (c *testCredential) isExternal() bool { return c.external } +func (c *testCredential) hasSnapshotData() bool { return c.hasData } +func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour } +func (c *testCredential) weeklyUtilization() float64 { return c.weekly } +func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV } +func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV } +func (c *testCredential) planWeight() float64 { return c.weight } +func (c *testCredential) weeklyBurnFactor() float64 { + if c.burnFactor > 0 { + return c.burnFactor + } + return ccmWeeklyBurnFactorMin +} func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset } func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset } func (c *testCredential) markRateLimited(time.Time) {} @@ -66,9 +73,11 @@ type testProvider struct { func (p *testProvider) selectCredential(string, credentialSelection) (Credential, bool, error) { return nil, false, nil } + func (p *testProvider) onRateLimited(string, Credential, time.Time, credentialSelection) Credential { return nil } + func (p *testProvider) linkProviderInterrupt(Credential, credentialSelection, func()) func() bool { return func() bool { return true } } @@ -208,6 +217,41 @@ func TestRewriteResponseHeadersRejectedOnHardRateLimit(t *testing.T) { } } +func TestComputeAggregatedUtilizationAggregatesWeeklyBurnFactor(t *testing.T) { + t.Parallel() + + service := &Service{} + status := service.computeAggregatedUtilization(&testProvider{credentials: []Credential{ + &testCredential{ + tag: "a", + available: true, + usable: true, + hasData: true, + weekly: 80, + weeklyCapV: 100, + weight: 1, + burnFactor: 1.2, + availability: availabilityStatus{State: availabilityStateUsable}, + }, + &testCredential{ + tag: "b", + available: true, + usable: true, + hasData: true, + weekly: 40, + weeklyCapV: 100, + weight: 2, + burnFactor: 1.8, + availability: availabilityStatus{State: availabilityStateUsable}, + }, + }}, nil) + + expected := (20*1*1.2 + 60*2*1.8) / (20*1 + 60*2) + if diff := status.weeklyBurnFactor - expected; diff < -0.000001 || diff > 0.000001 { + t.Fatalf("expected weekly burn factor %v, got %v", expected, status.weeklyBurnFactor) + } +} + func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) { t.Parallel() diff --git a/service/ccm/weekly_burn.go b/service/ccm/weekly_burn.go new file mode 100644 index 0000000000..81c8e49bde --- /dev/null +++ b/service/ccm/weekly_burn.go @@ -0,0 +1,138 @@ +package ccm + +import "time" + +const ( + ccmFiveHourWindowDuration = 5 * time.Hour + ccmWeeklyBurnFactorMin = 1.0 + ccmWeeklyBurnFactorMax = 2.0 +) + +type burnWindow struct { + end time.Time + capacity float64 +} + +func ccmPlanWeight5h(rateLimitTier string) float64 { + switch rateLimitTier { + case "default_claude_max_20x": + return 20 + case "default_claude_max_5x": + return 5 + default: + return 1 + } +} + +func ccmWeeklyBurnCapacity(limitPercent float64, planWeight5h float64) float64 { + if limitPercent <= 0 || planWeight5h <= 0 { + return 0 + } + return limitPercent * planWeight5h / 75 +} + +func computeWeeklyBurnDeadline( + now time.Time, + fiveHourReset time.Time, + weeklyReset time.Time, + fiveHourUtilization float64, + weeklyUtilization float64, + fiveHourCap float64, + weeklyCap float64, + planWeight5h float64, +) time.Time { + if weeklyCap <= 0 || planWeight5h <= 0 || weeklyReset.IsZero() { + return time.Time{} + } + remainingWeekly := weeklyCap - weeklyUtilization + if remainingWeekly <= 0 { + return weeklyReset + } + if !weeklyReset.After(now) { + return now + } + if fiveHourCap <= 0 || fiveHourReset.IsZero() || !fiveHourReset.After(now) { + return time.Time{} + } + + currentWindowEnd := fiveHourReset + if weeklyReset.Before(currentWindowEnd) { + currentWindowEnd = weeklyReset + } + + windows := []burnWindow{{ + end: currentWindowEnd, + capacity: ccmWeeklyBurnCapacity(fiveHourCap-fiveHourUtilization, planWeight5h), + }} + + if currentWindowEnd.Equal(fiveHourReset) { + fullWindowBurn := ccmWeeklyBurnCapacity(fiveHourCap, planWeight5h) + if fullWindowBurn > 0 { + for windowEnd := fiveHourReset.Add(ccmFiveHourWindowDuration); !windowEnd.After(weeklyReset); windowEnd = windowEnd.Add(ccmFiveHourWindowDuration) { + windows = append(windows, burnWindow{ + end: windowEnd, + capacity: fullWindowBurn, + }) + } + } + } + + remaining := remainingWeekly + for i := len(windows) - 1; i >= 0; i-- { + remaining -= windows[i].capacity + if remaining <= 0 { + return windows[i].end + } + } + + return now +} + +func computeWeeklyBurnFactor(now time.Time, burnDeadline time.Time, weeklyReset time.Time) float64 { + if weeklyReset.IsZero() || burnDeadline.IsZero() { + return ccmWeeklyBurnFactorMin + } + if !weeklyReset.After(now) || !burnDeadline.After(now) { + return ccmWeeklyBurnFactorMax + } + + timeLeft := weeklyReset.Sub(now) + if timeLeft <= 0 { + return ccmWeeklyBurnFactorMax + } + requiredSpan := weeklyReset.Sub(burnDeadline) + if requiredSpan <= 0 { + return ccmWeeklyBurnFactorMin + } + + pressure := requiredSpan.Seconds() / timeLeft.Seconds() + if pressure < 0 { + pressure = 0 + } else if pressure > 1 { + pressure = 1 + } + return ccmWeeklyBurnFactorMin + pressure*pressure +} + +func computeCredentialWeeklyBurnFactor( + now time.Time, + fiveHourReset time.Time, + weeklyReset time.Time, + fiveHourUtilization float64, + weeklyUtilization float64, + fiveHourCap float64, + weeklyCap float64, + planWeight5h float64, +) float64 { + burnDeadline := computeWeeklyBurnDeadline( + now, + fiveHourReset, + weeklyReset, + fiveHourUtilization, + weeklyUtilization, + fiveHourCap, + weeklyCap, + planWeight5h, + ) + return computeWeeklyBurnFactor(now, burnDeadline, weeklyReset) +} diff --git a/service/ccm/weekly_burn_test.go b/service/ccm/weekly_burn_test.go new file mode 100644 index 0000000000..72f882a20d --- /dev/null +++ b/service/ccm/weekly_burn_test.go @@ -0,0 +1,122 @@ +package ccm + +import ( + "testing" + "time" +) + +func TestCCMPlanWeight5h(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + rateLimitTier string + expected float64 + }{ + {name: "20x", rateLimitTier: "default_claude_max_20x", expected: 20}, + {name: "5x", rateLimitTier: "default_claude_max_5x", expected: 5}, + {name: "default", rateLimitTier: "unknown", expected: 1}, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + if actual := ccmPlanWeight5h(test.rateLimitTier); actual != test.expected { + t.Fatalf("expected %v, got %v", test.expected, actual) + } + }) + } +} + +func TestComputeWeeklyBurnDeadlineUsesLatestPossibleWindow(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + deadline := computeWeeklyBurnDeadline( + now, + now.Add(5*time.Hour), + now.Add(20*time.Hour), + 100, + 96, + 100, + 100, + 5, + ) + + expected := now.Add(20 * time.Hour) + if !deadline.Equal(expected) { + t.Fatalf("expected deadline %v, got %v", expected, deadline) + } +} + +func TestComputeWeeklyBurnDeadlineNeedsMultipleWindows(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + deadline := computeWeeklyBurnDeadline( + now, + now.Add(5*time.Hour), + now.Add(20*time.Hour), + 100, + 90, + 100, + 100, + 5, + ) + + expected := now.Add(15 * time.Hour) + if !deadline.Equal(expected) { + t.Fatalf("expected deadline %v, got %v", expected, deadline) + } +} + +func TestComputeWeeklyBurnDeadlineReturnsNowWhenAlreadyImpossible(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + deadline := computeWeeklyBurnDeadline( + now, + now.Add(5*time.Hour), + now.Add(20*time.Hour), + 100, + 75, + 100, + 100, + 5, + ) + + if !deadline.Equal(now) { + t.Fatalf("expected deadline %v, got %v", now, deadline) + } +} + +func TestComputeWeeklyBurnFactorStaysNearOneBeforeDeadline(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + factor := computeWeeklyBurnFactor(now, now.Add(90*time.Hour), now.Add(100*time.Hour)) + if factor < 1 || factor > 1.05 { + t.Fatalf("expected factor close to 1, got %v", factor) + } +} + +func TestComputeWeeklyBurnFactorRisesNearDeadline(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + factor := computeWeeklyBurnFactor(now, now.Add(time.Hour), now.Add(100*time.Hour)) + if factor <= 1.9 || factor > ccmWeeklyBurnFactorMax { + t.Fatalf("expected factor near 2, got %v", factor) + } +} + +func TestComputeWeeklyBurnFactorCapsAfterDeadline(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + factor := computeWeeklyBurnFactor(now, now.Add(-time.Minute), now.Add(100*time.Hour)) + if factor != ccmWeeklyBurnFactorMax { + t.Fatalf("expected factor %v, got %v", ccmWeeklyBurnFactorMax, factor) + } +}