forked from ogx-ai/ogx-client-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpost_training.py
More file actions
118 lines (95 loc) · 3.74 KB
/
post_training.py
File metadata and controls
118 lines (95 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
import click
from rich.console import Console
from llama_stack_client.types.alpha.post_training_supervised_fine_tune_params import (
AlgorithmConfigParam,
TrainingConfig,
)
from ..common.utils import handle_client_errors
@click.group()
@click.help_option("-h", "--help")
def post_training():
"""Post-training."""
@click.command("supervised_fine_tune")
@click.help_option("-h", "--help")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.option("--model", required=True, help="Model ID")
@click.option("--algorithm-config", required=True, help="Algorithm Config")
@click.option("--training-config", required=True, help="Training Config")
@click.option("--checkpoint-dir", required=False, help="Checkpoint Config", default=None)
@click.pass_context
@handle_client_errors("post_training supervised_fine_tune")
def supervised_fine_tune(
ctx,
job_uuid: str,
model: str,
algorithm_config: AlgorithmConfigParam,
training_config: TrainingConfig,
checkpoint_dir: Optional[str],
):
"""Kick off a supervised fine tune job"""
client = ctx.obj["client"]
console = Console()
post_training_job = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
model=model,
algorithm_config=algorithm_config,
training_config=training_config,
checkpoint_dir=checkpoint_dir,
# logger_config and hyperparam_search_config haven't been used yet
logger_config={},
hyperparam_search_config={},
)
console.print(post_training_job.job_uuid)
@click.command("list")
@click.help_option("-h", "--help")
@click.pass_context
@handle_client_errors("post_training get_training_jobs")
def get_training_jobs(ctx):
"""Show the list of available post training jobs"""
client = ctx.obj["client"]
console = Console()
post_training_jobs = client.post_training.job.list()
console.print([post_training_job.job_uuid for post_training_job in post_training_jobs])
@click.command("status")
@click.help_option("-h", "--help")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training get_training_job_status")
def get_training_job_status(ctx, job_uuid: str):
"""Show the status of a specific post training job"""
client = ctx.obj["client"]
console = Console()
job_status_reponse = client.post_training.job.status(job_uuid=job_uuid)
console.print(job_status_reponse)
@click.command("artifacts")
@click.help_option("-h", "--help")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training get_training_job_artifacts")
def get_training_job_artifacts(ctx, job_uuid: str):
"""Get the training artifacts of a specific post training job"""
client = ctx.obj["client"]
console = Console()
job_artifacts = client.post_training.job.artifacts(job_uuid=job_uuid)
console.print(job_artifacts)
@click.command("cancel")
@click.help_option("-h", "--help")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training cancel_training_job")
def cancel_training_job(ctx, job_uuid: str):
"""Cancel the training job"""
client = ctx.obj["client"]
client.post_training.job.cancel(job_uuid=job_uuid)
# Register subcommands
post_training.add_command(supervised_fine_tune)
post_training.add_command(get_training_jobs)
post_training.add_command(get_training_job_status)
post_training.add_command(get_training_job_artifacts)
post_training.add_command(cancel_training_job)