-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathusage_counter.py
More file actions
133 lines (108 loc) · 4.58 KB
/
usage_counter.py
File metadata and controls
133 lines (108 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Dict, Callable, Optional
import threading
import logging
logger = logging.getLogger(__name__)
class ModelUsageCounter:
"""Token and Time counter for counting and estimating usage."""
def __init__(self,
total: int = 0,
name: str = "Model",
parallel: bool = False,
on_update: Optional[Callable[["ModelUsageCounter"], None]] = None
) -> None:
"""
Initialize token and time counter
Args:
total: Anticipated number of each iteration unit.
name: The module name used for identification when using logging.info:
[{name} Usage]: completed=XXXX, token=XXXX, time=XXXX || remain=XXXX, remain_token_anticipation=XXXX, remain_time_anticipation=XXXX
parallel: Whether to be used in parallel processing. If parallel, time needs to be counted additionally.
on_update: Optional callback called after estimate_usage with self as argument.
"""
self.name = name
self.total = total
self.token: int = 0
self.time: float = 0 # measured in seconds
self.completed: int = 0 # number of iteration unit
self.remain: int = self.total - self.completed # number of remain iteration unit
# parallel setting
self._is_parallel = parallel
self._lock = threading.RLock()
# callback for progress updates
self._on_update = on_update
def load_from_dict(self, usage: Dict):
self.total = usage.get("total", 0)
self.name = usage.get("name", self.name)
self.token = usage.get("token", 0)
self.time = usage.get("time", 0.0)
self.completed = usage.get("completed", 0)
self.remain = self.total - self.completed
if self.remain < 0: # verify the remain to > 0
self.remain = 0
self.total = self.completed
def add_usage(self,
n_token: int,
time: float,
):
"""
add token and time usage to this counter.
Args:
n_token: token cost
time: time cost
if add_usage in parallel processing, time will not be counted.
Please use set_parallel_time(time) in parallel processing.
"""
with self._lock:
self.token += n_token
if not self._is_parallel:
self.time += time
def set_sequential(self):
self._is_parallel = False
def set_parallel(self):
self._is_parallel = True
def set_parallel_time(self, time: float):
if self._is_parallel:
self.time = time
def set_on_update(self, callback: Optional[Callable[["ModelUsageCounter"], None]]):
"""Set callback to be called after estimate_usage."""
self._on_update = callback
def resize_total(self, total: int):
"""Resize total and recalculate remain accordingly."""
self.total = total
self.remain = self.total - self.completed
if self.remain < 0:
self.remain = 0
def estimate_usage(self, n):
"""
estimate the token and time usage of n iteration units
Args:
n: the number of iteration units
"""
with self._lock:
n = min(self.remain, n)
self.remain -= n
self.completed += n
avg_token = self.token / self.completed
avg_time = self.time / self.completed
remain_token_anticipation = int(avg_token * self.remain)
remain_time_anticipation = avg_time * self.remain
logger.info(f"[{self.name} Usage]: completed={self.completed}, token={self.token}, time={self.time:.2f} || remain={self.remain}, remain_token_anticipation={remain_token_anticipation}, remain_time_anticipation={remain_time_anticipation:.2f}")
# call callback outside of lock to avoid deadlock
if self._on_update:
self._on_update(self)
@property
def estimated_remaining_tokens(self) -> int:
"""Estimate remaining tokens based on average usage."""
if self.completed == 0:
return 0
avg_token = self.token / self.completed
return int(avg_token * self.remain)
@property
def estimated_remaining_time(self) -> float:
"""Estimate remaining time based on average usage."""
if self.completed == 0:
return 0.0
avg_time = self.time / self.completed
return avg_time * self.remain
def __str__(self) -> str:
return f"[{self.name} Usage]: completed={self.completed}, remain={self.remain}, token={self.token}, time={self.time:.2f}"