Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f607d24
adding download task and creating seperate download pool
HansVRP Nov 26, 2025
7478990
include initial unit testing
HansVRP Nov 28, 2025
8e3ae8b
updated unit tests
HansVRP Dec 10, 2025
24989cf
including two simple unit tests and unifying pool usage
HansVRP Dec 11, 2025
3293327
changes to job manager
HansVRP Dec 11, 2025
12277de
adding easy callback to check number of pending tasks on thread worke…
HansVRP Dec 11, 2025
bade858
process updates through job update loop
HansVRP Dec 11, 2025
14585c9
remove folder creation logic from thread to resprect optional downloa…
HansVRP Dec 11, 2025
d0a7fbf
fix stop_job_thread
HansVRP Dec 11, 2025
d4d0110
working on fix for indefinete loop
HansVRP Dec 11, 2025
086a30b
fix infinite loop
HansVRP Dec 11, 2025
2603c30
wrapper to abstract multiple threadpools
HansVRP Dec 15, 2025
188ab5d
coupling task type to seperate pool
HansVRP Dec 15, 2025
24cf000
include unit test for dict of pools
HansVRP Dec 15, 2025
8a3aa20
tmp_path usage and renaming
HansVRP Dec 16, 2025
2e0d008
fix documentation
HansVRP Dec 16, 2025
1894ef6
keep track of number of assets
HansVRP Dec 19, 2025
1d2020b
avoid abreviation of number
HansVRP Dec 19, 2025
62a4bf4
do not expose number of remaining jobs
HansVRP Dec 19, 2025
fdcd047
abstract task name in thread pool
HansVRP Dec 19, 2025
9b978d3
not use remaing in unit test
HansVRP Dec 19, 2025
d629637
fix unit tests
HansVRP Dec 19, 2025
2ab7741
fix
HansVRP Dec 19, 2025
c999d1f
move towards get_results to avoid deprecation
HansVRP Jan 6, 2026
3a215b5
Merge remote-tracking branch 'origin/master' into issue816_threadeddo…
soxofaan Jan 23, 2026
87a32e5
re-adding len(to_keep)
HansVRP Jan 27, 2026
8c004c2
improved tracking of pending postprocessing in for future
HansVRP Jan 28, 2026
6f29682
clean some text
HansVRP Jan 28, 2026
09df02a
wait slightly longer
HansVRP Jan 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions openeo/extra/job_management/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from openeo.extra.job_management._thread_worker import (
_JobManagerWorkerThreadPool,
_JobStartTask,
_JobDownloadTask
)
from openeo.rest import OpenEoApiError
from openeo.rest.auth.auth import BearerAuth
Expand Down Expand Up @@ -175,6 +176,7 @@ def start_job(

.. versionchanged:: 0.47.0
Added ``download_results`` parameter.

"""

# Expected columns in the job DB dataframes.
Expand Down Expand Up @@ -373,6 +375,9 @@ def run_loop():
).values()
)
> 0

or (self._worker_pool is not None and self._worker_pool.has_unprocessed_tasks())

and not self._stop_thread
):
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
Expand All @@ -398,7 +403,10 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET):

.. versionadded:: 0.32.0
"""
self._worker_pool.shutdown()
if self._worker_pool is not None:
self._worker_pool.shutdown()
self._worker_pool = None


if self._thread is not None:
self._stop_thread = True
Expand Down Expand Up @@ -504,13 +512,15 @@ def run_jobs(

self._worker_pool = _JobManagerWorkerThreadPool()


while (
sum(
job_db.count_by_status(
statuses=["not_started", "created", "queued_for_start", "queued", "running"]
).values()
)
> 0
).values()) > 0

or (self._worker_pool is not None and self._worker_pool.has_unprocessed_tasks())

):
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
stats["run_jobs loop"] += 1
Expand All @@ -520,8 +530,10 @@ def run_jobs(
time.sleep(self.poll_sleep)
stats["sleep"] += 1

# TODO; run post process after shutdown once more to ensure completion?


self._worker_pool.shutdown()
self._worker_pool = None

return stats

Expand Down Expand Up @@ -567,7 +579,9 @@ def _job_update_loop(
stats["job_db persist"] += 1
total_added += 1

self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats)
if self._worker_pool is not None:
self._process_threadworker_updates(worker_pool=self._worker_pool, job_db=job_db, stats=stats)


# TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
for job, row in jobs_done:
Expand All @@ -579,6 +593,7 @@ def _job_update_loop(
for job, row in jobs_cancel:
self.on_job_cancel(job, row)


def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None):
"""Helper method for launching jobs

Expand Down Expand Up @@ -643,7 +658,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
df_idx=i,
)
_log.info(f"Submitting task {task} to thread pool")
self._worker_pool.submit_task(task)
self._worker_pool.submit_task(task=task, pool_name="job_start")

stats["job_queued_for_start"] += 1
df.loc[i, "status"] = "queued_for_start"
Expand Down Expand Up @@ -689,7 +704,7 @@ def _process_threadworker_updates(
:param stats: Dictionary accumulating statistic counters
"""
# Retrieve completed task results immediately
results, _ = worker_pool.process_futures(timeout=0)
results, _ = worker_pool.process_futures(timeout=0)

# Collect update dicts
updates: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -735,17 +750,28 @@ def on_job_done(self, job: BatchJob, row):
:param job: The job that has finished.
:param row: DataFrame row containing the job's metadata.
"""
# TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use?
if self._download_results:
job_metadata = job.describe()
job_dir = self.get_job_dir(job.job_id)
metadata_path = self.get_job_metadata_path(job.job_id)

job_dir = self.get_job_dir(job.job_id)
self.ensure_job_dir_exists(job.job_id)
job.get_results().download_files(target=job_dir)

with metadata_path.open("w", encoding="utf-8") as f:
json.dump(job_metadata, f, ensure_ascii=False)
#Proactively refresh bearer token
job_con = job.connection
self._refresh_bearer_token(connection=job_con)

task = _JobDownloadTask(
job_id=job.job_id,
df_idx=row.name,
root_url=job_con.root_url,
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
download_dir=job_dir,
)
_log.info(f"Submitting download task {task} to download thread pool")

if self._worker_pool is None:
self._worker_pool = _JobManagerWorkerThreadPool()

self._worker_pool.submit_task(task=task, pool_name="job_download")

def on_job_error(self, job: BatchJob, row):
"""
Expand Down Expand Up @@ -797,6 +823,7 @@ def _cancel_prolonged_job(self, job: BatchJob, row):
except Exception as e:
_log.error(f"Unexpected error while handling job {job.job_id}: {e}")

#TODO pull this functionality away from the manager to a general utility class? job dir creation could be reused for tje Jobdownload task
def get_job_dir(self, job_id: str) -> Path:
"""Path to directory where job metadata, results and error logs are be saved."""
return self._root_dir / f"job_{job_id}"
Expand Down
131 changes: 128 additions & 3 deletions openeo/extra/job_management/_thread_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path

import json
import urllib3.util

import openeo
Expand Down Expand Up @@ -99,7 +101,7 @@ def get_connection(self, retry: Union[urllib3.util.Retry, dict, bool, None] = No
connection.authenticate_bearer_token(self.bearer_token)
return connection


@dataclass(frozen=True)
class _JobStartTask(ConnectedTask):
"""
Task for starting an openEO batch job (the `POST /jobs/<job_id>/result` request).
Expand Down Expand Up @@ -139,9 +141,51 @@ def execute(self) -> _TaskResult:
db_update={"status": "start_failed"},
stats_update={"start_job error": 1},
)

@dataclass(frozen=True)
class _JobDownloadTask(ConnectedTask):
"""
Task for downloading job results and metadata.

:param download_dir:
Root directory where job results and metadata will be downloaded.
"""
download_dir: Path = field(default=None, repr=False)

class _JobManagerWorkerThreadPool:
def execute(self) -> _TaskResult:

try:
job = self.get_connection(retry=True).job(self.job_id)

# Count assets (files to download)
file_count = len(job.get_results().get_assets())

# Download results
job.get_results().download_files(target=self.download_dir)

# Download metadata
job_metadata = job.describe()
metadata_path = self.download_dir / f"job_{self.job_id}.json"
with metadata_path.open("w", encoding="utf-8") as f:
json.dump(job_metadata, f, ensure_ascii=False)

_log.info(f"Job {self.job_id!r} results downloaded successfully")
return _TaskResult(
job_id=self.job_id,
df_idx=self.df_idx,
db_update={}, #TODO consider db updates?
stats_update={"job download": 1, "files downloaded": file_count},
)
except Exception as e:
_log.error(f"Failed to download results for job {self.job_id!r}: {e!r}")
return _TaskResult(
job_id=self.job_id,
df_idx=self.df_idx,
db_update={},
stats_update={"job download error": 1, "files downloaded": 0},
)

class _TaskThreadPool:
"""
Thread pool-based worker that manages the execution of asynchronous tasks.

Expand All @@ -156,6 +200,8 @@ class _JobManagerWorkerThreadPool:
def __init__(self, max_workers: int = 2):
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = []
self._total_submitted = 0
self._total_processed = 0

def submit_task(self, task: Task) -> None:
"""
Expand All @@ -169,6 +215,8 @@ def submit_task(self, task: Task) -> None:
"""
future = self._executor.submit(task.execute)
self._future_task_pairs.append((future, task)) # Track pairs
self._total_submitted += 1


def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], int]:
"""
Expand All @@ -186,7 +234,11 @@ def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskRe
results = []
to_keep = []

done, _ = concurrent.futures.wait([f for f, _ in self._future_task_pairs], timeout=timeout)
if not self._future_task_pairs:
return results, 0

futures = [f for f, _ in self._future_task_pairs]
done, _ = concurrent.futures.wait(futures, timeout=timeout)

for future, task in self._future_task_pairs:
if future in done:
Expand All @@ -206,9 +258,82 @@ def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskRe
_log.info("process_futures: %d tasks done, %d tasks remaining", len(results), len(to_keep))

self._future_task_pairs = to_keep
self._total_processed += len(results)

return results, len(to_keep)

def get_unprocessed_count(self) -> int:
"""Get number of tasks that haven't been processed yet."""
return self._total_submitted - self._total_processed

