Skip to content

Commit 3d0f7c1

Browse files
authored
Bump versions, restrict transformers (#461)
1 parent a2454a2 commit 3d0f7c1

17 files changed

Lines changed: 83 additions & 103 deletions

File tree

.isort.cfg

Lines changed: 0 additions & 10 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
- --unsafe
1212
- id: check-added-large-files
1313
- repo: https://github.com/asottile/pyupgrade
14-
rev: v3.19.1
14+
rev: v3.21.2
1515
hooks:
1616
- id: pyupgrade
1717
args:
@@ -31,7 +31,7 @@ repos:
3131
- app/scripts/utility/shell.py
3232
- --remove-duplicate-keys
3333
- repo: https://github.com/pycqa/isort
34-
rev: 5.13.2
34+
rev: 7.0.0
3535
hooks:
3636
- id: isort
3737
name: isort (python)
@@ -42,14 +42,14 @@ repos:
4242
name: isort (pyi)
4343
types: [pyi]
4444
- repo: https://github.com/psf/black
45-
rev: 24.10.0
45+
rev: 26.1.0
4646
hooks:
4747
- id: black
4848
args:
4949
- "--config"
5050
- "./pyproject.toml"
5151
- repo: https://github.com/DavidAnson/markdownlint-cli2
52-
rev: v0.16.0
52+
rev: v0.20.0
5353
hooks:
5454
- id: markdownlint-cli2
5555
name: markdownlint

docs/recipes/generate.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Fast-LLM models support `generate` and `forward` operations through Hugging Face
1212

1313
---
1414

15-
### 🔧 Generating Text from a Fast-LLM Model
15+
## 🔧 Generating Text from a Fast-LLM Model
1616

1717
Below is a step-by-step example of how to generate text using a Fast-LLM model checkpoint from Hugging Face Hub.
1818

fast_llm/config.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,9 @@ def _process_config_class(cls: type["Config"]):
243243
return cls
244244

245245

246-
def config_class[
247-
T: Config
248-
](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]:
246+
def config_class[T: Config](
247+
registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None
248+
) -> typing.Callable[[type[T]], type[T]]:
249249
"""
250250
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
251251
"""
@@ -715,9 +715,7 @@ def to_copy[
715715
def __repr__(self):
716716
return self.to_logs(log_fn=str)
717717

718-
def to_logs[
719-
T
720-
](
718+
def to_logs[T](
721719
self,
722720
verbose: int | None = FieldVerboseLevel.core,
723721
log_fn: typing.Callable[[str], T] = logger.info,
@@ -1048,9 +1046,7 @@ def config(self) -> ConfigType:
10481046
return self._config
10491047

10501048

1051-
def set_nested_dict_value[
1052-
KeyType, ValueType
1053-
](
1049+
def set_nested_dict_value[KeyType, ValueType](
10541050
d: dict[KeyType, ValueType],
10551051
keys: KeyType | tuple[KeyType, ...],
10561052
value: ValueType,
@@ -1094,9 +1090,9 @@ def set_nested_dict_value[
10941090
raise NotImplementedError(update_type)
10951091

10961092

1097-
def get_nested_dict_value[
1098-
KeyType, ValueType
1099-
](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType:
1093+
def get_nested_dict_value[KeyType, ValueType](
1094+
d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]
1095+
) -> ValueType:
11001096
if isinstance(keys, tuple):
11011097
for key in keys:
11021098
d = d[key]
@@ -1105,9 +1101,9 @@ def get_nested_dict_value[
11051101
return d[keys]
11061102

11071103

1108-
def pop_nested_dict_value[
1109-
KeyType, ValueType
1110-
](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType:
1104+
def pop_nested_dict_value[KeyType, ValueType](
1105+
d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]
1106+
) -> ValueType:
11111107
if isinstance(keys, tuple):
11121108
for key in keys[:-1]:
11131109
d = d[key]

fast_llm/engine/config_utils/run.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,9 @@ def is_main_rank() -> bool:
240240
return DistributedConfig.default_rank == _MAIN_RANK
241241

242242

243-
def log_main_rank[
244-
T
245-
](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", ") -> T:
243+
def log_main_rank[T](
244+
*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", "
245+
) -> T:
246246
if is_main_rank():
247247
return log(*message, log_fn=log_fn, join=join)
248248

@@ -251,9 +251,9 @@ def is_model_parallel_main_rank() -> bool:
251251
return is_main_rank() if _run is None else _run._is_model_parallel_main_rank # Noqa
252252

253253

254-
def log_model_parallel_main_rank[
255-
T
256-
](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info) -> T:
254+
def log_model_parallel_main_rank[T](
255+
*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info
256+
) -> T:
257257
if is_model_parallel_main_rank():
258258
return log(*message, log_fn=log_fn)
259259

@@ -262,8 +262,8 @@ def is_pipeline_parallel_main_rank() -> bool:
262262
return is_main_rank() if _run is None else _run._is_pipeline_parallel_main_rank # Noqa
263263

264264

265-
def log_pipeline_parallel_main_rank[
266-
T
267-
](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info) -> T:
265+
def log_pipeline_parallel_main_rank[T](
266+
*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info
267+
) -> T:
268268
if is_pipeline_parallel_main_rank():
269269
return log(*message, log_fn=log_fn)

fast_llm/engine/config_utils/runnable.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def _get_runnable(self) -> typing.Callable[[], None]:
108108
def run(self) -> None:
109109
self._get_runnable()()
110110

111-
def _show[
112-
T
113-
](
111+
def _show[T](
114112
self,
115113
verbose: int = FieldVerboseLevel.core,
116114
*,

fast_llm/engine/distributed/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,9 @@ def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
410410
def get_distributed_dim(self, name: str) -> DistributedDim:
411411
return self.distributed_dims[name]
412412

413-
def _log_on_rank[
414-
T
415-
](self, *message, rank: int | None = None, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info):
413+
def _log_on_rank[T](
414+
self, *message, rank: int | None = None, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info
415+
):
416416
if rank is None or self.rank == rank:
417417
return log(*message, log_fn=log_fn)
418418

fast_llm/functional/autograd.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111

1212
# TODO: Improve type hint (use protocol?)
13-
def wrap_forward_backward[
14-
OutputType, ContextType
15-
](
13+
def wrap_forward_backward[OutputType, ContextType](
1614
forward: typing.Callable[..., tuple[OutputType, ContextType]],
1715
backward: typing.Callable[[OutputType, ContextType], typing.Any],
1816
) -> typing.Callable[..., OutputType]:

fast_llm/logging.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,15 @@ def format_metrics(
123123

124124

125125
@torch._dynamo.disable # noqa
126-
def log_tensor[
127-
T
128-
](
126+
def log_tensor[T](
129127
name: str,
130128
tensor: torch.Tensor,
131129
*,
132130
scale: float = 1.0,
133131
level: int = 2,
134132
storage: bool = False,
135133
log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info,
136-
) -> (T | None):
134+
) -> T | None:
137135
if level < 1:
138136
return
139137
tensor = tensor.detach()
@@ -219,9 +217,7 @@ def log_tensor[
219217

220218

221219
@torch._dynamo.disable # noqa
222-
def log_grad[
223-
T
224-
](
220+
def log_grad[T](
225221
name: str,
226222
tensor: torch.Tensor,
227223
*,
@@ -244,9 +240,7 @@ def log_grad[
244240

245241

246242
@torch._dynamo.disable # noqa
247-
def log_distributed_tensor[
248-
T
249-
](
243+
def log_distributed_tensor[T](
250244
name: str,
251245
tensor: torch.Tensor,
252246
*,
@@ -257,7 +251,7 @@ def log_distributed_tensor[
257251
global_: bool = True,
258252
log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info,
259253
meta: TensorMeta,
260-
) -> (T | None):
254+
) -> T | None:
261255
if level <= 0:
262256
return
263257
if global_:
@@ -278,9 +272,7 @@ def log_distributed_tensor[
278272

279273

280274
@torch._dynamo.disable # noqa
281-
def log_distributed_grad[
282-
T
283-
](
275+
def log_distributed_grad[T](
284276
name: str,
285277
tensor: torch.Tensor,
286278
*,
@@ -292,7 +284,7 @@ def log_distributed_grad[
292284
global_: bool = True,
293285
log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info,
294286
meta: TensorMeta,
295-
) -> (T | None):
287+
) -> T | None:
296288
if level <= 0:
297289
return
298290
tensor.register_hook(
@@ -311,9 +303,7 @@ def log_distributed_grad[
311303

312304

313305
@torch._dynamo.disable # noqa
314-
def log_generator[
315-
T
316-
](
306+
def log_generator[T](
317307
name,
318308
generator: torch.Tensor | torch.Generator | None = None,
319309
log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info,
@@ -328,9 +318,7 @@ def log_generator[
328318
_global_max_reserved = 0
329319

330320

331-
def log_memory_usage[
332-
T
333-
](
321+
def log_memory_usage[T](
334322
header: str | None = None,
335323
log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info,
336324
reset_stats: bool = True,

fast_llm/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,9 @@ def __getitem__(self, key: KeyType) -> ValueType:
261261
return super().__getitem__(key)()
262262

263263

264-
def log[
265-
T
266-
](*message: typing.Any, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", ") -> T:
264+
def log[T](
265+
*message: typing.Any, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", "
266+
) -> T:
267267
message = join.join([str(m() if callable(m) else m) for m in message])
268268
logged = log_fn(message)
269269
if isinstance(logged, BaseException):

0 commit comments

Comments
 (0)