Skip to content

Commit acb33c9

Browse files
committed
refactor: Improve assertion formatting and training client args handling in TinkerBackend
- Reformatted the assertion for clarity when checking TINKER_API_KEY. - Enhanced the configuration of tinker_args to ensure training_client_args are properly set up, maintaining backward compatibility.
1 parent 66c4a3e commit acb33c9

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

src/art/tinker/backend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
3636
config=model._internal_config,
3737
)
3838
config["tinker_args"] = config.get("tinker_args") or TinkerArgs(
39-
renderer_name=get_renderer_name(model.base_model),
40-
training_client_args=TinkerTrainingClientArgs(
41-
rank=8,
42-
),
39+
renderer_name=get_renderer_name(model.base_model)
40+
)
41+
config["tinker_args"]["training_client_args"] = config["tinker_args"].get(
42+
"training_client_args"
43+
) or TinkerTrainingClientArgs(
44+
rank=8,
4345
)
4446
self._services[model.name] = TinkerService(
4547
model_name=model.name,

0 commit comments

Comments
 (0)