def has_unprocessed_tasks(self) -> bool:
"""Check if there are tasks that haven't been processed yet."""
return self._total_submitted > self._total_processed

def shutdown(self) -> None:
"""Shuts down the thread pool gracefully."""
_log.info("Shutting down thread pool")
self._executor.shutdown(wait=True)


class _JobManagerWorkerThreadPool:
"""
Generic wrapper that manages multiple thread pools with a dict.
"""

def __init__(self, pool_configs: Optional[Dict[str, int]] = None):
self._pools: Dict[str, _TaskThreadPool] = {}
self._pool_configs = pool_configs or {}

def list_pools(self) -> List[str]:
"""List all active pool names."""
return list(self._pools.keys())

def submit_task(self, task: Task, pool_name: str = "default") -> None:
if pool_name not in self._pools:
max_workers = self._pool_configs.get(pool_name, 2)
self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers)
_log.info(f"Created pool '{pool_name}' with {max_workers} workers")

self._pools[pool_name].submit_task(task)

def get_unprocessed_counts(self) -> Dict[str, int]:
"""Get unprocessed (submitted but not processed) task counts per pool."""
return {name: pool.get_unprocessed_count() for name, pool in self._pools.items()}

def has_unprocessed_tasks(self) -> bool:
"""Check if any pool has unprocessed (submitted but not processed) tasks."""
return any(pool.has_unprocessed_tasks() for pool in self._pools.values())

