Skip to content

Commit ad645b0

Browse files
committed
Update Mlflow
1 parent f7ed6c1 commit ad645b0

6 files changed

Lines changed: 20 additions & 4 deletions

File tree

bats_ai/core/management/commands/examplelog.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def command(experiment_name):
1414
if experiment:
1515
click.echo(f'Creating a log for experiment {experiment_name}')
1616
example_train.delay(experiment_name)
17-
# train_body(experiment_name)
1817
else:
1918
click.echo(
2019
f'Could not find experiment {experiment_name}.'

bats_ai/core/management/commands/makeexperiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
@click.command()
2121
def command(username, name, description: str | None = None):
22+
user = None
2223
if username:
2324
user = User.objects.get(username=username)
2425
else:

bats_ai/core/management/commands/registeronnxmodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from django.conf import settings
44
import djclick as click
55
import mlflow
6-
import mlflow.onnx
6+
import mlflow.onnx as mlflow_onnx
77
import onnx
88

99

@@ -19,7 +19,7 @@ def command():
1919
with mlflow.start_run() as run:
2020
run_id = run.info.run_id
2121
click.echo(f'Run ID: {run_id}')
22-
mlflow.onnx.log_model(
22+
mlflow_onnx.log_model(
2323
onnx_model=onnx_model,
2424
artifact_path='onnx_model',
2525
# save_as_external_data=True,

bats_ai/core/management/commands/setupmlflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@click.command()
99
def setupmlflow():
10-
db_name = settings.MLFLOW_PG_DB if settings.MLFLOW_PG_DB else 'mlflow'
10+
db_name = settings.MLFLOW_DB if settings.MLFLOW_DB else 'mlflow'
1111
bucket_name = settings.MLFLOW_BUCKET if settings.MLFLOW_BUCKET else 'mlflow'
1212

1313
click.echo(f'Creating database {db_name} for mlflow')

bats_ai/settings/development.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@
1515
staticfiles_index = INSTALLED_APPS.index('django.contrib.staticfiles')
1616
INSTALLED_APPS.insert(staticfiles_index, 'whitenoise.runserver_nostatic')
1717

18+
DATABASES = {
19+
**DATABASES,
20+
'mlflow': {
21+
**env.db_url('DJANGO_MLFLOW_DB_URL', engine='django.contrib.gis.db.backends.postgis'),
22+
'CONN_MAX_AGE': timedelta(minutes=10).total_seconds(),
23+
}
24+
}
25+
26+
MLFLOW_BUCKET: str = env.str('DJANGO_MLFLOW_BUCKET')
27+
MLFLOW_DB: str = env.str('DJANGO_MLFLOW_DB')
28+
MLFLOW_ENDPOINT: str = env.str('DJANGO_MLFLOW_ENDPOINT')
29+
1830
# Include Debug Toolbar middleware as early as possible in the list.
1931
# However, it must come after any other middleware that encodes the response's content,
2032
# such as GZipMiddleware.

bats_ai/tasks/tasks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ def train_body(experiment_name: str):
264264
mlflow.log_metric('accuracy', accuracy)
265265
mlflow.set_tag('Training Info', 'Basic LR model for iris data')
266266

267+
print("ENV AWS_ACCESS_KEY_ID =", os.getenv("AWS_ACCESS_KEY_ID"))
268+
print("ENV AWS_SECRET_ACCESS_KEY =", os.getenv("AWS_SECRET_ACCESS_KEY"))
269+
print("ENV MLFLOW_S3_ENDPOINT_URL =", os.getenv("MLFLOW_S3_ENDPOINT_URL"))
270+
267271
signature = infer_signature(X_train, lr.predict(X_train))
268272
_ = mlflow.sklearn.log_model(
269273
sk_model=lr,

0 commit comments

Comments
 (0)