-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathdecorator.py
More file actions
159 lines (128 loc) · 7.12 KB
/
decorator.py
File metadata and controls
159 lines (128 loc) · 7.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import annotations
import functools
import logging
import signal
import sys
import typing as t
from typing import List
import aiida
from aiida.engine.processes.functions import FunctionType
from aiida.manage import get_manager
from aiida.orm import ProcessNode
from node_graph.socket_spec import SocketSpec
from packaging.version import parse as parse_version
from aiida_pythonjob.calculations.pyfunction import PyFunction
from aiida_pythonjob.launch import create_inputs, prepare_pyfunction_inputs
LOGGER = logging.getLogger(__name__)
_AIIDA_VERSION = parse_version(aiida.__version__)
_NEEDS_RECURSION_LIMIT_WORKAROUND = _AIIDA_VERSION < parse_version("2.8.0rc0")
if _NEEDS_RECURSION_LIMIT_WORKAROUND:
from aiida.engine.processes.functions import get_stack_size
# The following code is modified from the aiida-core.engine.processes.functions module
def pyfunction(
inputs: t.Optional[SocketSpec | List[str]] = None,
outputs: t.Optional[t.List[SocketSpec | List[str]]] = None,
) -> t.Callable[[FunctionType], FunctionType]:
"""The base function decorator to create a FunctionProcess out of a normal python function.
:param outputs: the outputs of the function, if not provided, we assume a single output named 'result'.
"""
def decorator(function: FunctionType) -> FunctionType:
"""Turn the decorated function into a FunctionProcess.
:param callable function: the actual decorated function that the FunctionProcess represents
:return callable: The decorated function.
"""
def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, "ProcessNode"]:
"""Run the FunctionProcess with the supplied inputs in a local runner.
:param args: input arguments to construct the FunctionProcess
:param kwargs: input keyword arguments to construct the FunctionProcess
:return: tuple of the outputs of the process and the process node
"""
if _NEEDS_RECURSION_LIMIT_WORKAROUND:
frame_delta = 1000
frame_count = get_stack_size()
stack_limit = sys.getrecursionlimit()
LOGGER.info(
"Executing process function, current stack status: %d frames of %d", frame_count, stack_limit
)
if frame_count > min(0.8 * stack_limit, stack_limit - 200):
LOGGER.warning(
"Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d",
frame_count,
stack_limit,
frame_delta,
)
sys.setrecursionlimit(stack_limit + frame_delta)
manager = get_manager()
runner = manager.get_runner()
# Remove all the known inputs from the kwargs
outputs_spec = kwargs.pop("outputs_spec", None) or outputs
inputs_spec = kwargs.pop("inputs_spec", None) or inputs
metadata = kwargs.pop("metadata", None)
function_data = kwargs.pop("function_data", None)
deserializers = kwargs.pop("deserializers", None)
serializers = kwargs.pop("serializers", None)
process_label = kwargs.pop("process_label", None)
register_pickle_by_value = kwargs.pop("register_pickle_by_value", False)
function_inputs = create_inputs(function, *args, **kwargs)
process_inputs = prepare_pyfunction_inputs(
function=function,
function_inputs=function_inputs,
inputs_spec=inputs_spec,
outputs_spec=outputs_spec,
metadata=metadata,
process_label=process_label,
function_data=function_data,
deserializers=deserializers,
serializers=serializers,
register_pickle_by_value=register_pickle_by_value,
)
process = PyFunction(inputs=process_inputs, runner=runner)
# Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner.
# Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown
current_runner = manager.get_runner()
original_handler = None
kill_signal = signal.SIGINT
if not current_runner.is_daemon_runner:
def kill_process(_num, _frame):
"""Send the kill signal to the process in the current scope."""
LOGGER.critical("runner received interrupt, killing process %s", process.pid)
result = process.kill(msg="Process was killed because the runner received an interrupt")
return result
# Store the current handler on the signal such that it can be restored after process has terminated
original_handler = signal.getsignal(kill_signal)
signal.signal(kill_signal, kill_process)
try:
result = process.execute()
finally:
# If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset
if original_handler:
signal.signal(signal.SIGINT, original_handler)
store_provenance = process_inputs.get("metadata", {}).get("store_provenance", True)
if not store_provenance:
process.node._storable = False
process.node._unstorable_message = "cannot store node because it was run with `store_provenance=False`"
return result, process.node
def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]:
"""Recreate the `run_get_pk` utility launcher.
:param args: input arguments to construct the FunctionProcess
:param kwargs: input keyword arguments to construct the FunctionProcess
:return: tuple of the outputs of the process and the process node pk
"""
result, node = run_get_node(*args, **kwargs)
assert node.pk is not None
return result, node.pk
@functools.wraps(function)
def decorated_function(*args, **kwargs):
"""This wrapper function is the actual function that is called."""
result, _ = run_get_node(*args, **kwargs)
return result
decorated_function.func = function # type: ignore[attr-defined]
decorated_function.run = decorated_function # type: ignore[attr-defined]
decorated_function.run_get_pk = run_get_pk # type: ignore[attr-defined]
decorated_function.run_get_node = run_get_node # type: ignore[attr-defined]
decorated_function.is_process_function = True # type: ignore[attr-defined]
decorated_function.process_class = PyFunction # type: ignore[attr-defined]
decorated_function.recreate_from = PyFunction.recreate_from # type: ignore[attr-defined]
decorated_function.spec = PyFunction.spec # type: ignore[attr-defined]
return decorated_function # type: ignore[return-value]
return decorator