diff --git a/src/madengine/core/constants.py b/src/madengine/core/constants.py index d1afa4c9..1a8b4fc3 100644 --- a/src/madengine/core/constants.py +++ b/src/madengine/core/constants.py @@ -26,17 +26,28 @@ import os import json import logging +import sys # Utility function for optional verbose logging of configuration def _log_config_info(message: str, force_print: bool = False): """Log configuration information either to logger or print if specified.""" + # Keep --version/--help output clean even if MAD_VERBOSE_CONFIG=true. + if any(arg in {"--version", "-V", "--help", "-h"} for arg in sys.argv[1:]): + logging.debug(message) + return if force_print or os.environ.get("MAD_VERBOSE_CONFIG", "").lower() == "true": print(message) else: logging.debug(message) +def _is_lightweight_cli_invocation() -> bool: + """Return True for metadata/help invocations that should avoid side effects.""" + lightweight_flags = {"--version", "-V", "--help", "-h"} + return any(arg in lightweight_flags for arg in sys.argv[1:]) + + # third-party modules from madengine.core.console import Console @@ -65,9 +76,12 @@ def _setup_model_dir(): _log_config_info(f"Model dir: {MODEL_DIR} copied to current dir: {cwd_abs}") -# Only setup model directory if explicitly requested (when not just importing for constants) +# Only setup model directory if explicitly requested and invocation is not metadata-only. if os.environ.get("MAD_SETUP_MODEL_DIR", "").lower() == "true": - _setup_model_dir() + if _is_lightweight_cli_invocation(): + _log_config_info("Skipping MODEL_DIR setup for lightweight CLI invocation (--version/--help).") + else: + _setup_model_dir() # madengine credentials configuration CRED_FILE = "credential.json" diff --git a/src/madengine/core/context.py b/src/madengine/core/context.py index eb129b82..a766ff2c 100644 --- a/src/madengine/core/context.py +++ b/src/madengine/core/context.py @@ -758,12 +758,6 @@ def get_gpu_renderD_nodes(self) -> typing.Optional[typing.List[int]]: print(f"Warning: Failed to parse unique_id from line '{item}': {e}") continue - if kfd_renderDs is None: - raise RuntimeError( - "KFD topology not accessible and required for ROCm < 6.4.1 GPU mapping. " - "Check permissions on /sys/devices/virtual/kfd/kfd/topology/nodes" - ) - if len(kfd_unique_ids) != len(kfd_renderDs): raise RuntimeError( f"Mismatch between unique_ids count ({len(kfd_unique_ids)}) " @@ -816,7 +810,10 @@ def get_gpu_renderD_nodes(self) -> typing.Optional[typing.List[int]]: json_output = output try: - data = json.loads(json_output) + # Use raw_decode so we tolerate any trailing non-JSON + # text (deprecation banners, double-emitted blocks under + # concurrent slurm tasks, stray newlines, etc.). + data, _end = json.JSONDecoder().raw_decode(json_output.lstrip()) except json.JSONDecodeError as e: raise ValueError(f"Failed to parse amd-smi JSON output: {e}. Output was: {output[:200]}") diff --git a/src/madengine/deployment/slurm.py b/src/madengine/deployment/slurm.py index 4efc5855..dd604888 100644 --- a/src/madengine/deployment/slurm.py +++ b/src/madengine/deployment/slurm.py @@ -1621,6 +1621,60 @@ def _build_common_info_dict( flatten_tags(result) return result + def _select_best_multiple_results_csv(self, candidates: List[Path]) -> Optional[Path]: + """Pick the CSV with the most non-empty performance entries. + + In multi-node SLURM runs every node copies its local multi-results CSV + into job_dir/node_/. Only some nodes observe the final throughput + and populate the performance column; others have the file but with empty + values (e.g. master perf empty while a worker has the real numbers). + Ranking candidates by the count of non-empty performance rows lets + downstream aggregation use the richest data instead of depending on + node-0 winning the race or being non-empty. Header/row keys are stripped + so a leading space in the CSV header (some Primus configs) still matches. + Ties break on total row count; candidates[0] is the ultimate fallback. + """ + if not candidates: + return None + if len(candidates) == 1: + return candidates[0] + import csv as _csv + best_candidate: Optional[Path] = None + best_score = -1 + best_rows = -1 + for candidate in candidates: + non_empty_perf = 0 + total_rows = 0 + has_perf_column = False + try: + with open(candidate, "r", encoding="utf-8", errors="ignore") as f: + reader = _csv.DictReader(f) + fieldnames = reader.fieldnames or [] + stripped_fields = [fn.strip() for fn in fieldnames] + has_perf_column = "performance" in stripped_fields + for row in reader: + total_rows += 1 + if has_perf_column: + normalized_row = {(k.strip() if isinstance(k, str) else k): v for k, v in row.items()} + value = (normalized_row.get("performance") or "").strip() + if value: + non_empty_perf += 1 + except Exception: + continue + score = non_empty_perf if has_perf_column else 0 + if score > best_score or (score == best_score and total_rows > best_rows): + best_score = score + best_rows = total_rows + best_candidate = candidate + if best_candidate is None: + return candidates[0] + if best_score > 0: + self.console.print( + f"[dim] Selected multiple_results CSV with {best_score} non-empty performance rows: {best_candidate}[/dim]" + ) + return best_candidate + + def collect_results(self, deployment_id: str) -> Dict[str, Any]: """Collect performance results from SLURM output files. @@ -1764,14 +1818,19 @@ def collect_results(self, deployment_id: str) -> Dict[str, Any]: mult_res = model_info_for_entry.get("multiple_results") if mult_res: resolved_csv: Optional[Path] = None + # Multi-node: gather all node CSVs and pick the one with the most + # non-empty performance rows (master CSV may be empty while a worker + # holds the real numbers) instead of taking the first node that has + # the file. + candidates: List[Path] = [] if (job_dir / mult_res).is_file(): - resolved_csv = job_dir / mult_res - else: - for i in range(self.nodes): - candidate = job_dir / f"node_{i}" / mult_res - if candidate.is_file(): - resolved_csv = candidate - break + candidates.append(job_dir / mult_res) + for i in range(self.nodes): + per_node_candidate = job_dir / f"node_{i}" / mult_res + if per_node_candidate.is_file(): + candidates.append(per_node_candidate) + if candidates: + resolved_csv = self._select_best_multiple_results_csv(candidates) if not resolved_csv and Path(mult_res).is_file(): resolved_csv = Path(mult_res) if not resolved_csv and Path("run_directory", mult_res).is_file(): diff --git a/src/madengine/execution/container_runner.py b/src/madengine/execution/container_runner.py index 26110c41..9b6dfc0c 100644 --- a/src/madengine/execution/container_runner.py +++ b/src/madengine/execution/container_runner.py @@ -11,8 +11,10 @@ import re import shlex import subprocess +import socket import time import json +import hashlib import typing import warnings from rich.console import Console as RichConsole @@ -706,14 +708,22 @@ def get_env_arg(self, run_env: typing.Dict) -> str: return env_args - def get_mount_arg(self, mount_datapaths: typing.List) -> str: - """Get the mount arguments for docker run.""" + def get_mount_arg(self, mount_datapaths: typing.List, excluded_container_targets: typing.Optional[typing.Set[str]] = None) -> str: + """Get the mount arguments for docker run. + + excluded_container_targets lists container-side paths already mounted via + additional_docker_run_options so we do not emit a duplicate -v (docker + rejects \"Duplicate mount point\"). + """ mount_args = "" + excluded_container_targets = excluded_container_targets or set() # Mount data paths if mount_datapaths: for mount_datapath in mount_datapaths: if mount_datapath: + if mount_datapath["home"] in excluded_container_targets: + continue mount_args += ( f"-v {shlex.quote(mount_datapath['path'])}:{shlex.quote(mount_datapath['home'])}" ) @@ -728,12 +738,13 @@ def get_mount_arg(self, mount_datapaths: typing.List) -> str: # Mount context paths if "docker_mounts" in self.context.ctx: for mount_arg in self.context.ctx["docker_mounts"].keys(): + if mount_arg in excluded_container_targets: + continue mount_args += ( f"-v {shlex.quote(self.context.ctx['docker_mounts'][mount_arg])}:{shlex.quote(mount_arg)} " ) return mount_args - def apply_tools( self, pre_encapsulate_post_scripts: typing.Dict, @@ -1315,9 +1326,11 @@ def run_container( self.context.ctx["docker_env_vars"]["MIOPEN_USER_DB_PATH"] = os.environ["MIOPEN_USER_DB_PATH"] print(f"â„šī¸ Added MIOPEN_USER_DB_PATH to docker_env_vars: {os.environ['MIOPEN_USER_DB_PATH']}") + additional_opts = model_info.get('additional_docker_run_options', '') + excluded_mount_targets = self._extract_additional_mount_targets(additional_opts) docker_options += self.get_env_arg(run_env) - docker_options += self.get_mount_arg(mount_datapaths) - docker_options += f" {model_info.get('additional_docker_run_options', '')}" + docker_options += self.get_mount_arg(mount_datapaths, excluded_container_targets=excluded_mount_targets) + docker_options += f" {additional_opts}" # Generate container name base_container_name = "container_" + re.sub( @@ -1551,10 +1564,63 @@ def run_container( ) # Use the container timeout (default 7200s) for script execution # to prevent indefinite hangs - model_output = model_docker.sh( - f"cd {model_dir} && {script_name} {model_args}", - timeout=timeout, - ) + try: + model_output = model_docker.sh( + f"cd {model_dir} && {script_name} {model_args}", + timeout=timeout, + ) + except RuntimeError as run_err: + # On script failure, collect lightweight diagnostics from the + # running container (process table, listening ports, log tails). + # These are printed via Console.sh so they land in the run log + # alongside the failure. Failures here are non-fatal. + run_err_str = str(run_err) + container_id_match = re.search( + r"docker exec\s+([a-f0-9]+)\s+bash", + run_err_str, + ) + failed_container_id = ( + container_id_match.group(1) + if container_id_match + else None + ) + if failed_container_id: + try: + self.console.sh( + f"docker exec {failed_container_id} bash -lc " + f"\"ps -eo pid,ppid,stat,etime,cmd | sed -n '1,160p'\"", + timeout=20, + ) + except Exception: + pass + try: + self.console.sh( + f"docker exec {failed_container_id} bash -lc " + f"\"(ss -lntp 2>/dev/null || netstat -lntp 2>/dev/null " + f"|| lsof -nP -iTCP -sTCP:LISTEN 2>/dev/null || true) " + f"| sed -n '1,200p'\"", + timeout=20, + ) + except Exception: + pass + try: + _md_q = _bash_quote_path(model_dir) + self.console.sh( + f"docker exec {failed_container_id} bash -lc " + f"\"for d in /run_logs /run_logs/${{SLURM_JOB_ID:-}} " + f"/myworkspace/{_md_q}; do " + f"if [ -d \\\"$d\\\" ]; then echo ===DIR:$d===; " + f"ls -lah \\\"$d\\\" | sed -n '1,80p'; fi; done; " + f"for f in /run_logs/*.log /run_logs/${{SLURM_JOB_ID:-}}/*.log " + f"/myworkspace/{_md_q}/*.log; do " + f"if [ -f \\\"$f\\\" ]; then echo ===$f===; " + f"tail -n 80 \\\"$f\\\"; fi; done\"", + timeout=30, + ) + except Exception: + pass + raise + # When live_output is True, Console.sh() already streamed the output; avoid duplicate print. if not self.live_output: print(model_output) @@ -1784,6 +1850,17 @@ def run_container( self.rich_console.print( f"[green]Status: SUCCESS (worker node, no errors detected)[/green]" ) + elif self.additional_context.get("skip_perf_collection", False): + # Multi-node/SLURM in-job run: the login node aggregates the richest + # per-node multiple_results CSV and writes the authoritative perf/status + # record (slurm.collect_results + _select_best_multiple_results_csv). + # Primus emits throughput only on the last global rank, which may land on a + # different node than the designated collector (MAD_COLLECT_METRICS=true), + # so an empty local perf here is not authoritative and must not fail the job. + run_results["status"] = "SUCCESS" + self.rich_console.print( + f"[green]Status: SUCCESS (perf collection deferred to login-node aggregation)[/green]" + ) else: run_results["status"] = "FAILURE" self.rich_console.print(f"[red]Status: FAILURE (no performance metrics)[/red]") @@ -1795,7 +1872,9 @@ def run_container( is_worker_node = os.environ.get("MAD_COLLECT_METRICS", "true").lower() == "false" run_results["status"] = ( "SUCCESS" - if run_results.get("performance") or is_worker_node + if run_results.get("performance") + or is_worker_node + or self.additional_context.get("skip_perf_collection", False) else "FAILURE" ) @@ -2039,6 +2118,597 @@ def set_credentials(self, credentials: typing.Dict) -> None: """ self.credentials = credentials + def _get_build_args(self) -> str: + """Build ``docker build --build-arg`` string from ``docker_build_arg`` context. + + Values are passed to ``Console.sh`` (``shell=True``); the key and the + value of each ``--build-arg`` are wrapped with :func:`shlex.quote` + individually so quotes / whitespace / shell metacharacters in either + component cannot break the build command or be injected when + ``docker_build_arg`` comes from manifests or user context. + """ + docker_build_arg = self.context.ctx.get("docker_build_arg", {}) if self.context else {} + if not docker_build_arg: + return "" + build_args = "" + for key, value in docker_build_arg.items(): + build_args += ( + "--build-arg " + + shlex.quote(str(key)) + + "=" + + shlex.quote(str(value)) + + " " + ) + return build_args + + def _get_node_rank(self) -> int: + """Return the current node rank for distributed runs. + + Raises RuntimeError when NODE_RANK / RANK is set but cannot be parsed + as an integer. Treating a malformed rank as 0 would let a worker take + the primary code path (image build, tar save) and diverge. + """ + node_rank_raw = os.environ.get("NODE_RANK") or os.environ.get("RANK") or "0" + try: + return int(node_rank_raw) + except (TypeError, ValueError) as e: + raise RuntimeError(f"Invalid NODE_RANK/RANK env value {node_rank_raw!r}: {e}") + + def _local_image_exists(self, run_image: str) -> bool: + """Check whether a Docker image already exists locally.""" + try: + self.console.sh(f"docker image inspect {shlex.quote(run_image)} > /dev/null 2>&1") + return True + except (subprocess.CalledProcessError, RuntimeError): + return False + + # Label baked into locally-built images so the run phase can tell whether an + # image found under a reused tag was built from the *current* manifest inputs + # (Dockerfile + build args) rather than a stale build sharing the same tag. + BUILD_FINGERPRINT_LABEL = "mad.build_fingerprint" + + def _local_image_id(self, run_image: str) -> typing.Optional[str]: + """Return the local Docker image ID (``sha256:...``) for *run_image*, or None. + + Used to compare image identity across nodes: two images sharing a tag may + still differ in content (e.g. built from different RCCL commits), so the + tag alone is not a safe equality check. + """ + try: + out = self.console.sh( + f"docker image inspect --format '{{{{.Id}}}}' {shlex.quote(run_image)} 2>/dev/null", + canFail=True, + secret=True, + ) + except (subprocess.CalledProcessError, RuntimeError): + return None + for line in reversed((out or "").splitlines()): + stripped = line.strip() + if stripped.startswith("sha256:"): + return stripped + return None + + def _image_label(self, run_image: str, label: str) -> typing.Optional[str]: + """Return the value of *label* on *run_image*, or None if absent/unset.""" + fmt = "{{ index .Config.Labels " + json.dumps(label) + " }}" + try: + out = self.console.sh( + f"docker image inspect --format {shlex.quote(fmt)} {shlex.quote(run_image)} 2>/dev/null", + canFail=True, + secret=True, + ) + except (subprocess.CalledProcessError, RuntimeError): + return None + lines = [ln.strip() for ln in (out or "").splitlines() if ln.strip()] + if not lines: + return None + value = lines[-1] + if not value or value == "": + return None + return value + + def _build_fingerprint(self, build_info: typing.Dict) -> str: + """Deterministic fingerprint of the build inputs (Dockerfile + build args). + + Returns "" when there is no buildable dockerfile (e.g. pull/registry or + local-image-mode entries) so callers skip content-staleness checks for + images madengine does not build. + """ + dockerfile = (build_info or {}).get("dockerfile", "") + if not dockerfile or dockerfile == "N/A (local image mode)": + return "" + payload: typing.Dict[str, typing.Any] = { + "docker_build_arg": self.context.ctx.get("docker_build_arg", {}) if self.context else {}, + } + try: + with open(dockerfile, "rb") as handle: + payload["dockerfile_sha256"] = hashlib.sha256(handle.read()).hexdigest() + except OSError: + payload["dockerfile_sha256"] = "" + blob = json.dumps(payload, sort_keys=True, default=str) + return hashlib.sha256(blob.encode("utf-8")).hexdigest() + + def _get_local_image_tar_path(self, run_image: str) -> typing.Optional[str]: + """Resolve the shared tar path for a local image, if configured. + + When MAD_DOCKER_BUILDS points at a shared directory (e.g. a network + filesystem visible to all nodes), this path is used to stage a + ``docker save`` tar of the pre-built local image so that worker nodes + can ``docker load`` it instead of rebuilding or pulling. + """ + builds_dir = (os.environ.get("MAD_DOCKER_BUILDS") or "").strip() + if not builds_dir: + return None + safe_image_name = re.sub(r"[^A-Za-z0-9_.-]+", "_", run_image).strip("._") + if not safe_image_name: + safe_image_name = "docker_image" + return os.path.join(builds_dir, f"{safe_image_name}.tar") + + def _load_local_image_from_tar(self, run_image: str, tar_path: str) -> None: + """Load a Docker image from a previously saved tar archive.""" + if not os.path.exists(tar_path): + raise RuntimeError(f"Image tar not found for {run_image}: {tar_path}") + self.rich_console.print(f"[yellow]đŸ“Ļ Loading local image tar:[/yellow] {tar_path}") + self.console.sh(f"docker load -i {shlex.quote(tar_path)}", timeout=None) + self.console.sh(f"docker image inspect {shlex.quote(run_image)} > /dev/null 2>&1") + self.rich_console.print(f"[green]✅ Loaded local image from tar:[/green] {run_image}") + + def _save_local_image_to_tar(self, run_image: str, tar_path: str) -> None: + """Persist a local Docker image into the shared tar cache. + + Written atomically: ``docker save`` streams into a sibling tmp file + which is renamed into place only on success, so peers never load a + half-written tar (POSIX rename is atomic within a single filesystem). + """ + tar_dir = os.path.dirname(tar_path) + if tar_dir: + os.makedirs(tar_dir, exist_ok=True) + tmp_path = f"{tar_path}.tmp.{os.getpid()}" + self.rich_console.print(f"[yellow]💾 Saving local image tar:[/yellow] {tar_path}") + try: + self.console.sh(f"docker save -o {shlex.quote(tmp_path)} {shlex.quote(run_image)}", timeout=None) + os.replace(tmp_path, tar_path) + except Exception: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except Exception: + pass + raise + self.rich_console.print(f"[green]✅ Saved local image tar:[/green] {tar_path}") + + def _build_or_pull_local_image(self, run_image: str, build_info: typing.Dict, model_info: typing.Dict) -> None: + """Ensure the local image exists by building it first and pulling as fallback.""" + self.rich_console.print(f"[yellow]âš ī¸ Image {run_image} not found on this node.[/yellow]") + try: + self._build_local_image_from_manifest(run_image=run_image, build_info=build_info, model_info=model_info) + except Exception as build_error: + self.rich_console.print("[yellow]âš ī¸ Local build failed, attempting pull as fallback...[/yellow]") + try: + self.pull_image(run_image) + except Exception as pull_error: + raise RuntimeError( + f"Failed to build or pull local image {run_image}: " + f"build_error={build_error}; pull_error={pull_error}" + ) + + def _build_local_image_from_manifest(self, run_image: str, build_info: typing.Dict, model_info: typing.Dict) -> None: + """Build run_image on the current compute node using its manifest dockerfile. + + Used by ``run --manifest-file`` in distributed mode when the local image + is not present on a compute node and pulling is not desired or possible. + """ + dockerfile = build_info.get("dockerfile", "") + if not dockerfile or dockerfile == "N/A (local image mode)": + raise RuntimeError(f"Cannot build image {run_image}: dockerfile is missing in manifest") + if not os.path.exists(dockerfile): + raise RuntimeError(f"Cannot build image {run_image}: dockerfile not found at {dockerfile!r}") + docker_context = model_info.get("dockercontext", "") or "./docker" + if not os.path.exists(docker_context): + docker_context = os.path.dirname(dockerfile) or "." + build_args = self._get_build_args() + # Bake a content fingerprint so a later run can detect a stale image that + # reuses this tag but was built from different inputs (e.g. RCCL commit). + fingerprint = self._build_fingerprint(build_info) + label_arg = "" + if fingerprint: + label_arg = f"--label {shlex.quote(self.BUILD_FINGERPRINT_LABEL + '=' + fingerprint)} " + build_command = ( + f"docker build --network=host -t {shlex.quote(run_image)} {label_arg}--pull " + f"-f {shlex.quote(dockerfile)} {build_args}{shlex.quote(docker_context)}" + ) + self.rich_console.print(f"[yellow]🔨 Building missing local image on this node:[/yellow] {run_image}") + self.rich_console.print(f"[dim] Dockerfile: {dockerfile}[/dim]") + self.rich_console.print(f"[dim] Context: {docker_context}[/dim]") + self.console.sh(build_command, timeout=None) + self.console.sh(f"docker image inspect {shlex.quote(run_image)} > /dev/null 2>&1") + self.rich_console.print(f"[green]✅ Built local image on this node:[/green] {run_image}") + + def _sync_after_local_image_ready( + self, + run_image: str, + master_image_id: typing.Optional[str] = None, + timeout_s: int = 1800, + ) -> typing.Optional[str]: + """Barrier for multi-node local-image runs so all nodes continue together. + + Uses a TCP rendezvous between NODE_RANK=0 and worker nodes so no shared + filesystem visibility is required (more robust than the Stage B shared-FS + poll under NFS/Weka metadata lag). No-op for single-node runs (NNODES<=1). + + Rank 0 broadcasts *master_image_id* (the ID of the image it ensured) over + the same rendezvous; the return value is that ID as seen by every node, so + workers can verify their local image matches rank 0 before the run starts. + """ + nnodes_raw = os.environ.get("NNODES") or os.environ.get("WORLD_SIZE") or "1" + node_rank = os.environ.get("NODE_RANK") or os.environ.get("RANK") or "0" + try: + nnodes = int(nnodes_raw) + except (TypeError, ValueError) as e: + raise RuntimeError(f"Invalid NNODES/WORLD_SIZE env value {nnodes_raw!r}: {e}") + if nnodes <= 1: + return master_image_id + return self._tcp_image_ready_barrier( + nnodes=nnodes, + node_rank=node_rank, + timeout_s=timeout_s, + master_image_id=master_image_id, + ) + def _image_is_stale(self, run_image: str, want_fingerprint: str) -> bool: + """True when *run_image* exists but was built from different inputs. + + Only reports stale when a fingerprint label is present AND differs from + *want_fingerprint*. Images without the label (pulled, or built before this + feature) are treated as not-stale so we never force spurious rebuilds. + """ + if not want_fingerprint: + return False + have = self._image_label(run_image, self.BUILD_FINGERPRINT_LABEL) + return bool(have) and have != want_fingerprint + + def _ensure_local_image_available(self, run_image: str, build_info: typing.Dict, model_info: typing.Dict) -> None: + """Prepare a correct local image on this node, optionally via a shared tar. + + Two correctness guarantees: + + * Content freshness (single- and multi-node): a tag reused for a different + build (e.g. a new RCCL commit) is detected via the + ``mad.build_fingerprint`` label and rebuilt instead of silently run. + * Cross-node identity (multi-node): rank 0 broadcasts the ID of the image + it ensured over the rendezvous barrier. When a shared tar cache + (MAD_DOCKER_BUILDS) is configured, workers whose local image ID differs + from rank 0 (e.g. a stale image already present under the same tag) + force-reload rank 0's tar so every node runs the identical image. + + Multi-node invariant: every rank reaches ``_sync_after_local_image_ready`` + exactly once. + """ + tar_path = self._get_local_image_tar_path(run_image) + node_rank = self._get_node_rank() + is_primary_node = node_rank == 0 + want_fingerprint = self._build_fingerprint(build_info) + + master_image_id: typing.Optional[str] = None + if is_primary_node: + image_exists = self._local_image_exists(run_image) + tar_exists = bool(tar_path) and os.path.exists(tar_path) + tar_dirty = False + if image_exists and self._image_is_stale(run_image, want_fingerprint): + self.rich_console.print( + f"[yellow]â™ģī¸ Local image {run_image} is stale " + f"(build inputs changed under the same tag); rebuilding.[/yellow]" + ) + self._build_or_pull_local_image(run_image=run_image, build_info=build_info, model_info=model_info) + tar_dirty = True # any existing tar is now stale + elif not image_exists: + if tar_path and tar_exists: + self._load_local_image_from_tar(run_image, tar_path) + if self._image_is_stale(run_image, want_fingerprint): + self.rich_console.print( + f"[yellow]â™ģī¸ Tar image for {run_image} is stale; rebuilding.[/yellow]" + ) + self._build_or_pull_local_image(run_image=run_image, build_info=build_info, model_info=model_info) + tar_dirty = True + else: + self._build_or_pull_local_image(run_image=run_image, build_info=build_info, model_info=model_info) + tar_dirty = True + if tar_path and (not tar_exists or tar_dirty): + self._save_local_image_to_tar(run_image, tar_path) + master_image_id = self._local_image_id(run_image) + + master_image_id = self._sync_after_local_image_ready( + run_image=run_image, master_image_id=master_image_id + ) + + if is_primary_node: + return + + local_id = self._local_image_id(run_image) + if tar_path: + need_reload = (master_image_id and local_id != master_image_id) or ( + not master_image_id and local_id is None + ) + if need_reload: + if not os.path.exists(tar_path): + raise RuntimeError(f"Node 0 did not produce image tar for {run_image}: {tar_path}") + self.rich_console.print( + f"[yellow]🔁 Worker image for {run_image} differs from rank 0 " + f"(local={local_id or 'none'}, master={master_image_id or 'unknown'}); " + f"reloading from master tar.[/yellow]" + ) + self._load_local_image_from_tar(run_image, tar_path) + local_id = self._local_image_id(run_image) + if master_image_id and local_id != master_image_id: + raise RuntimeError( + f"Image mismatch persists after tar reload for {run_image}: " + f"local={local_id} master={master_image_id}" + ) + else: + # No shared tar to reconcile by ID: keep each worker content-correct + # via the build fingerprint instead. + stale = (local_id is not None) and self._image_is_stale(run_image, want_fingerprint) + if local_id is None or stale: + self._build_or_pull_local_image(run_image=run_image, build_info=build_info, model_info=model_info) + elif master_image_id and local_id != master_image_id: + self.rich_console.print( + f"[yellow]âš ī¸ Worker image for {run_image} differs from rank 0 and " + f"MAD_DOCKER_BUILDS is unset, so it cannot be reconciled by tar. " + f"Proceeding with the content-verified local image.[/yellow]" + ) + + + @staticmethod + def _recv_line(sock: "socket.socket", max_len: int = 128) -> str: + """Read one newline-terminated line from a socket. + + socket.recv may return a partial read (TCP is a byte stream), so using + it directly to parse a protocol line can reject valid peers when the + line is split across reads, manifesting as flaky barrier timeouts. This + loops on recv until a newline or max_len bytes, honoring the socket + timeout. Trailing newline and surrounding whitespace are stripped; an + empty string is returned on EOF before any data. + """ + buf = bytearray() + while len(buf) < max_len: + chunk = sock.recv(max_len - len(buf)) + if not chunk: + break + buf.extend(chunk) + nl = buf.find(b"\n") + if nl != -1: + buf = buf[:nl] + break + return buf.decode("utf-8", errors="ignore").strip() + + @staticmethod + def _parse_go_image_id( + ack: str, expected_prefix: str + ) -> typing.Tuple[bool, typing.Optional[str]]: + """Parse a barrier ``GO`` line, returning (matched, master_image_id). + + Accepts both the bare ``GO `` form and the extended + ``GO `` form; ``-`` (or an empty trailer) maps + to None so callers treat the master image ID as unknown. + """ + if ack == expected_prefix or ack.startswith(expected_prefix + " "): + image_id = ack[len(expected_prefix):].strip() + return True, (image_id if (image_id and image_id != "-") else None) + return False, None + + def _tcp_image_ready_barrier( + self, + nnodes: int, + node_rank: str, + timeout_s: int, + master_image_id: typing.Optional[str] = None, + ) -> typing.Optional[str]: + """TCP rendezvous barrier that does not require shared filesystem visibility. + + Node 0 listens on one of candidate_ports derived from MASTER_PORT and + SLURM_JOB_ID; workers send "READY " and wait for + "GO ". The port range and token defend against + multiple concurrent jobs reusing the same master host. MAD_BARRIER_TOKEN + can set an opaque secret token; otherwise it defaults to JOB. + The listener binds to MASTER_ADDR resolved IP first (skipping loopback) + and only falls back to 0.0.0.0. + + Rank 0 appends its ensured image ID (or ``-`` when unknown) to the GO + line; the method returns that ID on every rank so callers can reconcile + worker images against rank 0. + """ + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + job_id_raw = os.environ.get("SLURM_JOB_ID", "0") + try: + job_id = int(job_id_raw) + except Exception: + job_id = 0 + token = (os.environ.get("MAD_BARRIER_TOKEN") or "").strip() or f"JOB{job_id}" + master_port_raw = os.environ.get("MASTER_PORT", "29500") + try: + master_port = int(master_port_raw) + except Exception: + master_port = 29500 + base_port = 43000 + ((master_port + job_id) % 1000) + candidate_ports = [base_port + i for i in range(0, 16)] + deadline = time.time() + timeout_s + try: + rank_int = int(node_rank) + except (TypeError, ValueError) as e: + raise RuntimeError(f"TCP barrier: invalid NODE_RANK/RANK value {node_rank!r}: {e}") + + if rank_int == 0: + accepted = 0 + peers: typing.Dict[int, str] = {} + waiting: typing.Dict[int, "socket.socket"] = {} + server = None + port = None + try: + master_ip = socket.gethostbyname(master_addr) + except Exception: + master_ip = "" + bind_hosts: typing.List[str] = [] + + def _is_loopback(ip: str) -> bool: + if not ip: + return True + if ip in ("0.0.0.0", "::"): + return True + if ip.startswith("127.") or ip == "::1": + return True + return False + + if master_ip and not _is_loopback(master_ip): + bind_hosts.append(master_ip) + if "0.0.0.0" not in bind_hosts: + bind_hosts.append("0.0.0.0") + try: + bind_errors = [] + for bind_host in bind_hosts: + for candidate in candidate_ports: + trial = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + trial.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + trial.bind((bind_host, candidate)) + server = trial + port = candidate + break + except Exception as e: + bind_errors.append({"host": bind_host, "port": candidate, "error": str(e)}) + try: + trial.close() + except Exception: + pass + if server is not None: + break + if server is None or port is None: + raise RuntimeError(f"TCP barrier bind failed on all candidate ports: {bind_errors}") + server.listen(max(1, nnodes - 1)) + server.settimeout(2.0) + while accepted < max(0, nnodes - 1) and time.time() < deadline: + try: + conn, addr = server.accept() + conn.settimeout(2.0) + payload = self._recv_line(conn) + parts = payload.split() + if len(parts) != 3 or parts[0] != "READY" or parts[1] != token: + conn.close() + continue + try: + worker_rank = int(parts[2]) + except Exception: + conn.close() + continue + if worker_rank <= 0 or worker_rank >= nnodes: + conn.close() + continue + if worker_rank in waiting: + try: + waiting[worker_rank].close() + except Exception: + pass + waiting[worker_rank] = conn + peers[worker_rank] = f"{addr[0]}:r{worker_rank}" + accepted = len(waiting) + except socket.timeout: + continue + if accepted < max(0, nnodes - 1): + for conn in waiting.values(): + try: + conn.close() + except Exception: + pass + raise RuntimeError( + f"TCP barrier timeout on master: accepted={accepted}/{max(0, nnodes - 1)} port={port}" + ) + go_payload = master_image_id if master_image_id else "-" + for worker_rank, conn in waiting.items(): + try: + conn.sendall(f"GO {token} {worker_rank} {go_payload}\n".encode("utf-8")) + finally: + try: + conn.close() + except Exception: + pass + if peers: + pretty = ", ".join(peers[r] for r in sorted(peers)) + self.rich_console.print( + f"[dim]TCP barrier master: released {accepted} peer(s): {pretty} (port={port})[/dim]" + ) + return master_image_id + finally: + try: + if server is not None: + server.close() + except Exception: + pass + + expected_prefix = f"GO {token} {rank_int}" + last_error = "" + connect_attempts = 0 + while time.time() < deadline: + for candidate in candidate_ports: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connect_attempts += 1 + try: + sock.settimeout(1.5) + sock.connect((master_addr, candidate)) + sock.sendall(f"READY {token} {rank_int}\n".encode("utf-8")) + remaining_s = max(1.0, deadline - time.time()) + sock.settimeout(remaining_s) + ack = self._recv_line(sock, max_len=256) + matched, image_id = self._parse_go_image_id(ack, expected_prefix) + if matched: + return image_id + last_error = f"unexpected_ack={ack!r} port={candidate}" + except Exception as e: + last_error = f"{e} port={candidate}" + finally: + try: + sock.close() + except Exception: + pass + time.sleep(1) + + raise RuntimeError( + f"TCP barrier timeout on worker rank={rank_int} master={master_addr} " + f"ports={candidate_ports} attempts={connect_attempts} last_error={last_error}" + ) + + def _extract_additional_mount_targets(self, additional_opts: str) -> typing.Set[str]: + """Extract container-side mount targets from free-form docker run options. + + Parses -v / --volume tokens from additional_docker_run_options and + returns the container paths already being mounted, so get_mount_arg can + skip duplicates that docker would reject ("Duplicate mount point"). + """ + targets: typing.Set[str] = set() + if not additional_opts: + return targets + try: + tokens = shlex.split(additional_opts) + except Exception: + return targets + i = 0 + while i < len(tokens): + token = tokens[i] + if token in ("-v", "--volume") and i + 1 < len(tokens): + spec = tokens[i + 1] + i += 2 + elif token.startswith("-v") and len(token) > 2: + spec = token[2:] + i += 1 + elif token.startswith("--volume="): + spec = token.split("=", 1)[1] + i += 1 + else: + i += 1 + continue + parts = spec.split(":") + if len(parts) >= 2: + targets.add(parts[1]) + return targets + + def run_models_from_manifest( self, manifest_file: str, @@ -2114,15 +2784,16 @@ def run_models_from_manifest( run_image = build_info.get("docker_image") self.rich_console.print(f"[yellow]🏠 Using local image: {run_image}[/yellow]") - # Verify image exists - try: - self.console.sh(f"docker image inspect {run_image} > /dev/null 2>&1") - except (subprocess.CalledProcessError, RuntimeError) as e: - self.rich_console.print(f"[yellow]âš ī¸ Image {run_image} not found, attempting to pull...[/yellow]") - try: - self.pull_image(run_image) - except Exception as e: - raise RuntimeError(f"Failed to find or pull local image {run_image}: {e}") + # Ensure the local image is available on this node. In a + # multi-node SLURM run only the primary may have the + # locally-built image; the shared-tar cache + # (MAD_DOCKER_BUILDS) lets workers load it instead of + # pulling (local images are not in any registry). + self._ensure_local_image_available( + run_image=run_image, + build_info=build_info, + model_info=model_info, + ) elif build_info.get("registry_image"): # Registry image: Pull from registry diff --git a/src/madengine/orchestration/run_orchestrator.py b/src/madengine/orchestration/run_orchestrator.py index 78ef898f..2a0eb070 100644 --- a/src/madengine/orchestration/run_orchestrator.py +++ b/src/madengine/orchestration/run_orchestrator.py @@ -390,14 +390,17 @@ def _create_manifest_from_local_image( self.rich_console.print(f"[yellow]🏠 Local Image Mode: Using {image_name}[/yellow]") self.rich_console.print(f"[dim]Skipping build phase, creating synthetic manifest...[/dim]\n") - # Validate that the image exists locally or can be pulled + # Validate that the image exists locally or can be pulled. + # image_name is interpolated into shell commands run with shell=True, + # so shell-escape it to avoid command injection / breakage on special chars. + quoted_image_name = shlex.quote(image_name) try: - self.console.sh(f"docker image inspect {shlex.quote(image_name)} > /dev/null 2>&1") + self.console.sh(f"docker image inspect {quoted_image_name} > /dev/null 2>&1") self.rich_console.print(f"[green]✓ Image {image_name} found locally[/green]") except (subprocess.CalledProcessError, RuntimeError) as e: self.rich_console.print(f"[yellow]âš ī¸ Image {image_name} not found locally, attempting to pull...[/yellow]") try: - self.console.sh(f"docker pull {shlex.quote(image_name)}") + self.console.sh(f"docker pull {quoted_image_name}") self.rich_console.print(f"[green]✓ Successfully pulled {image_name}[/green]") except Exception as e: raise RuntimeError( @@ -564,6 +567,27 @@ def _execute_local(self, manifest_file: str, timeout: int) -> Dict: # Restore context from manifest if present if "context" in manifest: manifest_context = manifest["context"] + # Restore host-level runtime context fields from manifest. + # Keep runtime-detected values as priority; bring missing keys from manifest + # (especially docker_mounts for host path visibility on compute nodes). + if "docker_mounts" in manifest_context: + if "docker_mounts" not in self.context.ctx: + self.context.ctx["docker_mounts"] = {} + for container_path, host_path in manifest_context["docker_mounts"].items(): + if container_path not in self.context.ctx["docker_mounts"]: + self.context.ctx["docker_mounts"][container_path] = host_path + if "docker_build_arg" in manifest_context: + if "docker_build_arg" not in self.context.ctx: + self.context.ctx["docker_build_arg"] = {} + for key, value in manifest_context["docker_build_arg"].items(): + if key not in self.context.ctx["docker_build_arg"]: + self.context.ctx["docker_build_arg"][key] = value + if "docker_gpus" in manifest_context and "docker_gpus" not in self.context.ctx: + self.context.ctx["docker_gpus"] = manifest_context["docker_gpus"] + if "gpu_vendor" in manifest_context and "gpu_vendor" not in self.context.ctx: + self.context.ctx["gpu_vendor"] = manifest_context["gpu_vendor"] + if "guest_os" in manifest_context and "guest_os" not in self.context.ctx: + self.context.ctx["guest_os"] = manifest_context["guest_os"] if "tools" in manifest_context: self.context.ctx["tools"] = manifest_context["tools"] if "pre_scripts" in manifest_context: @@ -572,12 +596,16 @@ def _execute_local(self, manifest_file: str, timeout: int) -> Dict: self.context.ctx["post_scripts"] = manifest_context["post_scripts"] if "encapsulate_script" in manifest_context: self.context.ctx["encapsulate_script"] = manifest_context["encapsulate_script"] - # Restore docker_env_vars from build context (e.g. MAD_SECRET_HFTOKEN for Primus HF-backed configs) + # Restore docker_env_vars from build context (e.g. MAD_SECRETS_HFTOKEN for Primus HF-backed configs). + # Keep runtime-detected values as priority (consistent with docker_mounts / docker_build_arg): + # values already populated by Context (e.g. MAD_SECRETS_* read from os.environ) must not be + # overwritten by manifest entries that may still contain unexpanded "${VAR}" placeholders. if "docker_env_vars" in manifest_context and manifest_context["docker_env_vars"]: if "docker_env_vars" not in self.context.ctx: self.context.ctx["docker_env_vars"] = {} for k, v in manifest_context["docker_env_vars"].items(): - self.context.ctx["docker_env_vars"][k] = v + if k not in self.context.ctx["docker_env_vars"]: + self.context.ctx["docker_env_vars"][k] = v # Merge runtime additional_context (takes precedence over manifest) # This allows users to override tools/scripts at runtime diff --git a/tests/unit/test_local_image_reconcile.py b/tests/unit/test_local_image_reconcile.py new file mode 100644 index 00000000..994eec90 --- /dev/null +++ b/tests/unit/test_local_image_reconcile.py @@ -0,0 +1,233 @@ +"""Unit tests for cross-node local-image reconciliation and content staleness. + +Covers two correctness fixes in ContainerRunner._ensure_local_image_available: + +* Bug 1 (cross-node identity): a worker that already holds a stale image under + the same tag must reload rank 0's tar so every node runs the identical image. +* Bug 2 (content staleness): a tag reused for a different build (changed + ``mad.build_fingerprint``) must be rebuilt instead of silently reused. +""" + +import os +from unittest.mock import MagicMock + +import pytest + +from madengine.execution.container_runner import ContainerRunner + + +def _make_runner(context=None): + return ContainerRunner(context=context, console=MagicMock()) + + +def _stub_io(runner): + """Replace the docker-touching helpers with mocks for orchestration tests.""" + runner._build_or_pull_local_image = MagicMock(name="build_or_pull") + runner._load_local_image_from_tar = MagicMock(name="load_tar") + runner._save_local_image_to_tar = MagicMock(name="save_tar") + runner.rich_console = MagicMock() + return runner + + +class TestParseGoImageId: + def test_bare_go_line_returns_none_id(self): + matched, image_id = ContainerRunner._parse_go_image_id("GO TOK 3", "GO TOK 3") + assert matched is True + assert image_id is None + + def test_dash_payload_maps_to_none(self): + matched, image_id = ContainerRunner._parse_go_image_id("GO TOK 3 -", "GO TOK 3") + assert matched is True + assert image_id is None + + def test_image_id_extracted(self): + matched, image_id = ContainerRunner._parse_go_image_id( + "GO TOK 3 sha256:abc", "GO TOK 3" + ) + assert matched is True + assert image_id == "sha256:abc" + + def test_wrong_rank_does_not_match(self): + matched, image_id = ContainerRunner._parse_go_image_id("GO TOK 30", "GO TOK 3") + assert matched is False + assert image_id is None + + def test_unrelated_line_does_not_match(self): + matched, image_id = ContainerRunner._parse_go_image_id("nope", "GO TOK 3") + assert matched is False + + +class TestBuildFingerprint: + def test_empty_when_no_dockerfile(self): + runner = _make_runner() + assert runner._build_fingerprint({}) == "" + assert runner._build_fingerprint({"dockerfile": "N/A (local image mode)"}) == "" + + def test_changes_with_dockerfile_content(self, tmp_path): + runner = _make_runner() + df = tmp_path / "Dockerfile" + df.write_text("FROM base\nARG RCCL_COMMIT=aaaa\n") + fp1 = runner._build_fingerprint({"dockerfile": str(df)}) + df.write_text("FROM base\nARG RCCL_COMMIT=bbbb\n") + fp2 = runner._build_fingerprint({"dockerfile": str(df)}) + assert fp1 and fp2 and fp1 != fp2 + + def test_changes_with_build_args(self, tmp_path): + df = tmp_path / "Dockerfile" + df.write_text("FROM base\n") + ctx_a = MagicMock() + ctx_a.ctx = {"docker_build_arg": {"RCCL_COMMIT": "aaaa"}} + ctx_b = MagicMock() + ctx_b.ctx = {"docker_build_arg": {"RCCL_COMMIT": "bbbb"}} + fp_a = _make_runner(ctx_a)._build_fingerprint({"dockerfile": str(df)}) + fp_b = _make_runner(ctx_b)._build_fingerprint({"dockerfile": str(df)}) + assert fp_a != fp_b + + +class TestImageIsStale: + def test_no_fingerprint_never_stale(self): + runner = _make_runner() + runner._image_label = MagicMock(return_value="anything") + assert runner._image_is_stale("img", "") is False + + def test_missing_label_not_stale(self): + runner = _make_runner() + runner._image_label = MagicMock(return_value=None) + assert runner._image_is_stale("img", "fp1") is False + + def test_matching_label_not_stale(self): + runner = _make_runner() + runner._image_label = MagicMock(return_value="fp1") + assert runner._image_is_stale("img", "fp1") is False + + def test_differing_label_is_stale(self): + runner = _make_runner() + runner._image_label = MagicMock(return_value="old") + assert runner._image_is_stale("img", "fp1") is True + + +class TestPrimaryEnsure: + def test_builds_and_saves_tar_when_missing(self, monkeypatch): + monkeypatch.setenv("NODE_RANK", "0") + runner = _stub_io(_make_runner()) + runner._get_local_image_tar_path = MagicMock(return_value="/shared/img.tar") + runner._build_fingerprint = MagicMock(return_value="fp1") + runner._local_image_exists = MagicMock(return_value=False) + runner._local_image_id = MagicMock(return_value="sha256:master") + monkeypatch.setattr(os.path, "exists", lambda p: False) + sync = MagicMock(side_effect=lambda run_image, master_image_id: master_image_id) + runner._sync_after_local_image_ready = sync + + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + + runner._build_or_pull_local_image.assert_called_once() + runner._save_local_image_to_tar.assert_called_once_with("img", "/shared/img.tar") + # The ensured image ID is what gets broadcast to workers. + assert sync.call_args.kwargs["master_image_id"] == "sha256:master" + + def test_stale_image_is_rebuilt_and_tar_refreshed(self, monkeypatch): + monkeypatch.setenv("NODE_RANK", "0") + runner = _stub_io(_make_runner()) + runner._get_local_image_tar_path = MagicMock(return_value="/shared/img.tar") + runner._build_fingerprint = MagicMock(return_value="fp_new") + runner._local_image_exists = MagicMock(return_value=True) + runner._image_is_stale = MagicMock(return_value=True) + runner._local_image_id = MagicMock(return_value="sha256:new") + monkeypatch.setattr(os.path, "exists", lambda p: True) # stale tar present + runner._sync_after_local_image_ready = MagicMock( + side_effect=lambda run_image, master_image_id: master_image_id + ) + + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + + runner._build_or_pull_local_image.assert_called_once() + # Existing (stale) tar must be overwritten with the rebuilt image. + runner._save_local_image_to_tar.assert_called_once_with("img", "/shared/img.tar") + + def test_fresh_image_not_rebuilt_and_tar_kept(self, monkeypatch): + monkeypatch.setenv("NODE_RANK", "0") + runner = _stub_io(_make_runner()) + runner._get_local_image_tar_path = MagicMock(return_value="/shared/img.tar") + runner._build_fingerprint = MagicMock(return_value="fp1") + runner._local_image_exists = MagicMock(return_value=True) + runner._image_is_stale = MagicMock(return_value=False) + runner._local_image_id = MagicMock(return_value="sha256:keep") + monkeypatch.setattr(os.path, "exists", lambda p: True) # tar already present + runner._sync_after_local_image_ready = MagicMock( + side_effect=lambda run_image, master_image_id: master_image_id + ) + + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + + runner._build_or_pull_local_image.assert_not_called() + runner._save_local_image_to_tar.assert_not_called() + + +class TestWorkerReconcile: + def _worker(self, monkeypatch, tar_path, master_id, local_ids, tar_on_disk=True): + monkeypatch.setenv("NODE_RANK", "1") + runner = _stub_io(_make_runner()) + runner._get_local_image_tar_path = MagicMock(return_value=tar_path) + runner._build_fingerprint = MagicMock(return_value="fp1") + runner._local_image_id = MagicMock(side_effect=list(local_ids)) + runner._sync_after_local_image_ready = MagicMock(return_value=master_id) + monkeypatch.setattr(os.path, "exists", lambda p: tar_on_disk) + return runner + + def test_mismatch_triggers_tar_reload(self, monkeypatch): + # local stale before reload, equals master after reload + runner = self._worker( + monkeypatch, "/shared/img.tar", "sha256:master", + local_ids=["sha256:old", "sha256:master"], + ) + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + runner._load_local_image_from_tar.assert_called_once_with("img", "/shared/img.tar") + + def test_match_skips_reload(self, monkeypatch): + runner = self._worker( + monkeypatch, "/shared/img.tar", "sha256:master", + local_ids=["sha256:master"], + ) + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + runner._load_local_image_from_tar.assert_not_called() + runner._build_or_pull_local_image.assert_not_called() + + def test_persistent_mismatch_raises(self, monkeypatch): + runner = self._worker( + monkeypatch, "/shared/img.tar", "sha256:master", + local_ids=["sha256:old", "sha256:still_old"], + ) + with pytest.raises(RuntimeError, match="mismatch persists"): + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + + def test_missing_tar_when_reload_needed_raises(self, monkeypatch): + runner = self._worker( + monkeypatch, "/shared/img.tar", "sha256:master", + local_ids=["sha256:old"], tar_on_disk=False, + ) + with pytest.raises(RuntimeError, match="did not produce image tar"): + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + + def test_no_tar_missing_image_builds(self, monkeypatch): + runner = self._worker( + monkeypatch, None, "sha256:master", local_ids=[None], + ) + runner._image_is_stale = MagicMock(return_value=False) + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + runner._build_or_pull_local_image.assert_called_once() + + def test_no_tar_stale_image_rebuilds(self, monkeypatch): + runner = self._worker( + monkeypatch, None, "sha256:master", local_ids=["sha256:x"], + ) + runner._image_is_stale = MagicMock(return_value=True) + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + runner._build_or_pull_local_image.assert_called_once() + + def test_no_tar_fresh_image_kept(self, monkeypatch): + runner = self._worker( + monkeypatch, None, "sha256:master", local_ids=["sha256:x"], + ) + runner._image_is_stale = MagicMock(return_value=False) + runner._ensure_local_image_available("img", {"dockerfile": "D"}, {}) + runner._build_or_pull_local_image.assert_not_called()