Skip to content
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
TaskGraph Release History
=========================

..
Unreleased Changes
------------------
Unreleased Changes
------------------
* When using ``n_workers >= 1``, the ``TaskGraph`` object will now monitor the
underlying ``multiprocessing.Pool`` object for any changes to the PIDs of its
processes. If a change is detected, the graph is shut down to avoid a
deadlock. https://github.com/natcap/taskgraph/issues/109


0.11.2 (2025-05-21)
-------------------
Expand Down
45 changes: 45 additions & 0 deletions taskgraph/Task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import sqlite3
import threading
import time

try:
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version
Expand Down Expand Up @@ -396,8 +397,16 @@ def __init__(
target=_logging_queue_monitor,
args=(self._logging_queue,))

self._process_pool_monitor_wait_event = threading.Event()
self._process_pool_monitor_thread = threading.Thread(
target=self._process_pool_monitor,
args=(self._process_pool_monitor_wait_event,))

self._logging_monitor_thread.daemon = True
self._logging_monitor_thread.start()
self._process_pool_monitor_thread.daemon = True
self._process_pool_monitor_thread.start()

if HAS_PSUTIL:
parent = psutil.Process()
parent.nice(PROCESS_LOW_PRIORITY)
Expand Down Expand Up @@ -763,6 +772,42 @@ def _execution_monitor(self, monitor_wait_event):
(time.time() - start_time)) % self._reporting_interval)
LOGGER.debug("_execution monitor shutting down")

def _process_pool_monitor(self, pool_monitor_wait_event):
"""Monitor the state of the multiprocessing pool's workers.

Python's multiprocessing.Pool has a bunch of logic to make sure that
the pool always has the same number of workers, and it can even limit
the lifespan of the pool's worker processes. In our case, worker
processes have multiprocessing.Event objects on them, which means that
if a worker process dies for any reason, the whole TaskGraph object
will hang. This worker process monitors for any changes in the PIDs of
a multiprocessing.Pool object and terminates the graph if any are
found.

Args:
pool_monitor_wait_event (threading.Event): used to sleep the
monitor thread for 0.5 seconds.
"""
starting_pool_pids = set(proc.pid for proc in self._worker_pool._pool)

while True:
if self._terminated:
break

current_pids = set(
proc.pid for proc in self._worker_pool._pool)

if current_pids != starting_pool_pids:
LOGGER.error(
"A change in process pool PIDs has been detected! "
"Shutting down the task graph. "
f"{starting_pool_pids} changed to {current_pids }")
self._terminate()

# Wait 0.5s before looping.
pool_monitor_wait_event.wait(timeout=0.5)
LOGGER.debug('_process_pool_monitor shutting down')

def join(self, timeout=None):
"""Join all threads in the graph.

Expand Down
33 changes: 33 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pickle
import re
import shutil
import signal
import sqlite3
import subprocess
import tempfile
Expand Down Expand Up @@ -140,6 +141,19 @@ def _log_from_another_process(logger_name, log_message):
logger.info(log_message)


def _kill_current_process():
"""Kill the current process.

Must be run within a taskgraph task process.
"""
if __name__ == '__main__':
raise AssertionError(
"This function is only supposed to be called in a subprocess")

# Signal.SIGTERM works on both *NIX and Windows.
os.kill(os.getpid(), signal.SIGTERM)


class TaskGraphTests(unittest.TestCase):
"""Tests for the taskgraph."""

Expand Down Expand Up @@ -1480,6 +1494,25 @@ def test_mtime_mismatch(self):
with open(target_path) as target_file:
self.assertEqual(target_file.read(), content)

def test_multiprocessing_deadlock(self):
"""Verify that the graph is shut down in case of deadlock.

This test will deadlock if the functionality it is testing for (graph
shutdown when a task process is killed) is not available.

See https://github.com/natcap/taskgraph/issues/109
"""
task_graph = taskgraph.TaskGraph(self.workspace_dir, n_workers=1)
with self.assertLogs('taskgraph', level='ERROR') as cm:
_ = task_graph.add_task(_kill_current_process)
task_graph.join()
task_graph.close()

self.assertEqual(len(cm.output), 1)
self.assertTrue(cm.output[0].startswith('ERROR'))
self.assertIn('A change in process pool PIDs has been detected!',
cm.output[0])


def Fail(n_tries, result_path):
"""Create a function that fails after ``n_tries``."""
Expand Down
Loading