-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcodex_switcher.py
More file actions
executable file
·2257 lines (1948 loc) · 79.1 KB
/
codex_switcher.py
File metadata and controls
executable file
·2257 lines (1948 loc) · 79.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Codex Switcher - 跨平台 Codex 账号管理工具
功能:
(1) 启动后直接进入账号余量列表
(2) 按编号切换已存档账号
(3) 在余量页内直接调用官方 codex login 添加账号
(4) 自动存档当前登录账号并刷新使用量
(0) 退出
支持平台: macOS, Linux, Windows
"""
import os
import sys
import argparse
import json
import shlex
import re
import shutil
import subprocess
import base64
import platform
import time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
import unicodedata
from typing import Optional, Dict, List, Tuple, Any
REFRESH_TOKEN_URL = "https://auth.openai.com/oauth/token"
REFRESH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
REFRESH_LOOKAHEAD_SECONDS = 300
TOKEN_REFRESH_INTERVAL_DAYS = 8
USAGE_API_URL = "https://chatgpt.com/backend-api/wham/usage"
RESTART_DRY_RUN_ENV = "CODEX_SWITCHER_DRY_RUN_RESTART"
MAX_REFRESH_WORKERS = 6
LOGIN_FILE_CREDENTIALS_CONFIG = 'cli_auth_credentials_store="file"'
ANSI_ESCAPE_RE = re.compile(r'\x1b\[[0-9;]*m')
# ============== 跨平台路径配置 ==============
def get_home_dir() -> Path:
"""获取用户主目录"""
return Path.home()
def get_codex_config_dir() -> Path:
"""获取 Codex 配置目录"""
home = get_home_dir()
if platform.system() == "Windows":
return home / ".codex"
else:
return home / ".codex"
def get_switcher_dir() -> Path:
"""获取 Codex Switcher 数据目录"""
home = get_home_dir()
if platform.system() == "Windows":
return home / "codex-switcher"
else:
return home / "codex-switcher"
def get_accounts_dir() -> Path:
"""获取账号存档目录"""
return get_switcher_dir() / "accounts"
def get_usage_cache_dir() -> Path:
"""获取使用量缓存目录"""
return get_switcher_dir() / "usage_cache"
def get_auth_file() -> Path:
"""获取当前 auth.json 文件路径"""
return get_codex_config_dir() / "auth.json"
# ============== 颜色配置 ==============
class Colors:
"""终端颜色(跨平台兼容)"""
SUPPORTS_COLOR = (
hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() and
(platform.system() != 'Windows' or 'ANSICON' in os.environ or
'WT_SESSION' in os.environ or os.environ.get('TERM') == 'xterm')
)
if SUPPORTS_COLOR:
HEADER = '\033[95m'
BLUE = '\033[94m'
CYAN = '\033[96m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
DIM = '\033[2m'
UNDERLINE = '\033[4m'
else:
HEADER = BLUE = CYAN = GREEN = YELLOW = RED = ENDC = BOLD = DIM = UNDERLINE = ''
# ============== 工具函数 ==============
def clear_screen():
"""清屏(跨平台)"""
os.system('cls' if platform.system() == 'Windows' else 'clear')
def decode_jwt_payload(token: str) -> Optional[dict]:
"""解码 JWT token 的 payload 部分"""
try:
parts = token.split('.')
if len(parts) < 2:
return None
payload = parts[1]
padding = 4 - len(payload) % 4
if padding != 4:
payload += '=' * padding
decoded = base64.urlsafe_b64decode(payload)
return json.loads(decoded)
except Exception:
return None
def format_datetime(dt_str: str) -> str:
"""格式化日期时间"""
if not dt_str:
return 'N/A'
try:
dt = datetime.fromisoformat(dt_str.replace('Z', '+00:00'))
return dt.strftime('%m-%d %H:%M')
except:
return dt_str[:16] if len(dt_str) > 16 else dt_str
def time_until_reset(sub_until: str) -> str:
"""计算到重置还有多久"""
if not sub_until:
return 'N/A'
try:
dt = datetime.fromisoformat(sub_until.replace('Z', '+00:00'))
now = datetime.now(timezone.utc)
delta = dt - now
if delta.total_seconds() < 0:
return "已过期"
days = delta.days
hours = delta.seconds // 3600
minutes = (delta.seconds % 3600) // 60
if days > 0:
return f"{days}d{hours}h"
elif hours > 0:
return f"{hours}h{minutes}m"
else:
return f"{minutes}m"
except:
return 'N/A'
def get_token_status(exp: int) -> Tuple[str, str]:
"""获取 token 状态"""
if not exp:
return 'unknown', Colors.YELLOW
now = time.time()
remaining = exp - now
if remaining <= 0:
return 'expired', Colors.RED
elif remaining < 3600:
return 'expiring', Colors.YELLOW
else:
return 'valid', Colors.GREEN
def sanitize_key(value: str) -> str:
"""将任意字符串转换为安全文件名"""
return "".join(c if c.isalnum() or c in '-_.' else '_' for c in value)
def char_display_width(ch: str) -> int:
"""计算单个字符在终端中的显示宽度"""
if unicodedata.east_asian_width(ch) in ('W', 'F'):
return 2
return 1
def display_width(text: str) -> int:
"""计算字符串在终端中的显示宽度"""
clean = ANSI_ESCAPE_RE.sub('', text)
return sum(char_display_width(ch) for ch in clean)
def truncate_display_text(text: str, max_width: int, suffix: str = '..') -> str:
"""按终端显示宽度截断字符串"""
if display_width(text) <= max_width:
return text
suffix_width = display_width(suffix)
width = 0
chars = []
for ch in text:
ch_width = char_display_width(ch)
if width + ch_width + suffix_width > max_width:
break
chars.append(ch)
width += ch_width
return ''.join(chars) + suffix
def pad_display(text: str, width: int) -> str:
"""按终端显示宽度右侧补齐空格"""
pad = max(0, width - display_width(text))
return text + (' ' * pad)
def parse_iso_datetime(dt_str: str) -> Optional[datetime]:
"""解析 ISO 时间字符串"""
if not dt_str:
return None
try:
return datetime.fromisoformat(dt_str.replace('Z', '+00:00'))
except Exception:
return None
def iso_utc_now() -> str:
"""返回当前 UTC 时间字符串"""
return datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')
def extract_claims_from_id_token(id_token: str) -> Optional[dict]:
"""从 id_token 提取 claims"""
payload = decode_jwt_payload(id_token)
if not payload:
return None
auth_info = payload.get('https://api.openai.com/auth', {})
email = payload.get('email', 'Unknown')
chatgpt_user_id = auth_info.get('chatgpt_user_id', '') or auth_info.get('user_id', '')
account_id = auth_info.get('chatgpt_account_id', '')
record_key = ''
if chatgpt_user_id and account_id:
record_key = f"{chatgpt_user_id}::{account_id}"
return {
'payload': payload,
'auth_info': auth_info,
'email': email,
'chatgpt_user_id': chatgpt_user_id,
'chatgpt_account_id': account_id,
'record_key': record_key,
}
def normalize_organizations(auth_info: dict) -> List[dict]:
"""标准化 workspace/organization 信息"""
raw_orgs = auth_info.get('organizations', [])
if not isinstance(raw_orgs, list):
return []
organizations = []
for raw_org in raw_orgs:
if not isinstance(raw_org, dict):
continue
organizations.append({
'id': str(raw_org.get('id', '') or '').strip(),
'title': str(raw_org.get('title', '') or '').strip(),
'role': str(raw_org.get('role', '') or '').strip(),
'is_default': bool(raw_org.get('is_default')),
})
return organizations
def get_primary_workspace(organizations: List[dict]) -> dict:
"""优先选择默认 workspace,其次取第一个"""
for org in organizations:
if org.get('is_default'):
return org
return organizations[0] if organizations else {}
def format_workspace_display(organizations: List[dict], primary: dict) -> str:
"""格式化 SPACE 列展示文本"""
if not primary:
return '-'
label = primary.get('title') or primary.get('id') or '-'
extra_count = max(0, len(organizations) - 1)
if extra_count:
return f"{label} (+{extra_count})"
return label
def get_usage_cache_key(email: str, account_id: str = '', record_key: str = '') -> str:
"""生成 usage 缓存 key"""
if record_key:
return sanitize_key(record_key)
if email and account_id:
return sanitize_key(f"{email}__{account_id}")
if email:
return sanitize_key(email)
if account_id:
return sanitize_key(account_id)
return 'unknown'
def list_processes() -> List[Tuple[int, str]]:
"""列出当前用户进程"""
try:
result = subprocess.run(
['ps', '-axo', 'pid=,command='],
capture_output=True,
text=True,
check=False,
)
except Exception:
return []
processes = []
for raw_line in result.stdout.splitlines():
line = raw_line.strip()
if not line:
continue
try:
pid_text, command = line.split(None, 1)
processes.append((int(pid_text), command))
except ValueError:
continue
return processes
def list_windows_processes() -> List[dict]:
"""列出 Windows 进程信息"""
if platform.system() != 'Windows':
return []
script = (
"Get-CimInstance Win32_Process | "
"Select-Object ProcessId,ParentProcessId,Name,ExecutablePath,CommandLine | "
"ConvertTo-Json -Compress"
)
try:
result = subprocess.run(
['powershell', '-NoProfile', '-Command', script],
capture_output=True,
text=True,
check=False,
)
except Exception:
return []
if result.returncode != 0:
return []
raw = result.stdout.strip()
if not raw:
return []
try:
data = json.loads(raw)
except Exception:
return []
if isinstance(data, dict):
data = [data]
if not isinstance(data, list):
return []
processes = []
for item in data:
if not isinstance(item, dict):
continue
try:
pid = int(item.get('ProcessId') or 0)
ppid = int(item.get('ParentProcessId') or 0)
except Exception:
continue
if pid <= 0:
continue
processes.append({
'pid': pid,
'ppid': ppid,
'name': str(item.get('Name') or ''),
'exe_path': str(item.get('ExecutablePath') or ''),
'command_line': str(item.get('CommandLine') or ''),
})
return processes
def list_process_tree() -> Dict[int, Tuple[int, str]]:
"""列出当前用户进程树"""
try:
result = subprocess.run(
['ps', '-axo', 'pid=,ppid=,command='],
capture_output=True,
text=True,
check=False,
)
except Exception:
return {}
tree: Dict[int, Tuple[int, str]] = {}
for raw_line in result.stdout.splitlines():
line = raw_line.strip()
if not line:
continue
try:
pid_text, ppid_text, command = line.split(None, 2)
tree[int(pid_text)] = (int(ppid_text), command)
except ValueError:
continue
return tree
def get_process_cwd(pid: int) -> str:
"""获取进程当前工作目录"""
try:
result = subprocess.run(
['lsof', '-a', '-p', str(pid), '-d', 'cwd', '-Fn'],
capture_output=True,
text=True,
check=False,
)
except Exception:
return str(get_home_dir())
for line in result.stdout.splitlines():
if line.startswith('n') and len(line) > 1:
return line[1:]
return str(get_home_dir())
def detect_codex_desktop_instances() -> List[dict]:
"""检测运行中的 Codex Desktop 主进程"""
if platform.system() == 'Windows':
instances = []
for proc in list_windows_processes():
exe_path = (proc.get('exe_path') or '').replace('/', '\\')
command = proc.get('command_line', '')
if not exe_path.lower().endswith('\\app\\codex.exe'):
continue
if '\\resources\\codex.exe' in exe_path.lower():
continue
if '--type=' in command:
continue
instances.append({
'pid': proc['pid'],
'app_path': exe_path,
})
return instances
instances = []
for pid, command in list_processes():
if '/Contents/MacOS/Codex' not in command:
continue
if 'Codex Helper' in command:
continue
if 'codex-switcher.py' in command:
continue
app_path = command.split('/Contents/MacOS/Codex', 1)[0]
if app_path.endswith('.app'):
instances.append({
'pid': pid,
'app_path': app_path,
})
return instances
def detect_codex_cli_instances() -> List[dict]:
"""检测运行中的 codex CLI 进程"""
if platform.system() == 'Windows':
return []
tree = list_process_tree()
instances = []
for pid, (ppid, command) in tree.items():
binary = command.split(' ', 1)[0]
if os.path.basename(binary) != 'codex':
continue
if '/Applications/Codex.app/Contents/Resources/codex' in command:
continue
if pid == os.getpid():
continue
if process_is_managed_by_codex_desktop(pid, tree):
continue
instances.append({
'pid': pid,
'command': command,
'cwd': get_process_cwd(pid),
})
return instances
def process_is_managed_by_codex_desktop(
pid: int,
tree: Dict[int, Tuple[int, str]],
) -> bool:
"""判断进程是否属于 Codex Desktop 进程树"""
visited = set()
current = pid
while current and current not in visited:
visited.add(current)
node = tree.get(current)
if not node:
return False
parent_pid, command = node
if '/Applications/Codex.app/' in command:
return True
current = parent_pid
return False
def escape_applescript_string(value: str) -> str:
"""转义 AppleScript 字符串"""
return value.replace('\\', '\\\\').replace('"', '\\"')
def escape_powershell_string(value: str) -> str:
"""转义 PowerShell 单引号字符串"""
return value.replace("'", "''")
def collect_windows_restart_targets(desktop_instances: List[dict]) -> List[str]:
"""收集 Windows 下需要关闭的 Codex 可执行文件路径"""
targets = []
seen = set()
for item in desktop_instances:
app_path = str(item.get('app_path') or '')
if not app_path:
continue
for candidate in [app_path, str(Path(app_path).parent / 'resources' / 'codex.exe')]:
normalized = candidate.lower()
if normalized in seen:
continue
seen.add(normalized)
targets.append(candidate)
return targets
def build_restart_script(
script_path: Path,
desktop_instances: List[dict],
cli_instances: List[dict],
) -> str:
"""构建重启脚本内容"""
lines = [
'#!/bin/zsh',
'sleep 1',
]
desktop_pids = [str(item['pid']) for item in desktop_instances]
cli_pids = [str(item['pid']) for item in cli_instances]
if desktop_pids:
lines.append(f"kill -TERM {' '.join(desktop_pids)} >/dev/null 2>&1 || true")
if cli_pids:
lines.append(f"kill -TERM {' '.join(cli_pids)} >/dev/null 2>&1 || true")
lines.append('sleep 1')
for item in desktop_instances:
lines.append(f"open -na {shlex.quote(item['app_path'])} >/dev/null 2>&1 || true")
lines.append(f"rm -f {shlex.quote(str(script_path))}")
return '\n'.join(lines) + '\n'
def build_windows_restart_script(script_path: Path, desktop_instances: List[dict]) -> str:
"""构建 Windows PowerShell 重启脚本"""
desktop_pids = [str(int(item['pid'])) for item in desktop_instances if item.get('pid')]
restart_paths = []
seen_paths = set()
for item in desktop_instances:
app_path = str(item.get('app_path') or '')
if not app_path:
continue
normalized = app_path.lower()
if normalized in seen_paths:
continue
seen_paths.add(normalized)
restart_paths.append(app_path)
target_paths = collect_windows_restart_targets(desktop_instances)
escaped_target_paths = ', '.join(
f"'{escape_powershell_string(path)}'" for path in target_paths
) or "''"
escaped_restart_paths = ', '.join(
f"'{escape_powershell_string(path)}'" for path in restart_paths
) or "''"
escaped_script_path = escape_powershell_string(str(script_path))
lines = [
"$ErrorActionPreference = 'SilentlyContinue'",
"Start-Sleep -Seconds 1",
f"$desktopPids = @({', '.join(desktop_pids)})" if desktop_pids else "$desktopPids = @()",
f"$targetPaths = @({escaped_target_paths})",
f"$restartPaths = @({escaped_restart_paths})",
"foreach ($pid in $desktopPids) { Stop-Process -Id $pid -Force -ErrorAction SilentlyContinue }",
"$lookup = @{}",
"foreach ($path in $targetPaths) { if ($path) { $lookup[$path.ToLowerInvariant()] = $true } }",
"Get-CimInstance Win32_Process | Where-Object { $_.ExecutablePath -and $lookup.ContainsKey($_.ExecutablePath.ToLowerInvariant()) } | ForEach-Object { Stop-Process -Id $_.ProcessId -Force -ErrorAction SilentlyContinue }",
"Start-Sleep -Seconds 1",
"foreach ($path in $restartPaths) { if ($path -and (Test-Path -LiteralPath $path)) { Start-Process -FilePath $path | Out-Null } }",
f"Remove-Item -LiteralPath '{escaped_script_path}' -Force -ErrorAction SilentlyContinue",
]
return '\r\n'.join(lines) + '\r\n'
def schedule_codex_restart() -> Tuple[bool, bool, str]:
"""安排后台重启运行中的 Codex Desktop 和 CLI"""
current_platform = platform.system()
if current_platform not in {'Darwin', 'Windows'}:
return False, False, '自动重启目前仅支持 macOS 和 Windows'
desktop_instances = detect_codex_desktop_instances()
cli_instances = detect_codex_cli_instances()
if not desktop_instances and not cli_instances:
return False, False, '未检测到运行中的 Codex 客户端或 Codex CLI'
runtime_dir = get_switcher_dir() / 'runtime'
runtime_dir.mkdir(parents=True, exist_ok=True)
if current_platform == 'Windows':
script_path = runtime_dir / f"restart_codex_{int(time.time())}.ps1"
script_path.write_text(
build_windows_restart_script(script_path, desktop_instances),
encoding='utf-8',
)
else:
script_path = runtime_dir / f"restart_codex_{int(time.time())}.sh"
script_path.write_text(
build_restart_script(script_path, desktop_instances, cli_instances),
encoding='utf-8',
)
script_path.chmod(0o700)
dry_run = os.environ.get(RESTART_DRY_RUN_ENV) == '1'
if dry_run:
message = (
f"[dry-run] 将关闭 Codex 客户端及相关 CLI 进程,"
f"并重启 {len(desktop_instances)} 个 Codex 客户端,会话脚本: {script_path}"
)
return True, True, message
if current_platform == 'Windows':
subprocess.Popen(
['powershell', '-NoProfile', '-ExecutionPolicy', 'Bypass', '-File', str(script_path)],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
creationflags=getattr(subprocess, 'CREATE_NO_WINDOW', 0),
)
else:
subprocess.Popen(
['/bin/zsh', str(script_path)],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
message = (
f"已安排关闭 Codex 客户端及相关 CLI 进程,并重启 {len(desktop_instances)} 个 Codex 客户端"
)
return True, False, message
def finish_switch_with_restart(quiet: bool = False) -> dict:
"""切换账号后安排自动重启"""
scheduled, dry_run, message = schedule_codex_restart()
if scheduled:
if not quiet:
print(f"{Colors.DIM} {message}{Colors.ENDC}")
if not dry_run:
print(f"{Colors.DIM} 正在关闭并重启 Codex 客户端,请稍候...{Colors.ENDC}")
return {'scheduled': True, 'dry_run': dry_run, 'message': message}
if not quiet:
print(f"{Colors.YELLOW} {message}{Colors.ENDC}")
print(f"{Colors.DIM} 提示: 请重启 Codex 或新开终端使登录生效{Colors.ENDC}")
return {'scheduled': False, 'dry_run': False, 'message': message}
# ============== 使用量缓存 ==============
def get_usage_cache_file(email: str, account_id: str = '', record_key: str = '') -> Path:
"""获取使用量缓存文件路径"""
cache_key = get_usage_cache_key(email, account_id, record_key)
return get_usage_cache_dir() / f"usage_{cache_key}.json"
def load_usage_cache(email: str, account_id: str = '', record_key: str = '') -> Optional[dict]:
"""加载使用量缓存"""
candidates = [get_usage_cache_file(email, account_id, record_key)]
legacy_cache = get_usage_cache_dir() / f"usage_{sanitize_key(email)}.json" if email else None
if legacy_cache and legacy_cache not in candidates:
candidates.append(legacy_cache)
for cache_file in candidates:
if not cache_file.exists():
continue
try:
with open(cache_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 检查缓存是否过期(5分钟)
if time.time() - data.get('timestamp', 0) > 300:
continue
return data
except Exception:
continue
return None
def save_usage_cache(email: str, usage_data: dict, account_id: str = '', record_key: str = ''):
"""保存使用量缓存"""
cache_dir = get_usage_cache_dir()
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = get_usage_cache_file(email, account_id, record_key)
payload = dict(usage_data)
payload['timestamp'] = time.time()
payload['email'] = email
payload['account_id'] = account_id
payload['record_key'] = record_key
try:
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(payload, f, indent=2)
except Exception:
pass
# ============== Auth 快照读写 ==============
def load_auth_data_from_path(auth_path: Path) -> Optional[dict]:
"""加载指定 auth 文件"""
if not auth_path.exists():
return None
try:
with open(auth_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
return None
def save_auth_data_to_path(auth_path: Path, auth_data: dict) -> bool:
"""保存指定 auth 文件"""
try:
auth_path.parent.mkdir(parents=True, exist_ok=True)
with open(auth_path, 'w', encoding='utf-8') as f:
json.dump(auth_data, f, indent=2)
f.write('\n')
return True
except Exception:
return False
def merge_refreshed_tokens(auth_data: dict, refreshed_tokens: dict) -> dict:
"""将刷新后的 token 合并到 auth 数据"""
merged = dict(auth_data)
tokens = dict(merged.get('tokens', {}))
claims = extract_claims_from_id_token(refreshed_tokens.get('id_token', '') or tokens.get('id_token', ''))
for key in ('id_token', 'access_token', 'refresh_token'):
value = refreshed_tokens.get(key)
if value:
tokens[key] = value
account_id = refreshed_tokens.get('account_id')
if not account_id and claims:
account_id = claims.get('chatgpt_account_id', '')
if account_id:
tokens['account_id'] = account_id
merged['tokens'] = tokens
merged['last_refresh'] = iso_utc_now()
return merged
def mirror_auth_tokens_to_path(auth_path: Path, refreshed_auth_data: dict) -> bool:
"""将刷新后的 token 字段同步到其他同账号文件"""
existing = load_auth_data_from_path(auth_path)
if not existing:
return False
merged = merge_refreshed_tokens(existing, refreshed_auth_data.get('tokens', {}))
if refreshed_auth_data.get('last_refresh'):
merged['last_refresh'] = refreshed_auth_data.get('last_refresh')
return save_auth_data_to_path(auth_path, merged)
def parse_refresh_error(body: str) -> Tuple[str, str]:
"""解析 refresh token 错误"""
backend_code = ''
message = 'token 刷新失败'
try:
payload = json.loads(body) if body else {}
except Exception:
payload = {}
if isinstance(payload, dict):
error = payload.get('error')
if isinstance(error, dict):
backend_code = str(error.get('code', '') or '')
elif isinstance(error, str):
backend_code = error
backend_code = backend_code.lower()
message = str(
payload.get('error_description')
or payload.get('message')
or payload.get('detail')
or message
)
if backend_code == 'refresh_token_expired':
return 'reauth', 'refresh token 已过期'
if backend_code == 'refresh_token_reused':
return 'reauth', 'refresh token 已被轮换'
if backend_code == 'refresh_token_invalidated':
return 'reauth', 'refresh token 已失效'
return 'error', message
def refresh_tokens_via_oauth(refresh_token: str) -> Tuple[Optional[dict], str]:
"""通过 OpenAI OAuth 刷新 token"""
payload = json.dumps({
'client_id': REFRESH_CLIENT_ID,
'grant_type': 'refresh_token',
'refresh_token': refresh_token,
}).encode('utf-8')
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'User-Agent': 'codex-switcher',
}
req = urllib.request.Request(REFRESH_TOKEN_URL, data=payload, headers=headers, method='POST')
try:
with urllib.request.urlopen(req, timeout=15) as resp:
data = json.loads(resp.read().decode())
refreshed = {
'id_token': data.get('id_token', ''),
'access_token': data.get('access_token', ''),
'refresh_token': data.get('refresh_token', ''),
}
if not any(refreshed.values()):
return None, 'empty_refresh_response'
return refreshed, ''
except urllib.error.HTTPError as e:
body = e.read().decode()[:1000] if e.fp else ''
status, message = parse_refresh_error(body)
if e.code == 401:
return None, status
return None, message or f'HTTP {e.code}'
except Exception as e:
return None, str(e)
# ============== API 获取使用量 ==============
import urllib.request
import urllib.error
def token_expired_or_expiring(access_token: str, last_refresh: str = '') -> bool:
"""判断 access_token 是否需要刷新"""
if not access_token:
return True
payload = decode_jwt_payload(access_token) or {}
exp = payload.get('exp', 0)
now = time.time()
if exp and exp <= now + REFRESH_LOOKAHEAD_SECONDS:
return True
last_refresh_dt = parse_iso_datetime(last_refresh)
if last_refresh_dt and last_refresh_dt < datetime.now(timezone.utc) - timedelta(days=TOKEN_REFRESH_INTERVAL_DAYS):
return True
return False
def request_usage_payload(access_token: str, account_id: str) -> Tuple[Optional[dict], Optional[int], str]:
"""请求 usage API 原始响应"""
headers = {
'Authorization': f'Bearer {access_token}',
'ChatGPT-Account-Id': account_id,
'User-Agent': 'codex-auth',
'Accept': 'application/json',
'Accept-Encoding': 'identity',
}
req = urllib.request.Request(USAGE_API_URL, headers=headers)
try:
with urllib.request.urlopen(req, timeout=15) as resp:
return json.loads(resp.read().decode()), resp.getcode(), ''
except urllib.error.HTTPError as e:
error_body = e.read().decode()[:1000] if e.fp else ''
return None, e.code, error_body
except Exception as e:
return None, None, str(e)
def build_usage_data(data: dict) -> Optional[dict]:
"""解析 usage API 响应"""
rate_limit = data.get('rate_limit', {})
primary = rate_limit.get('primary_window', {})
secondary = rate_limit.get('secondary_window', {})
plan_type = str(data.get('plan_type', 'unknown') or 'unknown')
if not primary and not secondary:
return None
is_team_family = plan_type in {'team', 'business', 'enterprise', 'edu'}
hourly_max = 80 if is_team_family else 50
weekly_max = 500 if is_team_family else 100
usage_data = {
'hourly_used': 'N/A',
'hourly_limit': str(hourly_max),
'hourly_remaining': 'N/A',
'hourly_percent': 0,
'weekly_used': 'N/A',
'weekly_limit': str(weekly_max),
'weekly_remaining': 'N/A',
'weekly_percent': 0,
'next_reset': 'N/A',
'next_reset_weekly': 'N/A',
'reset_at_hourly': 0,
'reset_at_weekly': 0,
'plan_type': plan_type,
}
# 5小时限额
if primary:
used_percent = primary.get('used_percent', 0)
used = int(hourly_max * used_percent / 100)
remaining = hourly_max - used
reset_at = int(primary.get('reset_at', 0) or 0)
usage_data['hourly_used'] = str(used)
usage_data['hourly_remaining'] = str(remaining)
usage_data['hourly_percent'] = used_percent
usage_data['reset_at_hourly'] = reset_at
usage_data['next_reset'] = format_reset_time(reset_at)
# 每周限额
if secondary:
used_percent = secondary.get('used_percent', 0)
used = int(weekly_max * used_percent / 100)
remaining = weekly_max - used
reset_at = int(secondary.get('reset_at', 0) or 0)
usage_data['weekly_used'] = str(used)
usage_data['weekly_remaining'] = str(remaining)
usage_data['weekly_percent'] = used_percent
usage_data['reset_at_weekly'] = reset_at
usage_data['next_reset_weekly'] = format_reset_time(reset_at)
return usage_data
def attempt_token_refresh(auth_path: Path, auth_data: dict) -> Tuple[Optional[dict], str]:
"""尝试刷新单个 auth 快照的 token"""
tokens = auth_data.get('tokens', {})
refresh_token = tokens.get('refresh_token', '')
if not refresh_token:
return None, 'missing_refresh_token'
refreshed_tokens, status = refresh_tokens_via_oauth(refresh_token)
if not refreshed_tokens:
return None, status or 'refresh_failed'
updated_auth = merge_refreshed_tokens(auth_data, refreshed_tokens)
if not save_auth_data_to_path(auth_path, updated_auth):
return None, 'save_failed'
return updated_auth, 'refreshed'
def refresh_usage_for_auth_path(auth_path: Path) -> Optional[dict]:
"""为指定 auth 文件刷新 usage,并在必要时刷新 token"""
auth_data = load_auth_data_from_path(auth_path)
if not auth_data:
return None
info = get_account_info(auth_data, str(auth_path))
if not info:
return None
if token_expired_or_expiring(info.get('access_token', ''), info.get('last_refresh', '')):
updated_auth, refresh_status = attempt_token_refresh(auth_path, auth_data)
if updated_auth:
auth_data = updated_auth
info = get_account_info(auth_data, str(auth_path))
elif refresh_status == 'reauth':
info['refresh_status'] = 'reauth'
info['refresh_status_text'] = '需重登'
return info
access_token = info.get('access_token', '')
account_id = info.get('account_id', '')
payload, status_code, _ = request_usage_payload(access_token, account_id)
if payload:
usage_data = build_usage_data(payload)
if usage_data:
save_usage_cache(
info.get('email', ''),
usage_data,
info.get('account_id', ''),
info.get('record_key', ''),
)
info = get_account_info(load_auth_data_from_path(auth_path) or auth_data, str(auth_path))
info['refresh_status'] = 'fresh'
info['refresh_status_text'] = '已刷新'
return info
if status_code in (401, 403):
updated_auth, refresh_status = attempt_token_refresh(auth_path, auth_data)
if updated_auth:
auth_data = updated_auth
info = get_account_info(auth_data, str(auth_path))
payload, _, _ = request_usage_payload(info.get('access_token', ''), info.get('account_id', ''))
if payload:
usage_data = build_usage_data(payload)
if usage_data:
save_usage_cache(
info.get('email', ''),
usage_data,
info.get('account_id', ''),
info.get('record_key', ''),
)
info = get_account_info(auth_data, str(auth_path))
info['refresh_status'] = 'fresh'
info['refresh_status_text'] = '已刷新'
return info
if refresh_status == 'reauth':
info['refresh_status'] = 'reauth'
info['refresh_status_text'] = '需重登'
else:
info['refresh_status'] = 'cached'
info['refresh_status_text'] = '缓存'
return info
cached_info = get_account_info(auth_data, str(auth_path)) or info
cached_info['refresh_status'] = 'cached'
cached_info['refresh_status_text'] = '缓存'