Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/executorlib/standalone/interactive/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,21 @@ def find_future_in_list(lst):

find_future_in_list(lst=args)
find_future_in_list(lst=kwargs.values())
boolean_flag = len([future for future in future_lst if future.done()]) == len(
future_lst
)
return future_lst, boolean_flag

return future_lst


def check_list_of_futures_is_done(future_lst: list[Future]) -> bool:
"""
Check if all future objects in the list of future objects are done

Args:
future_lst (list): list of future objects

Returns:
bool: True if all future objects in the list of future objects are done, False otherwise
"""
return len([future for future in future_lst if future.done()]) == len(future_lst)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def get_exception_lst(future_lst: list[Future]) -> list:
Expand Down
46 changes: 31 additions & 15 deletions src/executorlib/task_scheduler/interactive/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from executorlib.standalone.batched import batched_futures
from executorlib.standalone.interactive.arguments import (
check_exception_was_raised,
check_list_of_futures_is_done,
get_exception_lst,
get_future_objects_from_input,
update_futures_in_input,
Expand Down Expand Up @@ -185,6 +186,7 @@ def batched(
"args": (),
"kwargs": {"lst": iterable, "n": n, "skip_lst": skip_lst},
"future": f,
"future_lst": iterable,
"future_skip": f_skip,
"resource_dict": {},
}
Expand Down Expand Up @@ -249,7 +251,7 @@ def _execute_tasks_with_dependencies(
executor (TaskSchedulerBase): Executor to execute the tasks with after the dependencies are resolved.
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
"""
wait_lst: list = []
future_dependency_lst: list = []
while True:
try:
task_dict = future_queue.get_nowait()
Expand All @@ -258,10 +260,10 @@ def _execute_tasks_with_dependencies(
if ( # shutdown the executor
task_dict is not None and "shutdown" in task_dict and task_dict["shutdown"]
):
while len(wait_lst) > 0:
while len(future_dependency_lst) > 0:
# Check functions in the wait list and execute them if all future objects are now ready
wait_lst = _update_waiting_task(
wait_lst=wait_lst,
future_dependency_lst = _handle_future_dependencies(
future_dependency_lst=future_dependency_lst,
executor_queue=executor_queue,
refresh_rate=refresh_rate,
)
Expand All @@ -283,12 +285,24 @@ def _execute_tasks_with_dependencies(
task_dict["future"].set_result(False)
else:
task_dict["future"].set_result(True)
elif ( # handle batched function submitted to the executor
task_dict is not None
and "fn" in task_dict
and task_dict["fn"] == "batched"
and "future" in task_dict
):
future_dependency_lst.append(task_dict)
future_queue.task_done()
elif ( # handle function submitted to the executor
task_dict is not None and "fn" in task_dict and "future" in task_dict
task_dict is not None
and "fn" in task_dict
and task_dict["fn"] != "batched"
and "future" in task_dict
):
future_lst, ready_flag = get_future_objects_from_input(
future_lst = get_future_objects_from_input(
args=task_dict["args"], kwargs=task_dict["kwargs"]
)
ready_flag = check_list_of_futures_is_done(future_lst=future_lst)
exception_lst = get_exception_lst(future_lst=future_lst)
if not check_exception_was_raised(future_obj=task_dict["future"]):
if len(exception_lst) > 0:
Expand All @@ -301,12 +315,12 @@ def _execute_tasks_with_dependencies(
executor_queue.put(task_dict)
else: # Otherwise add the function to the wait list
task_dict["future_lst"] = future_lst
wait_lst.append(task_dict)
future_dependency_lst.append(task_dict)
future_queue.task_done()
elif len(wait_lst) > 0:
elif len(future_dependency_lst) > 0:
# Check functions in the wait list and execute them if all future objects are now ready
wait_lst = _update_waiting_task(
wait_lst=wait_lst,
future_dependency_lst = _handle_future_dependencies(
future_dependency_lst=future_dependency_lst,
executor_queue=executor_queue,
refresh_rate=refresh_rate,
)
Expand All @@ -315,22 +329,24 @@ def _execute_tasks_with_dependencies(
sleep(refresh_rate)


def _update_waiting_task(
wait_lst: list[dict], executor_queue: queue.Queue, refresh_rate: float = 0.01
def _handle_future_dependencies(
future_dependency_lst: list[dict],
executor_queue: queue.Queue,
refresh_rate: float = 0.01,
) -> list:
"""
Submit the waiting tasks, which future inputs have been completed, to the executor

Args:
wait_lst (list): List of waiting tasks
future_dependency_lst (list): List of waiting tasks
executor_queue (Queue): Queue of the internal executor
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.

Returns:
list: list tasks which future inputs have not been completed
"""
wait_tmp_lst = []
for task_wait_dict in wait_lst:
for task_wait_dict in future_dependency_lst:
exception_lst = get_exception_lst(future_lst=task_wait_dict["future_lst"])
if len(exception_lst) > 0 and task_wait_dict["fn"] != "batched":
task_wait_dict["future"].set_exception(exception_lst[0])
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Expand Down Expand Up @@ -360,6 +376,6 @@ def _update_waiting_task(
task_wait_dict["future_skip"].set_result([id(f) for f in done_lst])
else:
wait_tmp_lst.append(task_wait_dict)
if len(wait_lst) == len(wait_tmp_lst):
if len(future_dependency_lst) == len(wait_tmp_lst):
sleep(refresh_rate)
return wait_tmp_lst
7 changes: 5 additions & 2 deletions tests/unit/standalone/interactive/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from executorlib.standalone.interactive.arguments import (
check_exception_was_raised,
check_list_of_futures_is_done,
get_exception_lst,
get_future_objects_from_input,
update_futures_in_input,
Expand All @@ -13,14 +14,16 @@ class TestSerial(unittest.TestCase):
def test_get_future_objects_from_input_with_future(self):
input_args = (1, 2, Future(), [Future()], {3: Future()})
input_kwargs = {"a": 1, "b": [Future()], "c": {"d": Future()}, "e": Future()}
future_lst, boolean_flag = get_future_objects_from_input(args=input_args, kwargs=input_kwargs)
future_lst = get_future_objects_from_input(args=input_args, kwargs=input_kwargs)
boolean_flag = check_list_of_futures_is_done(future_lst=future_lst)
self.assertEqual(len(future_lst), 6)
self.assertFalse(boolean_flag)

def test_get_future_objects_from_input_without_future(self):
input_args = (1, 2)
input_kwargs = {"a": 1}
future_lst, boolean_flag = get_future_objects_from_input(args=input_args, kwargs=input_kwargs)
future_lst = get_future_objects_from_input(args=input_args, kwargs=input_kwargs)
boolean_flag = check_list_of_futures_is_done(future_lst=future_lst)
self.assertEqual(len(future_lst), 0)
self.assertTrue(boolean_flag)

Expand Down
Loading