Skip to content

Commit 4bd2713

Browse files
seveinacdha
authored andcommitted
Clean up multiprocessing pools on failure
Ensure bagit-python reliably cleans up multiprocessing pools when parallel manifest generation or validation fails with processes > 1. Successful make_bag() calls already cleaned up correctly through the normal graceful pool path. The gap was failures inside Pool.map(): worker exceptions skipped close() and join(), allowing BagIt child processes to remain alive after the caller received the error. In long-running services, those leftover children can later become orphaned or defunct when the owning worker exits. This follows up on c451b24 ("Wait for validation Pool to finish"), the validation-pool cleanup work Douglas and I did, by extending the same reliability expectation to failure paths.
1 parent 462f6b0 commit 4bd2713

2 files changed

Lines changed: 60 additions & 11 deletions

File tree

src/bagit/__init__.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -896,12 +896,12 @@ def _validate_entries(self, processes):
896896
if processes == 1:
897897
hash_results = [_calc_hashes(i) for i in args]
898898
else:
899-
pool = multiprocessing.Pool(
900-
processes if processes else None, initializer=worker_init
899+
hash_results = _multiprocessing_pool_map(
900+
_calc_hashes,
901+
args,
902+
processes if processes else None,
903+
initializer=worker_init,
901904
)
902-
hash_results = pool.map(_calc_hashes, args)
903-
pool.close()
904-
pool.join()
905905

906906
# Any unhandled exceptions are probably fatal
907907
except:
@@ -1037,6 +1037,25 @@ def posix_multiprocessing_worker_initializer():
10371037
signal.signal(signal.SIGINT, signal.SIG_IGN)
10381038

10391039

1040+
def _multiprocessing_pool_map(func, iterable, processes, initializer=None):
1041+
"""Run ``Pool.map()`` and always clean up the pool.
1042+
1043+
This ensures worker processes are closed or terminated, then joined, under
1044+
all conditions.
1045+
"""
1046+
pool = multiprocessing.Pool(processes=processes, initializer=initializer)
1047+
try:
1048+
results = pool.map(func, iterable)
1049+
except BaseException:
1050+
pool.terminate()
1051+
raise
1052+
else:
1053+
pool.close()
1054+
return results
1055+
finally:
1056+
pool.join()
1057+
1058+
10401059
# The Unicode normalization form used here doesn't matter – all we care about
10411060
# is consistency since the input value will be preserved:
10421061

@@ -1245,10 +1264,9 @@ def make_manifests(data_dir, processes, algorithms=DEFAULT_CHECKSUMS, encoding="
12451264
manifest_line_generator = partial(generate_manifest_lines, algorithms=algorithms)
12461265

12471266
if processes > 1:
1248-
pool = multiprocessing.Pool(processes=processes)
1249-
checksums = pool.map(manifest_line_generator, _walk(data_dir))
1250-
pool.close()
1251-
pool.join()
1267+
checksums = _multiprocessing_pool_map(
1268+
manifest_line_generator, _walk(data_dir), processes=processes
1269+
)
12521270
else:
12531271
checksums = [manifest_line_generator(i) for i in _walk(data_dir)]
12541272

test.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
import tempfile
1414
import unicodedata
1515
import unittest
16+
from io import StringIO
1617
from os.path import join as j
17-
1818
from unittest import mock
19-
from io import StringIO
2019

2120
import bagit
2221

@@ -458,6 +457,23 @@ def validate(self, bag, *args, **kwargs):
458457
bag, *args, processes=2, **kwargs
459458
)
460459

460+
@mock.patch("bagit.multiprocessing.Pool")
461+
def test_validate_multiprocessing_terminates_and_joins_pool_on_failure(self, pool):
462+
pool.return_value.map.side_effect = RuntimeError("boom")
463+
bag = bagit.make_bag(self.tmpdir)
464+
465+
with self.assertRaises(RuntimeError):
466+
self.validate(bag)
467+
468+
self.assertEqual(
469+
pool.return_value.mock_calls,
470+
[
471+
mock.call.map(mock.ANY, mock.ANY),
472+
mock.call.terminate(),
473+
mock.call.join(),
474+
],
475+
)
476+
461477
@mock.patch("bagit.multiprocessing.Pool")
462478
def test_validate_pool_error(self, pool):
463479
# Simulate the Pool constructor raising a RuntimeError.
@@ -745,6 +761,21 @@ def test_make_bag_multiprocessing(self):
745761
bagit.make_bag(self.tmpdir, processes=2)
746762
self.assertTrue(os.path.isdir(j(self.tmpdir, "data")))
747763

764+
@mock.patch("bagit.multiprocessing.Pool")
765+
def test_make_bag_multiprocessing_terminates_and_joins_pool_on_failure(self, pool):
766+
pool.return_value.map.side_effect = RuntimeError("boom")
767+
with self.assertRaises(RuntimeError):
768+
bagit.make_bag(self.tmpdir, processes=2)
769+
770+
self.assertEqual(
771+
pool.return_value.mock_calls,
772+
[
773+
mock.call.map(mock.ANY, mock.ANY),
774+
mock.call.terminate(),
775+
mock.call.join(),
776+
],
777+
)
778+
748779
def test_multiple_meta_values(self):
749780
baginfo = {"Multival-Meta": [7, 4, 8, 6, 8]}
750781
bag = bagit.make_bag(self.tmpdir, baginfo)

0 commit comments

Comments
 (0)