Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions airflow-core/src/airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airflow.models import DagModel, DagRun, TaskInstance
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import clear_task_instances
from airflow.timetables.base import TimeRestriction
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
Expand All @@ -58,7 +59,7 @@
from airflow.utils.platform import getuser
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
Expand All @@ -75,6 +76,9 @@

log = logging.getLogger(__name__)

# Chunk size for bulk delete.
_RUN_CHUNK_SIZE = 500


@cli_utils.action_cli
@providers_configuration_loaded
Expand Down Expand Up @@ -195,15 +199,47 @@ def dag_clear(args, *, session: Session = NEW_SESSION) -> None:
print("Cancelled, nothing was cleared.")
return

cleared = _bulk_clear_runs(
args.dag_id,
run_ids,
only_failed=args.only_failed,
only_running=args.only_running,
session=session,
)
print(f"Cleared {cleared} task instance(s) across {len(run_ids)} Dag run(s).")


def _bulk_clear_runs(
dag_id: str,
run_ids: list[str],
only_failed: bool,
only_running: bool,
session: Session,
) -> int:
"""Clear task instances for the given run_ids in chunks instead of one transaction per run."""
state_filter: list = []
if only_failed:
state_filter += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED]
if only_running:
state_filter += [TaskInstanceState.RUNNING]

cleared = 0
for run_id in run_ids:
cleared += dag.clear(
run_id=run_id,
only_failed=args.only_failed,
only_running=args.only_running,
session=session,
for chunk_start in range(0, len(run_ids), _RUN_CHUNK_SIZE):
chunk_run_ids = run_ids[chunk_start : chunk_start + _RUN_CHUNK_SIZE]
ti_query = select(TaskInstance).where(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id.in_(chunk_run_ids),
)
print(f"Cleared {cleared} task instance(s) across {len(run_ids)} Dag run(s).")
if state_filter:
ti_query = ti_query.where(TaskInstance.state.in_(state_filter))
tis = session.scalars(ti_query).all()
if not tis:
continue
clear_task_instances(list(tis), session=session)
session.flush()
cleared += len(tis)

return cleared


@cli_utils.action_cli
Expand Down
144 changes: 144 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,13 @@ def _get_run_states(self):
for row in session.scalars(select(DagRun).where(DagRun.dag_id == self.DAG_ID)).all()
}

def _get_run_clear_numbers(self):
with create_session() as session:
return {
row.run_id: row.clear_number
for row in session.scalars(select(DagRun).where(DagRun.dag_id == self.DAG_ID)).all()
}

def test_requires_a_selector(self, parser):
args = parser.parse_args(["dags", "clear", self.DAG_ID, "--yes"])
with pytest.raises(SystemExit, match="One of --run-id, --partition-key"):
Expand Down Expand Up @@ -1964,6 +1971,143 @@ def test_asset_timetable_upper_bound_over_cap(self, parser):
assert states["asset_2026_04_15"] == DagRunState.SUCCESS
assert states["asset_non_part"] == DagRunState.SUCCESS

@pytest.mark.usefixtures("seeded_partitioned_runs")
def test_clears_multiple_runs_in_one_batch(self, parser):
"""3 runs fit in one chunk, so clear_task_instances is called once (not N times)."""
from airflow.models.taskinstance import clear_task_instances

call_count = 0

def counting_clear(tis, session, **kwargs):
nonlocal call_count
call_count += 1
return clear_task_instances(tis, session, **kwargs)

args = parser.parse_args(
[
"dags",
"clear",
self.DAG_ID,
"--partition-date-start",
"2026-03-08T00:00:00",
"--partition-date-end",
"2026-03-14T00:00:00",
"--yes",
]
)
with mock.patch(
"airflow.cli.commands.dag_command.clear_task_instances",
side_effect=counting_clear,
):
dag_command.dag_clear(args)

# 3 partitioned runs all fit in a single run-id chunk.
assert call_count == 1

states = self._get_run_states()
assert states["part_2026_03_08"] == DagRunState.QUEUED
assert states["part_2026_03_10"] == DagRunState.QUEUED
assert states["part_2026_03_14"] == DagRunState.QUEUED
assert states["non_partitioned"] == DagRunState.SUCCESS

@pytest.mark.usefixtures("seeded_partitioned_runs")
def test_chunks_on_run_boundaries_clears_each_run_once(self, parser):
"""Across multiple chunks, each run is cleared once"""
from airflow.models.taskinstance import clear_task_instances

call_count = 0

def counting_clear(tis, session, **kwargs):
nonlocal call_count
call_count += 1
return clear_task_instances(tis, session, **kwargs)

args = parser.parse_args(
[
"dags",
"clear",
self.DAG_ID,
"--partition-date-start",
"2026-03-08T00:00:00",
"--partition-date-end",
"2026-03-14T00:00:00",
"--yes",
]
)
with (
mock.patch.object(dag_command, "_RUN_CHUNK_SIZE", 2),
mock.patch(
"airflow.cli.commands.dag_command.clear_task_instances",
side_effect=counting_clear,
),
):
dag_command.dag_clear(args)

# 3 runs with chunk size 2 → 2 calls.
assert call_count == 2

clear_numbers = self._get_run_clear_numbers()
assert clear_numbers["part_2026_03_08"] == 1
assert clear_numbers["part_2026_03_10"] == 1
assert clear_numbers["part_2026_03_14"] == 1
assert clear_numbers["non_partitioned"] == 0

@pytest.mark.usefixtures("seeded_partitioned_runs")
def test_does_not_clear_runs_of_other_dags(self, parser, dag_maker):
"""A run_id collision across DAGs must not clear the other DAG's task instances."""
other_dag_id = "test_dags_clear_other_dag"
with dag_maker(
other_dag_id,
schedule=CronPartitionTimetable("0 0 * * *", timezone=pendulum.UTC),
start_date=datetime(2026, 3, 1, tzinfo=pendulum.UTC),
catchup=True,
serialized=True,
):
EmptyOperator(task_id="t1")
# Same run_id and partition_date as a run cleared below, but a different DAG.
dag_maker.create_dagrun(
run_id="part_2026_03_08",
state=DagRunState.SUCCESS,
logical_date=None,
partition_date=datetime(2026, 3, 8, tzinfo=pendulum.UTC),
partition_key="2026-03-08T00:00:00",
)
dag_maker.sync_dagbag_to_db()
# If dag_id is not filtered, clearing the other DAG would reset this TI to None.
with create_session() as session:
session.execute(
TaskInstance.__table__.update()
.where(TaskInstance.dag_id == other_dag_id)
.values(state=TaskInstanceState.SUCCESS)
)

args = parser.parse_args(
[
"dags",
"clear",
self.DAG_ID,
"--partition-date-start",
"2026-03-08T00:00:00",
"--partition-date-end",
"2026-03-14T00:00:00",
"--yes",
]
)
dag_command.dag_clear(args)

# The target DAG's same-named run must be cleared.
assert self._get_run_states()["part_2026_03_08"] == DagRunState.QUEUED

# The other DAG's same-named run must be left untouched.
with create_session() as session:
other_run = session.scalars(
select(DagRun).where(DagRun.dag_id == other_dag_id, DagRun.run_id == "part_2026_03_08")
).one()
assert other_run.state == DagRunState.SUCCESS
assert other_run.clear_number == 0
other_ti = session.scalars(select(TaskInstance).where(TaskInstance.dag_id == other_dag_id)).one()
assert other_ti.state == TaskInstanceState.SUCCESS


class TestDagDetailsIsBackfillable:
"""Tests for the is_backfillable computation in _get_dagbag_dag_details."""
Expand Down
Loading