Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@
def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path:
temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name)
temp_dir.mkdir(parents=True, exist_ok=True)
resolved_temp_dir = temp_dir.resolve()
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref:
for member in zip_ref.namelist():
member_path = (resolved_temp_dir / member).resolve()
# Ensure the member extracts within temp_dir (allow temp_dir itself for directory entries)
if member_path != resolved_temp_dir and not str(member_path).startswith(
str(resolved_temp_dir) + os.sep
):
raise ValueError(f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}")
zip_ref.extractall(temp_dir)
return temp_dir

Expand Down Expand Up @@ -166,6 +174,38 @@ def is_local_run(job_definition: JobBaseData) -> bool:
local = job_definition.properties.services.get("Local", None)
return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint

def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None:
"""Extract tar archive members safely, preventing path traversal (TarSlip).

On Python 3.12+, uses the built-in 'data' filter. On older versions,
manually validates each member to ensure no path traversal, symlinks,
or hard links that could write outside the destination directory.

:param tar: An opened tarfile.TarFile object.
:type tar: tarfile.TarFile
:param dest_dir: The destination directory for extraction.
:type dest_dir: str
:raises ValueError: If a tar member would escape the destination directory
or contains a symlink/hard link.
"""
resolved_dest = os.path.realpath(dest_dir)

# Python 3.12+ has built-in data_filter for safe extraction
if hasattr(tarfile, "data_filter"):
tar.extractall(resolved_dest, filter="data")
else:
for member in tar.getmembers():
if member.issym() or member.islnk():
raise ValueError(
f"Tar archive contains a symbolic or hard link and cannot be extracted safely: {member.name}"
)
member_path = os.path.realpath(os.path.join(resolved_dest, member.name))
if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep):
raise ValueError(
f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}"
)
# All members validated; safe to extract
tar.extractall(resolved_dest)
Comment on lines +197 to +208

class CommonRuntimeHelper:
COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json"
Expand Down Expand Up @@ -266,8 +306,7 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers.
for chunk in data_stream:
f.write(chunk)
with tarfile.open(tar_file, mode="r") as tar:
for file_name in tar.getnames():
tar.extract(file_name, os.path.dirname(path_in_host))
_safe_tar_extractall(tar, os.path.dirname(path_in_host))
os.remove(tar_file)
except docker.errors.APIError as e:
msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}"
Expand Down
Empty file added sdk/ml/azure-ai-ml/python
Empty file.
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import io
import os
import tarfile
import tempfile
from pathlib import Path
from unittest.mock import MagicMock

import pytest

from azure.ai.ml.operations._local_job_invoker import (
_get_creationflags_and_startupinfo_for_background_process,
_safe_tar_extractall,
patch_invocation_script_serialization,
unzip_to_temporary_file,
)

import zipfile
Comment on lines 10 to +16

@pytest.mark.unittest
@pytest.mark.training_experiences_test
Expand Down Expand Up @@ -61,3 +66,114 @@ def test_creation_flags(self):
flags = _get_creationflags_and_startupinfo_for_background_process("linux")

assert flags == {"stderr": -2, "stdin": -3, "stdout": -3}

def _make_job_definition(name="test-run"):
job_def = MagicMock()
job_def.name = name
return job_def

@pytest.mark.unittest
@pytest.mark.training_experiences_test
class TestUnzipPathTraversalPrevention:
"""Tests for ZIP path traversal prevention in unzip_to_temporary_file."""

def test_normal_zip_extracts_successfully(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n")
zf.writestr("azureml-setup/config.json", '{"key": "value"}')
zip_bytes = buf.getvalue()

job_def = _make_job_definition("safe-run")
result = unzip_to_temporary_file(job_def, zip_bytes)

assert result.exists()
assert (result / "azureml-setup" / "invocation.sh").exists()
assert (result / "azureml-setup" / "config.json").exists()
Comment on lines +87 to +92

def test_zip_with_path_traversal_is_rejected(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n")
zf.writestr("../../etc/evil.sh", "#!/bin/bash\necho pwned\n")
zip_bytes = buf.getvalue()

job_def = _make_job_definition("traversal-run")
with pytest.raises(ValueError, match="path traversal"):
unzip_to_temporary_file(job_def, zip_bytes)

def test_zip_with_absolute_path_is_rejected(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
if os.name == "nt":
zf.writestr("C:/Windows/Temp/evil.sh", "#!/bin/bash\necho pwned\n")
else:
zf.writestr("/tmp/evil.sh", "#!/bin/bash\necho pwned\n")
zip_bytes = buf.getvalue()

job_def = _make_job_definition("absolute-path-run")
with pytest.raises(ValueError, match="path traversal"):
unzip_to_temporary_file(job_def, zip_bytes)


@pytest.mark.unittest
@pytest.mark.training_experiences_test
class TestSafeTarExtract:
"""Tests for tar path traversal prevention in _safe_tar_extractall."""

def test_normal_tar_extracts_successfully(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
data = b"normal content"
info = tarfile.TarInfo(name="vm-bootstrapper")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
_safe_tar_extractall(tar, dest)

assert os.path.exists(os.path.join(dest, "vm-bootstrapper"))

def test_tar_with_path_traversal_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
data = b"evil content"
info = tarfile.TarInfo(name="../../evil_script.sh")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises((ValueError, Exception)):
_safe_tar_extractall(tar, dest)
Comment on lines +149 to +151

def test_tar_with_symlink_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
info = tarfile.TarInfo(name="evil_link")
info.type = tarfile.SYMTYPE
info.linkname = "/etc/passwd"
tar.addfile(info)
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises((ValueError, Exception)):
_safe_tar_extractall(tar, dest)
Comment on lines +163 to +165

def test_tar_with_hardlink_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
info = tarfile.TarInfo(name="evil_hardlink")
info.type = tarfile.LNKTYPE
info.linkname = "/etc/shadow"
tar.addfile(info)
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises((ValueError, Exception)):
_safe_tar_extractall(tar, dest)
Comment on lines +177 to +179
Loading