diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index f9efe42a18..20475ad678 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -54,7 +54,7 @@ def build(self): # SageMaker core imports from sagemaker.core.helper.session_helper import Session -from sagemaker.core.utils.utils import logger +from sagemaker.core.utils.utils import logger, Unassigned from sagemaker.train import ModelTrainer @@ -137,6 +137,98 @@ def build(self): from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources" + + +def resolve_base_model_fields(base_model): + """Resolve missing BaseModel fields (hub_content_version, recipe_name). + + When a ModelPackage's BaseModel has hub_content_name set but is missing + hub_content_version and/or recipe_name (returned as Unassigned from the + DescribeModelPackage API), this function attempts to resolve them + automatically by querying SageMakerPublicHub. + + Args: + base_model: A BaseModel object with hub_content_name, hub_content_version, + and recipe_name attributes. + + Returns: + The mutated base_model with resolved fields where possible. + """ + if base_model is None: + return base_model + + # Check if hub_content_name is present and valid + hub_content_name = getattr(base_model, "hub_content_name", None) + if hub_content_name is None or isinstance(hub_content_name, Unassigned): + return base_model + + if not hub_content_name or not str(hub_content_name).strip(): + return base_model + + hub_content_version = getattr(base_model, "hub_content_version", None) + recipe_name = getattr(base_model, "recipe_name", None) + + version_missing = ( + hub_content_version is None + or isinstance(hub_content_version, Unassigned) + or not str(hub_content_version).strip() + ) + recipe_missing = ( + recipe_name is None + or isinstance(recipe_name, Unassigned) + or not str(recipe_name).strip() + ) + + if not version_missing and not recipe_missing: + return base_model + + # Attempt to resolve from SageMakerPublicHub + if version_missing: + try: + from sagemaker.core.resources import HubContent + + logger.info( + "Resolving missing hub_content_version for hub_content_name='%s' " + "from SageMakerPublicHub...", + hub_content_name, + ) + hc = HubContent.get( + hub_content_type="Model", + hub_name="SageMakerPublicHub", + hub_content_name=str(hub_content_name), + ) + if hasattr(hc, "hub_content_version") and not isinstance( + hc.hub_content_version, Unassigned + ): + base_model.hub_content_version = hc.hub_content_version + logger.info( + "Resolved hub_content_version='%s' for hub_content_name='%s'", + base_model.hub_content_version, + hub_content_name, + ) + else: + logger.warning( + "Could not resolve hub_content_version for hub_content_name='%s'. " + "The HubContent response did not contain a valid version.", + hub_content_name, + ) + except Exception as e: + logger.warning( + "Failed to resolve hub_content_version for hub_content_name='%s' " + "from SageMakerPublicHub. You may need to set it manually. Error: %s", + hub_content_name, + e, + ) + + if recipe_missing: + logger.warning( + "recipe_name is missing (Unassigned) for hub_content_name='%s'. " + "ModelBuilder will proceed without it. If a recipe is required, " + "please set base_model.recipe_name manually before calling build().", + hub_content_name, + ) + + return base_model _DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py" _NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID." _JS_SCOPE = "inference" diff --git a/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py b/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py new file mode 100644 index 0000000000..c3307272ca --- /dev/null +++ b/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py @@ -0,0 +1,208 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for resolve_base_model_fields utility function.""" +from __future__ import absolute_import + +import pytest +from unittest.mock import patch, MagicMock + +from sagemaker.core.utils.utils import Unassigned +from sagemaker.serve.model_builder_utils import resolve_base_model_fields + + +class FakeBaseModel: + """Fake BaseModel for testing.""" + + def __init__(self, hub_content_name=None, hub_content_version=None, recipe_name=None): + self.hub_content_name = hub_content_name + self.hub_content_version = hub_content_version + self.recipe_name = recipe_name + + +class FakeHubContent: + """Fake HubContent response.""" + + def __init__(self, hub_content_version=None): + self.hub_content_version = hub_content_version + + +class TestResolveBaseModelFields: + """Tests for resolve_base_model_fields.""" + + def test_resolve_with_none_base_model(self): + """Test that None base_model is returned unchanged.""" + result = resolve_base_model_fields(None) + assert result is None + + def test_resolve_with_no_hub_content_name_returns_unchanged(self): + """Test that base_model without hub_content_name is returned unchanged.""" + base_model = FakeBaseModel( + hub_content_name=Unassigned(), + hub_content_version=Unassigned(), + recipe_name=Unassigned(), + ) + result = resolve_base_model_fields(base_model) + assert isinstance(result.hub_content_version, Unassigned) + assert isinstance(result.recipe_name, Unassigned) + + def test_resolve_with_none_hub_content_name_returns_unchanged(self): + """Test that base_model with None hub_content_name is returned unchanged.""" + base_model = FakeBaseModel( + hub_content_name=None, + hub_content_version=Unassigned(), + recipe_name=Unassigned(), + ) + result = resolve_base_model_fields(base_model) + assert isinstance(result.hub_content_version, Unassigned) + + def test_resolve_with_empty_hub_content_name_returns_unchanged(self): + """Test that base_model with empty hub_content_name is returned unchanged.""" + base_model = FakeBaseModel( + hub_content_name="", + hub_content_version=Unassigned(), + recipe_name=Unassigned(), + ) + result = resolve_base_model_fields(base_model) + assert isinstance(result.hub_content_version, Unassigned) + + def test_resolve_with_all_fields_present_no_api_call(self): + """Test that no API call is made when all fields are already present.""" + base_model = FakeBaseModel( + hub_content_name="huggingface-model-abc", + hub_content_version="1.0.0", + recipe_name="my-recipe", + ) + with patch("sagemaker.serve.model_builder_utils.HubContent", autospec=True) as mock_hc: + # HubContent should NOT be imported/called + result = resolve_base_model_fields(base_model) + assert result.hub_content_version == "1.0.0" + assert result.recipe_name == "my-recipe" + + @patch("sagemaker.core.resources.HubContent") + def test_resolve_missing_hub_content_version_resolves_from_hub(self, mock_hub_content_cls): + """Test that missing hub_content_version is resolved from SageMakerPublicHub.""" + fake_hc = FakeHubContent(hub_content_version="2.5.0") + mock_hub_content_cls.get.return_value = fake_hc + + base_model = FakeBaseModel( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=Unassigned(), + recipe_name="some-recipe", + ) + + with patch( + "sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls + ): + result = resolve_base_model_fields(base_model) + + assert result.hub_content_version == "2.5.0" + mock_hub_content_cls.get.assert_called_once_with( + hub_content_type="Model", + hub_name="SageMakerPublicHub", + hub_content_name="huggingface-reasoning-qwen3-32b", + ) + + @patch("sagemaker.core.resources.HubContent") + def test_resolve_missing_recipe_name_logs_warning(self, mock_hub_content_cls): + """Test that missing recipe_name logs a warning but does not crash.""" + base_model = FakeBaseModel( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version="1.0.0", + recipe_name=Unassigned(), + ) + + result = resolve_base_model_fields(base_model) + # recipe_name should still be Unassigned (not resolved automatically) + assert isinstance(result.recipe_name, Unassigned) + # But the function should not crash + assert result.hub_content_version == "1.0.0" + + @patch("sagemaker.core.resources.HubContent") + def test_resolve_hub_content_not_found_does_not_crash(self, mock_hub_content_cls): + """Test that HubContent.get() failure is handled gracefully.""" + mock_hub_content_cls.get.side_effect = Exception("HubContent not found") + + base_model = FakeBaseModel( + hub_content_name="nonexistent-model", + hub_content_version=Unassigned(), + recipe_name="some-recipe", + ) + + with patch( + "sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls + ): + # Should not raise, just log a warning + result = resolve_base_model_fields(base_model) + + # hub_content_version should still be Unassigned since resolution failed + assert isinstance(result.hub_content_version, Unassigned) + + @patch("sagemaker.core.resources.HubContent") + def test_resolve_both_version_and_recipe_missing(self, mock_hub_content_cls): + """Test resolution when both hub_content_version and recipe_name are missing.""" + fake_hc = FakeHubContent(hub_content_version="3.0.0") + mock_hub_content_cls.get.return_value = fake_hc + + base_model = FakeBaseModel( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=Unassigned(), + recipe_name=Unassigned(), + ) + + with patch( + "sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls + ): + result = resolve_base_model_fields(base_model) + + # Version should be resolved + assert result.hub_content_version == "3.0.0" + # Recipe should still be Unassigned (with warning logged) + assert isinstance(result.recipe_name, Unassigned) + + @patch("sagemaker.core.resources.HubContent") + def test_resolve_with_none_version_resolves(self, mock_hub_content_cls): + """Test that None hub_content_version (not just Unassigned) is also resolved.""" + fake_hc = FakeHubContent(hub_content_version="1.2.3") + mock_hub_content_cls.get.return_value = fake_hc + + base_model = FakeBaseModel( + hub_content_name="huggingface-model-xyz", + hub_content_version=None, + recipe_name="my-recipe", + ) + + with patch( + "sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls + ): + result = resolve_base_model_fields(base_model) + + assert result.hub_content_version == "1.2.3" + + @patch("sagemaker.core.resources.HubContent") + def test_resolve_with_empty_string_version_resolves(self, mock_hub_content_cls): + """Test that empty string hub_content_version is also resolved.""" + fake_hc = FakeHubContent(hub_content_version="4.0.0") + mock_hub_content_cls.get.return_value = fake_hc + + base_model = FakeBaseModel( + hub_content_name="huggingface-model-xyz", + hub_content_version="", + recipe_name="my-recipe", + ) + + with patch( + "sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls + ): + result = resolve_base_model_fields(base_model) + + assert result.hub_content_version == "4.0.0"