Skip to content

Commit c0eaf07

Browse files
committed
Implement module upload plugin (#8698)
1 parent 36020d6 commit c0eaf07

1 file changed

Lines changed: 170 additions & 0 deletions

File tree

distributed/diagnostics/plugin.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@
55
import functools
66
import logging
77
import os
8+
import shutil
89
import socket
910
import subprocess
1011
import sys
1112
import tempfile
1213
import uuid
1314
import zipfile
1415
from collections.abc import Awaitable
16+
from contextlib import contextmanager
17+
from importlib.util import find_spec
18+
from io import BytesIO
1519
from typing import TYPE_CHECKING, Any, Callable, ClassVar
20+
from types import ModuleType
21+
from typing import Any, Tuple
22+
from pathlib import Path
1623

1724
from dask.typing import Key
1825
from dask.utils import funcname, tmpfile
@@ -29,6 +36,7 @@
2936
from distributed.scheduler import TaskStateState as SchedulerTaskStateState
3037
from distributed.worker import Worker
3138
from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState
39+
from distributed.node import ServerNode
3240

3341
logger = logging.getLogger(__name__)
3442

@@ -1051,3 +1059,165 @@ def setup(self, worker):
10511059

10521060
def teardown(self, worker):
10531061
self._exit_stack.close()
1062+
1063+
1064+
@contextmanager
1065+
def serialize_module(
1066+
module: ModuleType, exclude: Tuple[str] = ("__pycache__", ".DS_Store")
1067+
) -> Path:
1068+
module_path = Path(module.__file__)
1069+
1070+
if module_path.stem == "__init__":
1071+
# In case of package we serialize the whole package
1072+
module_path = module_path.parent
1073+
if "." in module.__name__:
1074+
# TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py,
1075+
# but it should contain the whole structure of the package (package/module.py)
1076+
raise Exception(
1077+
f"Plugin supports only top-level packages or single-file modules. You provided `{module.__name__}`, try `{module.__name__.split('.')[0]}`."
1078+
)
1079+
1080+
# In case of single file we don't need to serialize anything
1081+
1082+
with tempfile.TemporaryDirectory() as tmp:
1083+
package_name = module_path.name
1084+
1085+
package_copy_path = Path(tmp).joinpath(package_name)
1086+
if module_path.is_dir():
1087+
copied_package = Path(
1088+
shutil.copytree(
1089+
module_path,
1090+
package_copy_path,
1091+
ignore=shutil.ignore_patterns(f"{package_name}.zip", *exclude),
1092+
)
1093+
)
1094+
else:
1095+
copied_package = Path(shutil.copy2(module_path, package_copy_path))
1096+
1097+
archive_path = shutil.make_archive(
1098+
# output path including a name w/o extension
1099+
base_name=str(copied_package),
1100+
format="zip",
1101+
# chroot
1102+
root_dir=copied_package.parent,
1103+
# Name of the directory to archive and a common prefix of all files and directories in the archive
1104+
base_dir=package_name,
1105+
)
1106+
1107+
egg_file = shutil.move(archive_path, package_copy_path.with_suffix(".egg"))
1108+
1109+
# zip file handler
1110+
zip = zipfile.ZipFile(egg_file)
1111+
# list available files in the container
1112+
logger.debug(
1113+
"The egg file %s contains the following files %s",
1114+
str(egg_file),
1115+
str(zip.namelist()),
1116+
)
1117+
1118+
logger.info("Created an egg file %s from %s", str(egg_file), str(module_path))
1119+
1120+
yield Path(egg_file)
1121+
1122+
1123+
class AbstractUploadModulePlugin:
1124+
def __init__(self, module: ModuleType):
1125+
self._module_name = module.__name__
1126+
self._data: bytes
1127+
self._filepath: Path
1128+
self._filename: str
1129+
with serialize_module(module) as filepath:
1130+
self._filename = filepath.name
1131+
with open(filepath, "rb") as f:
1132+
self._data = f.read()
1133+
1134+
async def _upload_file(self, node: ServerNode):
1135+
response = await node.upload_file(self._filename, self._data, load=True)
1136+
assert len(self._data) == response["nbytes"]
1137+
1138+
async def _upload(self, node: ServerNode):
1139+
import zipfile
1140+
import sys
1141+
try:
1142+
from IPython.extensions.autoreload import superreload
1143+
except ImportError:
1144+
superreload = lambda x: x
1145+
1146+
# Try to find already loaded module
1147+
module = (
1148+
sys.modules[self._module_name] if self._module_name in sys.modules else None
1149+
)
1150+
# Try to find module on disk
1151+
module_spec = find_spec(self._module_name)
1152+
1153+
if not module_spec and not module:
1154+
# If module does not exist we keep it as egg file and load it.
1155+
logger.info(
1156+
'Uploading a new module "%s" to "%s" on %s "%s"',
1157+
self._module_name,
1158+
str(self._filename),
1159+
"worker" if isinstance(node, Worker) else "scheduler",
1160+
node.id,
1161+
)
1162+
await self._upload_file(node)
1163+
return
1164+
1165+
if module:
1166+
module_path = self._get_module_dir(module)
1167+
else:
1168+
module_path = Path(module_spec.origin)
1169+
1170+
if ".egg" in str(module_path):
1171+
# Update the previously uploaded egg module and reload it.
1172+
logger.info(
1173+
'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"',
1174+
self._module_name,
1175+
str(self._filename),
1176+
"worker" if isinstance(node, Worker) else "scheduler",
1177+
node.id,
1178+
)
1179+
await self._upload_file(node)
1180+
return
1181+
1182+
with zipfile.ZipFile(BytesIO(self._data), "r") as zip_ref:
1183+
# In case, we received egg file for module that exists on node in source code,
1184+
# we overwrite each file separately by extracting it from the egg.
1185+
logger.info(
1186+
'Uploading an update for an existing module "%s" in "%s" on %s "%s"',
1187+
self._module_name,
1188+
str(module_path.parent),
1189+
"worker" if isinstance(node, Worker) else "scheduler",
1190+
node.id,
1191+
)
1192+
zip_ref.extractall(module_path.parent)
1193+
1194+
# TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function?
1195+
if self._module_name in sys.modules:
1196+
# Reload module if it is already loaded
1197+
superreload(sys.modules[self._module_name])
1198+
1199+
@classmethod
1200+
def _get_module_dir(cls, module: ModuleType) -> Path:
1201+
"""Get the directory of the module."""
1202+
module_path = Path(sys.modules[module.__name__].__file__)
1203+
1204+
if module_path.stem == "__init__":
1205+
# In case of package we serialize the whole package
1206+
return module_path.parent
1207+
1208+
# In case of single file we don't need to serialize anything
1209+
return module_path
1210+
1211+
1212+
class UploadModule(WorkerPlugin, AbstractUploadModulePlugin):
1213+
name = "upload_module"
1214+
1215+
async def setup(self, worker: Worker):
1216+
await self._upload(worker)
1217+
1218+
1219+
class SchedulerUploadModule(SchedulerPlugin, AbstractUploadModulePlugin):
1220+
name = "upload_module"
1221+
1222+
async def start(self, scheduler: Scheduler) -> None:
1223+
await self._upload(scheduler)

0 commit comments

Comments
 (0)