From 857485fa94917a0826c1f02c6c84953daf75b6b2 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:58:28 -0400 Subject: [PATCH 1/2] fix: ModelBuilder with source_code + DJL LMI: /opt/ml/model becomes read-only, breaki (5698) --- .../sagemaker/serve/model_builder_servers.py | 9 +- .../tests/unit/servers/__init__.py | 0 .../unit/servers/test_djl_hf_cache_env.py | 318 ++++++++++++++++++ 3 files changed, 325 insertions(+), 2 deletions(-) create mode 100644 sagemaker-serve/tests/unit/servers/__init__.py create mode 100644 sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index 43af8b4f7a..d83fa54eda 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -320,7 +320,7 @@ def _build_for_djl(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for DJL - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) # Get model configuration for DJL optimization self.hf_model_config = _get_model_config_properties_from_hf( @@ -345,7 +345,9 @@ def _build_for_djl(self) -> Model: "SERVING_MAX_WORKERS": "1", "OPTION_MODEL_LOADING_TIMEOUT": "240", "OPTION_PREDICT_TIMEOUT": "60", - "TENSOR_PARALLEL_DEGREE": "1" # Default, will be overridden below + "TENSOR_PARALLEL_DEGREE": "1", # Default, will be overridden below + "HF_HOME": "/tmp", + "HUGGINGFACE_HUB_CACHE": "/tmp", } # Add HuggingFace authentication @@ -370,6 +372,9 @@ def _build_for_djl(self) -> Model: # Cache management based on mode if self.mode in LOCAL_MODES: self.env_vars.update({"HF_HUB_OFFLINE": "1"}) + else: + self.env_vars["HF_HOME"] = "/tmp" + self.env_vars["HUGGINGFACE_HUB_CACHE"] = "/tmp" # GPU-based tensor parallel calculation for SAGEMAKER_ENDPOINT mode if self.mode == Mode.SAGEMAKER_ENDPOINT: diff --git a/sagemaker-serve/tests/unit/servers/__init__.py b/sagemaker-serve/tests/unit/servers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py b/sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py new file mode 100644 index 0000000000..3a7f5e5fe3 --- /dev/null +++ b/sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py @@ -0,0 +1,318 @@ +"""Tests for DJL builder HF cache environment variables and HF_MODEL_ID handling. + +Verifies that _build_for_djl() correctly: +- Sets HF_HOME and HUGGINGFACE_HUB_CACHE to /tmp for writable cache +- Preserves user-provided HF_MODEL_ID values (uses setdefault) +- Sets HF_MODEL_ID when not provided by user +- Sets HF_HUB_OFFLINE in local modes +""" + +import unittest +from unittest.mock import Mock, patch, MagicMock +import tempfile +import os +import shutil + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.core.resources import Model + + +def _mock_sagemaker_session(): + """Create a mock SageMaker session.""" + session = Mock() + session.boto_region_name = "us-east-1" + session.sagemaker_config = {} + session.default_bucket.return_value = "mock-bucket" + session.upload_data.return_value = "s3://mock-bucket/model.tar.gz" + return session + + +MOCK_ROLE_ARN = "arn:aws:iam::123456789012:role/SageMakerRole" +MOCK_IMAGE_URI = "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.36.0-lmi22.0.0-cu129" +MOCK_HF_MODEL_CONFIG = {"model_type": "gpt2", "architectures": ["GPT2LMHeadModel"]} + + +class TestDjlHfCacheEnv(unittest.TestCase): + """Test DJL builder HF cache environment variable handling.""" + + def setUp(self): + """Set up test fixtures.""" + self.mock_session = _mock_sagemaker_session() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up temp directory.""" + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + @patch('sagemaker.serve.model_builder_servers._get_gpu_info') + @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') + @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') + @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') + @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') + @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') + @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') + @patch('sagemaker.serve.model_builder_servers._get_nb_instance') + def test_build_for_djl_sets_hf_home_to_tmp( + self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, + mock_validate, mock_auto_detect, mock_prepare, mock_create, + mock_tp_degree, mock_gpu_info + ): + """Verify HF_HOME=/tmp is set in SAGEMAKER_ENDPOINT mode.""" + mock_nb.return_value = None + mock_is_js.return_value = False + mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG + mock_djl_config.return_value = ({}, 256) + mock_create.return_value = Mock(spec=Model) + mock_prepare.return_value = ("s3://bucket/model", None) + mock_gpu_info.return_value = 4 + mock_tp_degree.return_value = 4 + + builder = ModelBuilder( + model="chromadb/context-1", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + model_path=self.temp_dir, + mode=Mode.SAGEMAKER_ENDPOINT, + image_uri=MOCK_IMAGE_URI, + model_server=ModelServer.DJL_SERVING, + instance_type="ml.g6e.12xlarge", + ) + builder.schema_builder = Mock() + builder.schema_builder.sample_input = {"inputs": "Hello"} + builder._optimizing = False + builder.hf_model_config = MOCK_HF_MODEL_CONFIG + + builder._build_for_djl() + + self.assertEqual(builder.env_vars.get("HF_HOME"), "/tmp") + + @patch('sagemaker.serve.model_builder_servers._get_gpu_info') + @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') + @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') + @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') + @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') + @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') + @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') + @patch('sagemaker.serve.model_builder_servers._get_nb_instance') + def test_build_for_djl_sets_huggingface_hub_cache_to_tmp( + self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, + mock_validate, mock_auto_detect, mock_prepare, mock_create, + mock_tp_degree, mock_gpu_info + ): + """Verify HUGGINGFACE_HUB_CACHE=/tmp is set in SAGEMAKER_ENDPOINT mode.""" + mock_nb.return_value = None + mock_is_js.return_value = False + mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG + mock_djl_config.return_value = ({}, 256) + mock_create.return_value = Mock(spec=Model) + mock_prepare.return_value = ("s3://bucket/model", None) + mock_gpu_info.return_value = 4 + mock_tp_degree.return_value = 4 + + builder = ModelBuilder( + model="chromadb/context-1", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + model_path=self.temp_dir, + mode=Mode.SAGEMAKER_ENDPOINT, + image_uri=MOCK_IMAGE_URI, + model_server=ModelServer.DJL_SERVING, + instance_type="ml.g6e.12xlarge", + ) + builder.schema_builder = Mock() + builder.schema_builder.sample_input = {"inputs": "Hello"} + builder._optimizing = False + builder.hf_model_config = MOCK_HF_MODEL_CONFIG + + builder._build_for_djl() + + self.assertEqual(builder.env_vars.get("HUGGINGFACE_HUB_CACHE"), "/tmp") + + @patch('sagemaker.serve.model_builder_servers._get_gpu_info') + @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') + @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') + @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') + @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') + @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') + @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') + @patch('sagemaker.serve.model_builder_servers._get_nb_instance') + def test_build_for_djl_preserves_user_provided_hf_model_id( + self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, + mock_validate, mock_auto_detect, mock_prepare, mock_create, + mock_tp_degree, mock_gpu_info + ): + """Verify user-provided HF_MODEL_ID is NOT overridden.""" + mock_nb.return_value = None + mock_is_js.return_value = False + mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG + mock_djl_config.return_value = ({}, 256) + mock_create.return_value = Mock(spec=Model) + mock_prepare.return_value = ("s3://bucket/model", None) + mock_gpu_info.return_value = 4 + mock_tp_degree.return_value = 4 + + builder = ModelBuilder( + model="chromadb/context-1", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + model_path=self.temp_dir, + mode=Mode.SAGEMAKER_ENDPOINT, + image_uri=MOCK_IMAGE_URI, + model_server=ModelServer.DJL_SERVING, + instance_type="ml.g6e.12xlarge", + env_vars={"HF_MODEL_ID": "/opt/ml/model"}, + ) + builder.schema_builder = Mock() + builder.schema_builder.sample_input = {"inputs": "Hello"} + builder._optimizing = False + builder.hf_model_config = MOCK_HF_MODEL_CONFIG + + builder._build_for_djl() + + # User-provided value should be preserved, NOT overridden by model param + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "/opt/ml/model") + + @patch('sagemaker.serve.model_builder_servers._get_gpu_info') + @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') + @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') + @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') + @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') + @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') + @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') + @patch('sagemaker.serve.model_builder_servers._get_nb_instance') + def test_build_for_djl_sets_hf_model_id_when_not_provided( + self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, + mock_validate, mock_auto_detect, mock_prepare, mock_create, + mock_tp_degree, mock_gpu_info + ): + """Verify HF_MODEL_ID is set from model param when not user-provided.""" + mock_nb.return_value = None + mock_is_js.return_value = False + mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG + mock_djl_config.return_value = ({}, 256) + mock_create.return_value = Mock(spec=Model) + mock_prepare.return_value = ("s3://bucket/model", None) + mock_gpu_info.return_value = 4 + mock_tp_degree.return_value = 4 + + builder = ModelBuilder( + model="chromadb/context-1", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + model_path=self.temp_dir, + mode=Mode.SAGEMAKER_ENDPOINT, + image_uri=MOCK_IMAGE_URI, + model_server=ModelServer.DJL_SERVING, + instance_type="ml.g6e.12xlarge", + ) + builder.schema_builder = Mock() + builder.schema_builder.sample_input = {"inputs": "Hello"} + builder._optimizing = False + builder.hf_model_config = MOCK_HF_MODEL_CONFIG + + builder._build_for_djl() + + # When no user-provided HF_MODEL_ID, it should be set from model param + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "chromadb/context-1") + + @patch('sagemaker.serve.model_builder_servers._get_gpu_info') + @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') + @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') + @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') + @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') + @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') + @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') + @patch('sagemaker.serve.model_builder_servers._get_nb_instance') + def test_build_for_djl_with_source_code_and_hf_model_id( + self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, + mock_validate, mock_auto_detect, mock_prepare, mock_create, + mock_tp_degree, mock_gpu_info + ): + """Verify HF cache env vars are set to /tmp when source_code is provided. + + This is the key scenario from the bug: source_code makes /opt/ml/model + read-only, so HF cache must be redirected to /tmp. + """ + mock_nb.return_value = None + mock_is_js.return_value = False + mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG + mock_djl_config.return_value = ({}, 256) + mock_create.return_value = Mock(spec=Model) + mock_prepare.return_value = ("s3://bucket/model", None) + mock_gpu_info.return_value = 4 + mock_tp_degree.return_value = 4 + + builder = ModelBuilder( + model="chromadb/context-1", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + model_path=self.temp_dir, + mode=Mode.SAGEMAKER_ENDPOINT, + image_uri=MOCK_IMAGE_URI, + model_server=ModelServer.DJL_SERVING, + instance_type="ml.g6e.12xlarge", + ) + builder.schema_builder = Mock() + builder.schema_builder.sample_input = {"inputs": "Hello"} + builder._optimizing = False + builder.hf_model_config = MOCK_HF_MODEL_CONFIG + + builder._build_for_djl() + + # HF cache should be redirected to /tmp to avoid read-only /opt/ml/model + self.assertEqual(builder.env_vars.get("HF_HOME"), "/tmp") + self.assertEqual(builder.env_vars.get("HUGGINGFACE_HUB_CACHE"), "/tmp") + + @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') + @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') + @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') + @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') + @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') + @patch('sagemaker.serve.model_builder_servers._get_nb_instance') + def test_build_for_djl_local_mode_sets_hf_hub_offline( + self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, + mock_validate, mock_auto_detect, mock_prepare, mock_create + ): + """Verify HF_HUB_OFFLINE=1 is set in LOCAL_CONTAINER mode.""" + mock_nb.return_value = None + mock_is_js.return_value = False + mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG + mock_djl_config.return_value = ({}, 256) + mock_create.return_value = Mock(spec=Model) + + builder = ModelBuilder( + model="chromadb/context-1", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + model_path=self.temp_dir, + mode=Mode.LOCAL_CONTAINER, + image_uri=MOCK_IMAGE_URI, + model_server=ModelServer.DJL_SERVING, + ) + builder.schema_builder = Mock() + builder.schema_builder.sample_input = {"inputs": "Hello"} + builder._optimizing = False + builder.hf_model_config = MOCK_HF_MODEL_CONFIG + + builder._build_for_djl() + + self.assertEqual(builder.env_vars.get("HF_HUB_OFFLINE"), "1") + + +if __name__ == "__main__": + unittest.main() From bd478463d0f3588422f614172b91eb9be5b29c99 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:51:28 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- .../sagemaker/serve/model_builder_servers.py | 33 +- .../unit/servers/test_djl_hf_cache_env.py | 389 +++++------------- 2 files changed, 128 insertions(+), 294 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index d83fa54eda..2c3ac31943 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -319,45 +319,43 @@ def _build_for_djl(self) -> Model: logger.debug(f"Using detected notebook instance type: {nb_instance}") if isinstance(self.model, str) and not self._is_jumpstart_model_id(): - # Configure HuggingFace model for DJL + # Configure HuggingFace model for DJL (preserve user-provided HF_MODEL_ID) self.env_vars.setdefault("HF_MODEL_ID", self.model) - + # Get model configuration for DJL optimization self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + # Apply DJL-specific configurations default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations( self.model, self.hf_model_config, self.schema_builder ) self.env_vars.update(default_djl_configurations) - + # Configure schema builder for text generation if "parameters" not in self.schema_builder.sample_input: self.schema_builder.sample_input["parameters"] = {} self.schema_builder.sample_input["parameters"]["max_new_tokens"] = _default_max_new_tokens - - # Set DJL serving defaults + + # Set DJL serving defaults (only if not already set by user) djl_env_vars = { "OPTION_ENGINE": "Python", "SERVING_MIN_WORKERS": "1", - "SERVING_MAX_WORKERS": "1", + "SERVING_MAX_WORKERS": "1", "OPTION_MODEL_LOADING_TIMEOUT": "240", "OPTION_PREDICT_TIMEOUT": "60", - "TENSOR_PARALLEL_DEGREE": "1", # Default, will be overridden below - "HF_HOME": "/tmp", - "HUGGINGFACE_HUB_CACHE": "/tmp", + "TENSOR_PARALLEL_DEGREE": "1", } - + # Add HuggingFace authentication if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): djl_env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + # Update with defaults only if not already set for key, value in djl_env_vars.items(): self.env_vars.setdefault(key, value) - + # DJL downloads models directly from HuggingFace Hub self.s3_upload_path = None @@ -369,12 +367,15 @@ def _build_for_djl(self) -> Model: else: self.s3_model_data_url, _ = self._prepare_for_mode() + # Set HF cache env vars to writable location (unconditionally, using setdefault + # to preserve user-provided values). This is needed because /opt/ml/model/ may be + # read-only when source_code artifacts are mounted there. + self.env_vars.setdefault("HF_HOME", "/tmp") + self.env_vars.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp") + # Cache management based on mode if self.mode in LOCAL_MODES: self.env_vars.update({"HF_HUB_OFFLINE": "1"}) - else: - self.env_vars["HF_HOME"] = "/tmp" - self.env_vars["HUGGINGFACE_HUB_CACHE"] = "/tmp" # GPU-based tensor parallel calculation for SAGEMAKER_ENDPOINT mode if self.mode == Mode.SAGEMAKER_ENDPOINT: diff --git a/sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py b/sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py index 3a7f5e5fe3..b6de95059e 100644 --- a/sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py +++ b/sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py @@ -3,15 +3,12 @@ Verifies that _build_for_djl() correctly: - Sets HF_HOME and HUGGINGFACE_HUB_CACHE to /tmp for writable cache - Preserves user-provided HF_MODEL_ID values (uses setdefault) -- Sets HF_MODEL_ID when not provided by user -- Sets HF_HUB_OFFLINE in local modes +- Sets HF_MODEL_ID from model param when not provided by user +- Preserves user-provided HF_HOME and HUGGINGFACE_HUB_CACHE values """ -import unittest -from unittest.mock import Mock, patch, MagicMock -import tempfile -import os -import shutil +import pytest +from unittest.mock import Mock, patch from sagemaker.serve.model_builder import ModelBuilder from sagemaker.serve.utils.types import ModelServer @@ -19,6 +16,26 @@ from sagemaker.core.resources import Model +MOCK_ROLE_ARN = "arn:aws:iam::000000000000:role/SageMakerRole" +MOCK_IMAGE_URI = "000000000000.dkr.ecr.us-east-1.amazonaws.com/djl-inference:latest" +MOCK_HF_MODEL_CONFIG = {"model_type": "gpt2", "architectures": ["GPT2LMHeadModel"]} + + +# Common patches needed for _build_for_djl +_DJL_PATCHES = [ + "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers._get_default_djl_configurations", + "sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf", + "sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id", + "sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data", + "sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri", + "sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode", + "sagemaker.serve.model_builder.ModelBuilder._create_model", + "sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", + "sagemaker.serve.model_builder_servers._get_gpu_info", +] + + def _mock_sagemaker_session(): """Create a mock SageMaker session.""" session = Mock() @@ -29,290 +46,106 @@ def _mock_sagemaker_session(): return session -MOCK_ROLE_ARN = "arn:aws:iam::123456789012:role/SageMakerRole" -MOCK_IMAGE_URI = "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.36.0-lmi22.0.0-cu129" -MOCK_HF_MODEL_CONFIG = {"model_type": "gpt2", "architectures": ["GPT2LMHeadModel"]} - - -class TestDjlHfCacheEnv(unittest.TestCase): - """Test DJL builder HF cache environment variable handling.""" - - def setUp(self): - """Set up test fixtures.""" - self.mock_session = _mock_sagemaker_session() - self.temp_dir = tempfile.mkdtemp() - - def tearDown(self): - """Clean up temp directory.""" - if os.path.exists(self.temp_dir): - shutil.rmtree(self.temp_dir) - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') - @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - def test_build_for_djl_sets_hf_home_to_tmp( - self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, - mock_validate, mock_auto_detect, mock_prepare, mock_create, - mock_tp_degree, mock_gpu_info - ): - """Verify HF_HOME=/tmp is set in SAGEMAKER_ENDPOINT mode.""" - mock_nb.return_value = None - mock_is_js.return_value = False - mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG - mock_djl_config.return_value = ({}, 256) - mock_create.return_value = Mock(spec=Model) - mock_prepare.return_value = ("s3://bucket/model", None) - mock_gpu_info.return_value = 4 - mock_tp_degree.return_value = 4 - - builder = ModelBuilder( - model="chromadb/context-1", - role_arn=MOCK_ROLE_ARN, - sagemaker_session=self.mock_session, - model_path=self.temp_dir, - mode=Mode.SAGEMAKER_ENDPOINT, - image_uri=MOCK_IMAGE_URI, - model_server=ModelServer.DJL_SERVING, - instance_type="ml.g6e.12xlarge", - ) - builder.schema_builder = Mock() - builder.schema_builder.sample_input = {"inputs": "Hello"} - builder._optimizing = False - builder.hf_model_config = MOCK_HF_MODEL_CONFIG - - builder._build_for_djl() - - self.assertEqual(builder.env_vars.get("HF_HOME"), "/tmp") - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') - @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - def test_build_for_djl_sets_huggingface_hub_cache_to_tmp( - self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, - mock_validate, mock_auto_detect, mock_prepare, mock_create, - mock_tp_degree, mock_gpu_info - ): - """Verify HUGGINGFACE_HUB_CACHE=/tmp is set in SAGEMAKER_ENDPOINT mode.""" - mock_nb.return_value = None - mock_is_js.return_value = False - mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG - mock_djl_config.return_value = ({}, 256) - mock_create.return_value = Mock(spec=Model) - mock_prepare.return_value = ("s3://bucket/model", None) - mock_gpu_info.return_value = 4 - mock_tp_degree.return_value = 4 - - builder = ModelBuilder( - model="chromadb/context-1", - role_arn=MOCK_ROLE_ARN, - sagemaker_session=self.mock_session, - model_path=self.temp_dir, - mode=Mode.SAGEMAKER_ENDPOINT, - image_uri=MOCK_IMAGE_URI, - model_server=ModelServer.DJL_SERVING, - instance_type="ml.g6e.12xlarge", - ) - builder.schema_builder = Mock() - builder.schema_builder.sample_input = {"inputs": "Hello"} - builder._optimizing = False - builder.hf_model_config = MOCK_HF_MODEL_CONFIG - +def _create_djl_builder(tmp_path, env_vars=None, mode=Mode.SAGEMAKER_ENDPOINT): + """Create a ModelBuilder configured for DJL serving tests.""" + builder = ModelBuilder( + model="test-org/test-model", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_sagemaker_session(), + model_path=str(tmp_path), + mode=mode, + image_uri=MOCK_IMAGE_URI, + model_server=ModelServer.DJL_SERVING, + instance_type="ml.g6e.12xlarge", + env_vars=env_vars or {}, + ) + builder.schema_builder = Mock() + builder.schema_builder.sample_input = {"inputs": "Hello"} + builder._optimizing = False + builder.hf_model_config = MOCK_HF_MODEL_CONFIG + return builder + + +def _setup_mocks(mocks): + """Configure common mock return values for DJL build.""" + # mocks are in reverse order of _DJL_PATCHES + mock_gpu_info = mocks[-1] + mock_tp_degree = mocks[-2] + mock_create = mocks[-3] + mock_prepare = mocks[-4] + # mock_auto_detect = mocks[-5] # no setup needed + # mock_validate = mocks[-6] # no setup needed + mock_is_js = mocks[-7] + mock_hf_config = mocks[-8] + mock_djl_config = mocks[-9] + mock_nb = mocks[-10] + + mock_nb.return_value = None + mock_djl_config.return_value = ({}, 256) + mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG + mock_is_js.return_value = False + mock_prepare.return_value = ("s3://bucket/model", None) + mock_create.return_value = Mock(spec=Model) + mock_tp_degree.return_value = 4 + mock_gpu_info.return_value = 4 + + +class TestDjlHfCacheAndModelId: + """Tests for DJL builder HF cache env vars and HF_MODEL_ID handling.""" + + @pytest.fixture(autouse=True) + def _patch_djl(self): + """Apply all DJL-related patches for each test.""" + patchers = [patch(p) for p in _DJL_PATCHES] + self._mocks = [p.start() for p in patchers] + _setup_mocks(self._mocks) + yield + for p in patchers: + p.stop() + + def test_sets_hf_cache_env_vars_to_tmp(self, tmp_path): + """HF_HOME and HUGGINGFACE_HUB_CACHE should be /tmp in endpoint mode.""" + builder = _create_djl_builder(tmp_path) builder._build_for_djl() - self.assertEqual(builder.env_vars.get("HUGGINGFACE_HUB_CACHE"), "/tmp") + assert builder.env_vars["HF_HOME"] == "/tmp" + assert builder.env_vars["HUGGINGFACE_HUB_CACHE"] == "/tmp" - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') - @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - def test_build_for_djl_preserves_user_provided_hf_model_id( - self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, - mock_validate, mock_auto_detect, mock_prepare, mock_create, - mock_tp_degree, mock_gpu_info - ): - """Verify user-provided HF_MODEL_ID is NOT overridden.""" - mock_nb.return_value = None - mock_is_js.return_value = False - mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG - mock_djl_config.return_value = ({}, 256) - mock_create.return_value = Mock(spec=Model) - mock_prepare.return_value = ("s3://bucket/model", None) - mock_gpu_info.return_value = 4 - mock_tp_degree.return_value = 4 - - builder = ModelBuilder( - model="chromadb/context-1", - role_arn=MOCK_ROLE_ARN, - sagemaker_session=self.mock_session, - model_path=self.temp_dir, - mode=Mode.SAGEMAKER_ENDPOINT, - image_uri=MOCK_IMAGE_URI, - model_server=ModelServer.DJL_SERVING, - instance_type="ml.g6e.12xlarge", - env_vars={"HF_MODEL_ID": "/opt/ml/model"}, + def test_preserves_user_provided_hf_model_id(self, tmp_path): + """User-provided HF_MODEL_ID must NOT be overridden by model param.""" + builder = _create_djl_builder( + tmp_path, env_vars={"HF_MODEL_ID": "/opt/ml/model"} ) - builder.schema_builder = Mock() - builder.schema_builder.sample_input = {"inputs": "Hello"} - builder._optimizing = False - builder.hf_model_config = MOCK_HF_MODEL_CONFIG - builder._build_for_djl() - # User-provided value should be preserved, NOT overridden by model param - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "/opt/ml/model") - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') - @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - def test_build_for_djl_sets_hf_model_id_when_not_provided( - self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, - mock_validate, mock_auto_detect, mock_prepare, mock_create, - mock_tp_degree, mock_gpu_info - ): - """Verify HF_MODEL_ID is set from model param when not user-provided.""" - mock_nb.return_value = None - mock_is_js.return_value = False - mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG - mock_djl_config.return_value = ({}, 256) - mock_create.return_value = Mock(spec=Model) - mock_prepare.return_value = ("s3://bucket/model", None) - mock_gpu_info.return_value = 4 - mock_tp_degree.return_value = 4 - - builder = ModelBuilder( - model="chromadb/context-1", - role_arn=MOCK_ROLE_ARN, - sagemaker_session=self.mock_session, - model_path=self.temp_dir, - mode=Mode.SAGEMAKER_ENDPOINT, - image_uri=MOCK_IMAGE_URI, - model_server=ModelServer.DJL_SERVING, - instance_type="ml.g6e.12xlarge", - ) - builder.schema_builder = Mock() - builder.schema_builder.sample_input = {"inputs": "Hello"} - builder._optimizing = False - builder.hf_model_config = MOCK_HF_MODEL_CONFIG + assert builder.env_vars["HF_MODEL_ID"] == "/opt/ml/model" + def test_sets_hf_model_id_from_model_param_when_not_provided(self, tmp_path): + """When no user-provided HF_MODEL_ID, it should come from model param.""" + builder = _create_djl_builder(tmp_path) builder._build_for_djl() - # When no user-provided HF_MODEL_ID, it should be set from model param - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "chromadb/context-1") - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') - @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - def test_build_for_djl_with_source_code_and_hf_model_id( - self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, - mock_validate, mock_auto_detect, mock_prepare, mock_create, - mock_tp_degree, mock_gpu_info - ): - """Verify HF cache env vars are set to /tmp when source_code is provided. - - This is the key scenario from the bug: source_code makes /opt/ml/model - read-only, so HF cache must be redirected to /tmp. - """ - mock_nb.return_value = None - mock_is_js.return_value = False - mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG - mock_djl_config.return_value = ({}, 256) - mock_create.return_value = Mock(spec=Model) - mock_prepare.return_value = ("s3://bucket/model", None) - mock_gpu_info.return_value = 4 - mock_tp_degree.return_value = 4 + assert builder.env_vars["HF_MODEL_ID"] == "test-org/test-model" - builder = ModelBuilder( - model="chromadb/context-1", - role_arn=MOCK_ROLE_ARN, - sagemaker_session=self.mock_session, - model_path=self.temp_dir, - mode=Mode.SAGEMAKER_ENDPOINT, - image_uri=MOCK_IMAGE_URI, - model_server=ModelServer.DJL_SERVING, - instance_type="ml.g6e.12xlarge", + def test_preserves_user_provided_hf_cache_dirs(self, tmp_path): + """User-provided HF_HOME and HUGGINGFACE_HUB_CACHE should be preserved.""" + builder = _create_djl_builder( + tmp_path, + env_vars={ + "HF_HOME": "/my/custom/cache", + "HUGGINGFACE_HUB_CACHE": "/my/custom/hub", + }, ) - builder.schema_builder = Mock() - builder.schema_builder.sample_input = {"inputs": "Hello"} - builder._optimizing = False - builder.hf_model_config = MOCK_HF_MODEL_CONFIG - builder._build_for_djl() - # HF cache should be redirected to /tmp to avoid read-only /opt/ml/model - self.assertEqual(builder.env_vars.get("HF_HOME"), "/tmp") - self.assertEqual(builder.env_vars.get("HUGGINGFACE_HUB_CACHE"), "/tmp") - - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri') - @patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - def test_build_for_djl_local_mode_sets_hf_hub_offline( - self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js, - mock_validate, mock_auto_detect, mock_prepare, mock_create - ): - """Verify HF_HUB_OFFLINE=1 is set in LOCAL_CONTAINER mode.""" - mock_nb.return_value = None - mock_is_js.return_value = False - mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG - mock_djl_config.return_value = ({}, 256) - mock_create.return_value = Mock(spec=Model) - - builder = ModelBuilder( - model="chromadb/context-1", - role_arn=MOCK_ROLE_ARN, - sagemaker_session=self.mock_session, - model_path=self.temp_dir, - mode=Mode.LOCAL_CONTAINER, - image_uri=MOCK_IMAGE_URI, - model_server=ModelServer.DJL_SERVING, - ) - builder.schema_builder = Mock() - builder.schema_builder.sample_input = {"inputs": "Hello"} - builder._optimizing = False - builder.hf_model_config = MOCK_HF_MODEL_CONFIG + assert builder.env_vars["HF_HOME"] == "/my/custom/cache" + assert builder.env_vars["HUGGINGFACE_HUB_CACHE"] == "/my/custom/hub" + def test_local_mode_sets_hf_hub_offline(self, tmp_path): + """HF_HUB_OFFLINE=1 should be set in LOCAL_CONTAINER mode.""" + builder = _create_djl_builder(tmp_path, mode=Mode.LOCAL_CONTAINER) + # Local mode doesn't need GPU info mocks for instance_type validation + builder.instance_type = None builder._build_for_djl() - self.assertEqual(builder.env_vars.get("HF_HUB_OFFLINE"), "1") - - -if __name__ == "__main__": - unittest.main() + assert builder.env_vars["HF_HUB_OFFLINE"] == "1"