def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]:
"""
Process updates from ALL pools.
Returns: (all_results, dict of remaining tasks per pool)
"""
all_results = []
to_keep = {}

for pool_name, pool in self._pools.items():
results, remaining = pool.process_futures(timeout)
all_results.extend(results)

to_keep[pool_name] = remaining

return all_results, to_keep

def shutdown(self, pool_name: Optional[str] = None) -> None:
"""
Shutdown pools.
If pool_name is None, shuts down all pools.
"""
if pool_name:
if pool_name in self._pools:
self._pools[pool_name].shutdown()
del self._pools[pool_name]
else:
for pool_name, pool in list(self._pools.items()):
pool.shutdown()
del self._pools[pool_name]



8 changes: 4 additions & 4 deletions tests/extra/job_management/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def get_status(job_id, current_status):
assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime)

def test_process_threadworker_updates(self, tmp_path, caplog):
pool = _JobManagerWorkerThreadPool(max_workers=2)
pool = _JobManagerWorkerThreadPool()
stats = collections.defaultdict(int)

# Submit tasks covering all cases
Expand Down Expand Up @@ -769,7 +769,7 @@ def test_process_threadworker_updates(self, tmp_path, caplog):
assert caplog.messages == []

def test_process_threadworker_updates_unknown(self, tmp_path, caplog):
pool = _JobManagerWorkerThreadPool(max_workers=2)
pool = _JobManagerWorkerThreadPool()
stats = collections.defaultdict(int)

pool.submit_task(DummyResultTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1}))
Expand Down Expand Up @@ -806,7 +806,7 @@ def test_process_threadworker_updates_unknown(self, tmp_path, caplog):
assert caplog.messages == [dirty_equals.IsStr(regex=".*Ignoring unknown.*indices.*4.*")]

def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
pool = _JobManagerWorkerThreadPool(max_workers=2)
pool = _JobManagerWorkerThreadPool()
stats = collections.defaultdict(int)

df_initial = pd.DataFrame({"id": ["j-0"], "status": ["created"]})
Expand All @@ -820,7 +820,7 @@ def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
assert stats == {}

def test_logs_on_invalid_update(self, tmp_path, caplog):
pool = _JobManagerWorkerThreadPool(max_workers=2)
pool = _JobManagerWorkerThreadPool()
stats = collections.defaultdict(int)

# Malformed db_update (not a dict unpackable via **)
Expand Down
Loading