-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathairflow.py
More file actions
39 lines (30 loc) · 1.37 KB
/
airflow.py
File metadata and controls
39 lines (30 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import logging
from sync.databricks.integrations._run_submit_runner import (
apply_sync_gradient_cluster_recommendation,
)
logger = logging.getLogger(__name__)
def airflow_gradient_pre_execute_hook(context: dict):
try:
logger.info("Running airflow gradient pre-execute hook!")
logger.debug(f"Airflow operator context - context:{context}")
task_id = context["task"].task_id
gradient_app_id = context["params"]["gradient_app_id"]
auto_apply = context["params"]["gradient_auto_apply"]
cluster_log_url = context["params"]["cluster_log_url"]
workspace_id = context["params"]["databricks_workspace_id"]
run_submit_task = context[
"task"
].json.copy() # copy the run submit json from the task context
updated_task_configuration = apply_sync_gradient_cluster_recommendation(
run_submit_task=run_submit_task,
gradient_app_id=build_app_id(task_id, gradient_app_id),
auto_apply=auto_apply,
cluster_log_url=cluster_log_url,
workspace_id=workspace_id,
)
context["task"].json = updated_task_configuration
except Exception as e:
logger.exception(e)
logger.error("Unable to apply gradient configuration to Databricks run submit tasks")
def build_app_id(task_id: str, app_id: str):
return f"{task_id}-{app_id}"