11from infinilm .lib import _infinilm
2+ import infinicore
3+ from ..modeling_utils import parse_dtype
24
35
46class CacheConfig (_infinilm .CacheConfig ):
@@ -9,22 +11,45 @@ def __init__(self):
911
1012
1113class StaticKVCacheConfig (CacheConfig , _infinilm .StaticKVCacheConfig ):
12- def __init__ (self , max_batch_size : int = 1 , max_cache_len : int = 0 , kv_cache_dtype : str | None = None ):
13- if kv_cache_dtype is None :
14- _infinilm .StaticKVCacheConfig .__init__ (self , max_batch_size , max_cache_len )
14+ def __init__ (
15+ self ,
16+ max_batch_size : int = 1 ,
17+ max_cache_len : int = 0 ,
18+ kv_cache_dtype = None ,
19+ ):
20+ if isinstance (kv_cache_dtype , str ):
21+ _infinilm .StaticKVCacheConfig .__init__ (
22+ self ,
23+ max_batch_size ,
24+ max_cache_len ,
25+ parse_dtype (kv_cache_dtype )._underlying ,
26+ )
27+ elif isinstance (kv_cache_dtype , infinicore .dtype .dtype ):
28+ _infinilm .StaticKVCacheConfig .__init__ (
29+ self , max_batch_size , max_cache_len , kv_cache_dtype ._underlying
30+ )
1531 else :
16- _infinilm .StaticKVCacheConfig .__init__ (self , max_batch_size , max_cache_len , kv_cache_dtype )
32+ _infinilm .StaticKVCacheConfig .__init__ (
33+ self , max_batch_size , max_cache_len , kv_cache_dtype
34+ )
35+
1736
1837class PagedKVCacheConfig (CacheConfig , _infinilm .PagedKVCacheConfig ):
1938 def __init__ (
2039 self ,
2140 num_blocks : int ,
2241 block_size : int = 256 ,
23- kv_cache_dtype : str | None = None ,
42+ kv_cache_dtype = None ,
2443 ):
25- if kv_cache_dtype is None :
26- _infinilm .PagedKVCacheConfig .__init__ (self , num_blocks , block_size )
44+ if isinstance (kv_cache_dtype , str ):
45+ _infinilm .PagedKVCacheConfig .__init__ (
46+ self , num_blocks , block_size , parse_dtype (kv_cache_dtype )._underlying
47+ )
48+ elif isinstance (kv_cache_dtype , infinicore .dtype .dtype ):
49+ _infinilm .PagedKVCacheConfig .__init__ (
50+ self , num_blocks , block_size , kv_cache_dtype ._underlying
51+ )
2752 else :
2853 _infinilm .PagedKVCacheConfig .__init__ (
29- self , num_blocks , kv_cache_dtype , block_size
54+ self , num_blocks , block_size , kv_cache_dtype
3055 )
0 commit comments