diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index d9d8388055af9..497461512b69b 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -45,6 +45,7 @@ from airflow.models import DagModel, DagRun, TaskInstance from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel +from airflow.models.taskinstance import clear_task_instances from airflow.timetables.base import TimeRestriction from airflow.utils import cli as cli_utils from airflow.utils.cli import ( @@ -58,7 +59,7 @@ from airflow.utils.platform import getuser from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.state import DagRunState +from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunType if TYPE_CHECKING: @@ -75,6 +76,9 @@ log = logging.getLogger(__name__) +# Chunk size for bulk delete. +_RUN_CHUNK_SIZE = 500 + @cli_utils.action_cli @providers_configuration_loaded @@ -195,15 +199,47 @@ def dag_clear(args, *, session: Session = NEW_SESSION) -> None: print("Cancelled, nothing was cleared.") return + cleared = _bulk_clear_runs( + args.dag_id, + run_ids, + only_failed=args.only_failed, + only_running=args.only_running, + session=session, + ) + print(f"Cleared {cleared} task instance(s) across {len(run_ids)} Dag run(s).") + + +def _bulk_clear_runs( + dag_id: str, + run_ids: list[str], + only_failed: bool, + only_running: bool, + session: Session, +) -> int: + """Clear task instances for the given run_ids in chunks instead of one transaction per run.""" + state_filter: list = [] + if only_failed: + state_filter += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED] + if only_running: + state_filter += [TaskInstanceState.RUNNING] + cleared = 0 - for run_id in run_ids: - cleared += dag.clear( - run_id=run_id, - only_failed=args.only_failed, - only_running=args.only_running, - session=session, + for chunk_start in range(0, len(run_ids), _RUN_CHUNK_SIZE): + chunk_run_ids = run_ids[chunk_start : chunk_start + _RUN_CHUNK_SIZE] + ti_query = select(TaskInstance).where( + TaskInstance.dag_id == dag_id, + TaskInstance.run_id.in_(chunk_run_ids), ) - print(f"Cleared {cleared} task instance(s) across {len(run_ids)} Dag run(s).") + if state_filter: + ti_query = ti_query.where(TaskInstance.state.in_(state_filter)) + tis = session.scalars(ti_query).all() + if not tis: + continue + clear_task_instances(list(tis), session=session) + session.flush() + cleared += len(tis) + + return cleared @cli_utils.action_cli diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index 6f76034a4218d..d243547ec988a 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -1267,6 +1267,13 @@ def _get_run_states(self): for row in session.scalars(select(DagRun).where(DagRun.dag_id == self.DAG_ID)).all() } + def _get_run_clear_numbers(self): + with create_session() as session: + return { + row.run_id: row.clear_number + for row in session.scalars(select(DagRun).where(DagRun.dag_id == self.DAG_ID)).all() + } + def test_requires_a_selector(self, parser): args = parser.parse_args(["dags", "clear", self.DAG_ID, "--yes"]) with pytest.raises(SystemExit, match="One of --run-id, --partition-key"): @@ -1959,6 +1966,143 @@ def test_asset_timetable_upper_bound_over_cap(self, parser): assert states["asset_2026_04_15"] == DagRunState.SUCCESS assert states["asset_non_part"] == DagRunState.SUCCESS + @pytest.mark.usefixtures("seeded_partitioned_runs") + def test_clears_multiple_runs_in_one_batch(self, parser): + """3 runs fit in one chunk, so clear_task_instances is called once (not N times).""" + from airflow.models.taskinstance import clear_task_instances + + call_count = 0 + + def counting_clear(tis, session, **kwargs): + nonlocal call_count + call_count += 1 + return clear_task_instances(tis, session, **kwargs) + + args = parser.parse_args( + [ + "dags", + "clear", + self.DAG_ID, + "--partition-date-start", + "2026-03-08T00:00:00", + "--partition-date-end", + "2026-03-14T00:00:00", + "--yes", + ] + ) + with mock.patch( + "airflow.cli.commands.dag_command.clear_task_instances", + side_effect=counting_clear, + ): + dag_command.dag_clear(args) + + # 3 partitioned runs all fit in a single run-id chunk. + assert call_count == 1 + + states = self._get_run_states() + assert states["part_2026_03_08"] == DagRunState.QUEUED + assert states["part_2026_03_10"] == DagRunState.QUEUED + assert states["part_2026_03_14"] == DagRunState.QUEUED + assert states["non_partitioned"] == DagRunState.SUCCESS + + @pytest.mark.usefixtures("seeded_partitioned_runs") + def test_chunks_on_run_boundaries_clears_each_run_once(self, parser): + """Across multiple chunks, each run is cleared once""" + from airflow.models.taskinstance import clear_task_instances + + call_count = 0 + + def counting_clear(tis, session, **kwargs): + nonlocal call_count + call_count += 1 + return clear_task_instances(tis, session, **kwargs) + + args = parser.parse_args( + [ + "dags", + "clear", + self.DAG_ID, + "--partition-date-start", + "2026-03-08T00:00:00", + "--partition-date-end", + "2026-03-14T00:00:00", + "--yes", + ] + ) + with ( + mock.patch.object(dag_command, "_RUN_CHUNK_SIZE", 2), + mock.patch( + "airflow.cli.commands.dag_command.clear_task_instances", + side_effect=counting_clear, + ), + ): + dag_command.dag_clear(args) + + # 3 runs with chunk size 2 → 2 calls. + assert call_count == 2 + + clear_numbers = self._get_run_clear_numbers() + assert clear_numbers["part_2026_03_08"] == 1 + assert clear_numbers["part_2026_03_10"] == 1 + assert clear_numbers["part_2026_03_14"] == 1 + assert clear_numbers["non_partitioned"] == 0 + + @pytest.mark.usefixtures("seeded_partitioned_runs") + def test_does_not_clear_runs_of_other_dags(self, parser, dag_maker): + """A run_id collision across DAGs must not clear the other DAG's task instances.""" + other_dag_id = "test_dags_clear_other_dag" + with dag_maker( + other_dag_id, + schedule=CronPartitionTimetable("0 0 * * *", timezone=pendulum.UTC), + start_date=datetime(2026, 3, 1, tzinfo=pendulum.UTC), + catchup=True, + serialized=True, + ): + EmptyOperator(task_id="t1") + # Same run_id and partition_date as a run cleared below, but a different DAG. + dag_maker.create_dagrun( + run_id="part_2026_03_08", + state=DagRunState.SUCCESS, + logical_date=None, + partition_date=datetime(2026, 3, 8, tzinfo=pendulum.UTC), + partition_key="2026-03-08T00:00:00", + ) + dag_maker.sync_dagbag_to_db() + # If dag_id is not filtered, clearing the other DAG would reset this TI to None. + with create_session() as session: + session.execute( + TaskInstance.__table__.update() + .where(TaskInstance.dag_id == other_dag_id) + .values(state=TaskInstanceState.SUCCESS) + ) + + args = parser.parse_args( + [ + "dags", + "clear", + self.DAG_ID, + "--partition-date-start", + "2026-03-08T00:00:00", + "--partition-date-end", + "2026-03-14T00:00:00", + "--yes", + ] + ) + dag_command.dag_clear(args) + + # The target DAG's same-named run must be cleared. + assert self._get_run_states()["part_2026_03_08"] == DagRunState.QUEUED + + # The other DAG's same-named run must be left untouched. + with create_session() as session: + other_run = session.scalars( + select(DagRun).where(DagRun.dag_id == other_dag_id, DagRun.run_id == "part_2026_03_08") + ).one() + assert other_run.state == DagRunState.SUCCESS + assert other_run.clear_number == 0 + other_ti = session.scalars(select(TaskInstance).where(TaskInstance.dag_id == other_dag_id)).one() + assert other_ti.state == TaskInstanceState.SUCCESS + class TestDagDetailsIsBackfillable: """Tests for the is_backfillable computation in _get_dagbag_dag_details."""