55import functools
66import logging
77import os
8+ import shutil
89import socket
910import subprocess
1011import sys
1112import tempfile
1213import uuid
1314import zipfile
1415from collections .abc import Awaitable
16+ from contextlib import contextmanager
17+ from importlib .util import find_spec
18+ from io import BytesIO
1519from typing import TYPE_CHECKING , Any , Callable , ClassVar
20+ from types import ModuleType
21+ from typing import Any , Tuple
22+ from pathlib import Path
1623
1724from dask .typing import Key
1825from dask .utils import funcname , tmpfile
2936 from distributed .scheduler import TaskStateState as SchedulerTaskStateState
3037 from distributed .worker import Worker
3138 from distributed .worker_state_machine import TaskStateState as WorkerTaskStateState
39+ from distributed .node import ServerNode
3240
3341logger = logging .getLogger (__name__ )
3442
@@ -1051,3 +1059,165 @@ def setup(self, worker):
10511059
10521060 def teardown (self , worker ):
10531061 self ._exit_stack .close ()
1062+
1063+
1064+ @contextmanager
1065+ def serialize_module (
1066+ module : ModuleType , exclude : Tuple [str ] = ("__pycache__" , ".DS_Store" )
1067+ ) -> Path :
1068+ module_path = Path (module .__file__ )
1069+
1070+ if module_path .stem == "__init__" :
1071+ # In case of package we serialize the whole package
1072+ module_path = module_path .parent
1073+ if "." in module .__name__ :
1074+ # TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py,
1075+ # but it should contain the whole structure of the package (package/module.py)
1076+ raise Exception (
1077+ f"Plugin supports only top-level packages or single-file modules. You provided `{ module .__name__ } `, try `{ module .__name__ .split ('.' )[0 ]} `."
1078+ )
1079+
1080+ # In case of single file we don't need to serialize anything
1081+
1082+ with tempfile .TemporaryDirectory () as tmp :
1083+ package_name = module_path .name
1084+
1085+ package_copy_path = Path (tmp ).joinpath (package_name )
1086+ if module_path .is_dir ():
1087+ copied_package = Path (
1088+ shutil .copytree (
1089+ module_path ,
1090+ package_copy_path ,
1091+ ignore = shutil .ignore_patterns (f"{ package_name } .zip" , * exclude ),
1092+ )
1093+ )
1094+ else :
1095+ copied_package = Path (shutil .copy2 (module_path , package_copy_path ))
1096+
1097+ archive_path = shutil .make_archive (
1098+ # output path including a name w/o extension
1099+ base_name = str (copied_package ),
1100+ format = "zip" ,
1101+ # chroot
1102+ root_dir = copied_package .parent ,
1103+ # Name of the directory to archive and a common prefix of all files and directories in the archive
1104+ base_dir = package_name ,
1105+ )
1106+
1107+ egg_file = shutil .move (archive_path , package_copy_path .with_suffix (".egg" ))
1108+
1109+ # zip file handler
1110+ zip = zipfile .ZipFile (egg_file )
1111+ # list available files in the container
1112+ logger .debug (
1113+ "The egg file %s contains the following files %s" ,
1114+ str (egg_file ),
1115+ str (zip .namelist ()),
1116+ )
1117+
1118+ logger .info ("Created an egg file %s from %s" , str (egg_file ), str (module_path ))
1119+
1120+ yield Path (egg_file )
1121+
1122+
1123+ class AbstractUploadModulePlugin :
1124+ def __init__ (self , module : ModuleType ):
1125+ self ._module_name = module .__name__
1126+ self ._data : bytes
1127+ self ._filepath : Path
1128+ self ._filename : str
1129+ with serialize_module (module ) as filepath :
1130+ self ._filename = filepath .name
1131+ with open (filepath , "rb" ) as f :
1132+ self ._data = f .read ()
1133+
1134+ async def _upload_file (self , node : ServerNode ):
1135+ response = await node .upload_file (self ._filename , self ._data , load = True )
1136+ assert len (self ._data ) == response ["nbytes" ]
1137+
1138+ async def _upload (self , node : ServerNode ):
1139+ import zipfile
1140+ import sys
1141+ try :
1142+ from IPython .extensions .autoreload import superreload
1143+ except ImportError :
1144+ superreload = lambda x : x
1145+
1146+ # Try to find already loaded module
1147+ module = (
1148+ sys .modules [self ._module_name ] if self ._module_name in sys .modules else None
1149+ )
1150+ # Try to find module on disk
1151+ module_spec = find_spec (self ._module_name )
1152+
1153+ if not module_spec and not module :
1154+ # If module does not exist we keep it as egg file and load it.
1155+ logger .info (
1156+ 'Uploading a new module "%s" to "%s" on %s "%s"' ,
1157+ self ._module_name ,
1158+ str (self ._filename ),
1159+ "worker" if isinstance (node , Worker ) else "scheduler" ,
1160+ node .id ,
1161+ )
1162+ await self ._upload_file (node )
1163+ return
1164+
1165+ if module :
1166+ module_path = self ._get_module_dir (module )
1167+ else :
1168+ module_path = Path (module_spec .origin )
1169+
1170+ if ".egg" in str (module_path ):
1171+ # Update the previously uploaded egg module and reload it.
1172+ logger .info (
1173+ 'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"' ,
1174+ self ._module_name ,
1175+ str (self ._filename ),
1176+ "worker" if isinstance (node , Worker ) else "scheduler" ,
1177+ node .id ,
1178+ )
1179+ await self ._upload_file (node )
1180+ return
1181+
1182+ with zipfile .ZipFile (BytesIO (self ._data ), "r" ) as zip_ref :
1183+ # In case, we received egg file for module that exists on node in source code,
1184+ # we overwrite each file separately by extracting it from the egg.
1185+ logger .info (
1186+ 'Uploading an update for an existing module "%s" in "%s" on %s "%s"' ,
1187+ self ._module_name ,
1188+ str (module_path .parent ),
1189+ "worker" if isinstance (node , Worker ) else "scheduler" ,
1190+ node .id ,
1191+ )
1192+ zip_ref .extractall (module_path .parent )
1193+
1194+ # TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function?
1195+ if self ._module_name in sys .modules :
1196+ # Reload module if it is already loaded
1197+ superreload (sys .modules [self ._module_name ])
1198+
1199+ @classmethod
1200+ def _get_module_dir (cls , module : ModuleType ) -> Path :
1201+ """Get the directory of the module."""
1202+ module_path = Path (sys .modules [module .__name__ ].__file__ )
1203+
1204+ if module_path .stem == "__init__" :
1205+ # In case of package we serialize the whole package
1206+ return module_path .parent
1207+
1208+ # In case of single file we don't need to serialize anything
1209+ return module_path
1210+
1211+
1212+ class UploadModule (WorkerPlugin , AbstractUploadModulePlugin ):
1213+ name = "upload_module"
1214+
1215+ async def setup (self , worker : Worker ):
1216+ await self ._upload (worker )
1217+
1218+
1219+ class SchedulerUploadModule (SchedulerPlugin , AbstractUploadModulePlugin ):
1220+ name = "upload_module"
1221+
1222+ async def start (self , scheduler : Scheduler ) -> None :
1223+ await self ._upload (scheduler )
0 commit comments