From 97298255a0aa6d604b7080b5b5c217bfaff934c2 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Tue, 26 May 2026 00:57:18 -0700 Subject: [PATCH 1/2] feat: bedrock-oss-provisioned-throughput-polling --- .../sagemaker/serve/bedrock_model_builder.py | 176 +++++++++- .../tests/unit/test_bedrock_model_builder.py | 331 +++++++++++++++++- 2 files changed, 500 insertions(+), 7 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index 38cbba09c2..663ed41579 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -118,12 +118,18 @@ def deploy( client_request_token: Optional[str] = None, imported_model_kms_key_id: Optional[str] = None, deployment_name: Optional[str] = None, + provisioned_model_name: Optional[str] = None, + model_units: int = 1, + commitment_duration: Optional[str] = None, + provisioned_model_tags: Optional[list] = None, ) -> Dict[str, Any]: """Deploy the model to Bedrock. Automatically detects if the model is a Nova model and uses the appropriate Bedrock API (create_custom_model for Nova, create_model_import_job for others). For Nova models, also creates a custom model deployment for inference. + For OSS models, creates a model import job, waits for completion, then creates + provisioned throughput and waits for it to become InService. Args: job_name: Name for the model import job (non-Nova models only). @@ -137,14 +143,25 @@ def deploy( imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only). deployment_name: Name for the deployment (Nova models only). If not provided, defaults to custom_model_name suffixed with '-deployment'. + provisioned_model_name: Name for the provisioned throughput resource + (non-Nova models only). If not provided, defaults to + imported_model_name suffixed with '-throughput'. + model_units: Number of model units for provisioned throughput (non-Nova + models only). Defaults to 1. + commitment_duration: Commitment duration for provisioned throughput + (non-Nova models only). Valid values: 'OneMonth', 'SixMonths'. + If not provided, no commitment is set (on-demand). + provisioned_model_tags: Tags for the provisioned throughput resource + (non-Nova models only). Returns: Response from Bedrock API. For Nova models, returns the - create_custom_model_deployment response. For others, returns - the create_model_import_job response. + create_custom_model_deployment response. For OSS models, returns + the create_provisioned_model_throughput response. Raises: ValueError: If model_package is not set or required parameters are missing. + RuntimeError: If the import job or provisioned throughput fails or times out. """ if not self.model_package: raise ValueError( @@ -190,7 +207,26 @@ def deploy( params = {k: v for k, v in params.items() if v is not None} logger.info("Creating model import job for non-Nova deployment") - return self._get_bedrock_client().create_model_import_job(**params) + import_response = self._get_bedrock_client().create_model_import_job(**params) + + job_arn = import_response.get("jobArn") + self._wait_for_import_job_complete(job_arn) + + # Get the imported model ARN from the completed job + job_details = self._get_bedrock_client().get_model_import_job( + jobIdentifier=job_arn + ) + imported_model_arn = job_details.get("importedModelArn") + + # Create provisioned throughput + pt_name = provisioned_model_name or f"{imported_model_name}-throughput" + return self.create_provisioned_throughput( + model_id=imported_model_arn, + provisioned_model_name=pt_name, + model_units=model_units, + commitment_duration=commitment_duration, + tags=provisioned_model_tags, + ) def create_deployment( self, @@ -243,6 +279,140 @@ def create_deployment( return response + def create_provisioned_throughput( + self, + model_id: str, + provisioned_model_name: str, + model_units: int = 1, + commitment_duration: Optional[str] = None, + tags: Optional[list] = None, + poll_interval: int = 60, + max_wait: int = 3600, + ) -> Dict[str, Any]: + """Create provisioned throughput for an imported model on Bedrock. + + Calls CreateProvisionedModelThroughput and polls until the provisioned + throughput reaches InService status. + + Args: + model_id: ARN or ID of the imported model. + provisioned_model_name: Name for the provisioned throughput resource. + model_units: Number of model units to provision. Defaults to 1. + commitment_duration: Commitment duration. Valid values: 'OneMonth', + 'SixMonths'. If not provided, no commitment is set (on-demand). + tags: Tags for the provisioned throughput resource. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Returns: + Response from Bedrock create_provisioned_model_throughput API. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + ValueError: If model_id or provisioned_model_name is not provided. + """ + if not model_id: + raise ValueError("model_id is required for create_provisioned_throughput.") + if not provisioned_model_name: + raise ValueError( + "provisioned_model_name is required for create_provisioned_throughput." + ) + + params = { + "modelId": model_id, + "provisionedModelName": provisioned_model_name, + "modelUnits": model_units, + } + if commitment_duration: + params["commitmentDuration"] = commitment_duration + if tags: + params["tags"] = tags + + logger.info( + "Creating provisioned throughput '%s' for model %s with %d model units", + provisioned_model_name, + model_id, + model_units, + ) + response = self._get_bedrock_client().create_provisioned_model_throughput(**params) + + provisioned_model_arn = response.get("provisionedModelArn") + if provisioned_model_arn: + self._wait_for_provisioned_throughput_in_service( + provisioned_model_arn, poll_interval=poll_interval, max_wait=max_wait + ) + + return response + + def _wait_for_import_job_complete( + self, job_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until the model import job reaches Completed status. + + Args: + job_arn: ARN of the model import job. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the import job fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_model_import_job(jobIdentifier=job_arn) + status = resp.get("status") + logger.info("Import job status: %s (elapsed %ds)", status, elapsed) + if status == "Completed": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Model import job {job_arn} failed. Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for import job {job_arn} to complete. " + f"Last status: {status}" + ) + + def _wait_for_provisioned_throughput_in_service( + self, provisioned_model_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until provisioned throughput reaches InService status. + + Args: + provisioned_model_arn: ARN of the provisioned model throughput. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_provisioned_model_throughput( + provisionedModelId=provisioned_model_arn + ) + status = resp.get("status") + logger.info("Provisioned throughput status: %s (elapsed %ds)", status, elapsed) + if status == "InService": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Provisioned throughput {provisioned_model_arn} failed. " + f"Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for provisioned throughput " + f"{provisioned_model_arn} to become InService. Last status: {status}" + ) + def _wait_for_model_active( self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600 ): diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 57a5c7abc9..5679b624d9 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -469,15 +469,120 @@ def test_timeout_raises(self): class TestDeploy: - def test_non_nova(self): + def test_non_nova_full_chain(self): + """Non-Nova deploy: import job → wait → get model ARN → create PT → wait PT.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) b.s3_model_artifacts = "s3://b/m.tar.gz" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} - result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") - assert result == {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:aws:bedrock:us-west-2:123:imported-model/m", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") + + b._bedrock_client.create_model_import_job.assert_called_once() + b._bedrock_client.get_model_import_job.assert_called() + b._bedrock_client.create_provisioned_model_throughput.assert_called_once() + b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + assert result["provisionedModelArn"] == "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput" + + def test_non_nova_default_provisioned_model_name(self): + """Default provisioned model name is imported_model_name + '-throughput'.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="my-model", role_arn="r") + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["provisionedModelName"] == "my-model-throughput" + + def test_non_nova_custom_provisioned_model_name(self): + """User can override provisioned model name.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy( + job_name="j", + imported_model_name="m", + role_arn="r", + provisioned_model_name="custom-pt-name", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["provisionedModelName"] == "custom-pt-name" + + def test_non_nova_with_model_units_and_commitment(self): + """User can specify model_units and commitment_duration.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy( + job_name="j", + imported_model_name="m", + role_arn="r", + model_units=3, + commitment_duration="SixMonths", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelUnits"] == 3 + assert kw["commitmentDuration"] == "SixMonths" def test_nova_full_chain(self): c = _make_container(recipe_name="nova-micro", hub_content_name="nova") @@ -579,7 +684,225 @@ def test_non_nova_strips_none_params(self): b.s3_model_artifacts = "s3://b/k" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn"} - b.deploy(job_name="j", imported_model_name="m", role_arn="r") + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="m", role_arn="r") + kw = b._bedrock_client.create_model_import_job.call_args[1] assert "importedModelKmsKeyId" not in kw assert "clientRequestToken" not in kw + + +# ── _wait_for_import_job_complete ─────────────────────────────────────────── + + +class TestWaitForImportJobComplete: + def test_immediate_completed(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "Completed"} + b._wait_for_import_job_complete("arn:job") + b._bedrock_client.get_model_import_job.assert_called_once_with( + jobIdentifier="arn:job" + ) + + def test_polls_then_completed(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.side_effect = [ + {"status": "InProgress"}, + {"status": "InProgress"}, + {"status": "Completed"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=10) + assert b._bedrock_client.get_model_import_job.call_count == 3 + + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = { + "status": "Failed", + "failureMessage": "Invalid model format", + } + with pytest.raises(RuntimeError, match="Invalid model format"): + b._wait_for_import_job_complete("arn:job") + + def test_failed_unknown_reason(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "Failed"} + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_import_job_complete("arn:job") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "InProgress"} + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=2) + + +# ── create_provisioned_throughput ─────────────────────────────────────────── + + +class TestCreateProvisionedThroughput: + def test_creates_and_polls(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + result = b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="my-pt" + ) + + b._bedrock_client.create_provisioned_model_throughput.assert_called_once_with( + modelId="arn:model", + provisionedModelName="my-pt", + modelUnits=1, + ) + b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + assert result["provisionedModelArn"] == "arn:pt" + + def test_passes_commitment_duration(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + b.create_provisioned_throughput( + model_id="arn:model", + provisioned_model_name="pt", + model_units=5, + commitment_duration="OneMonth", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelUnits"] == 5 + assert kw["commitmentDuration"] == "OneMonth" + + def test_passes_tags(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + tags = [{"Key": "team", "Value": "ml"}] + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt", tags=tags + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["tags"] == tags + + def test_skips_polling_when_no_arn_in_response(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = {} + + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt" + ) + b._bedrock_client.get_provisioned_model_throughput.assert_not_called() + + def test_empty_model_id_raises(self): + b = _builder() + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id="", provisioned_model_name="pt") + + def test_none_model_id_raises(self): + b = _builder() + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id=None, provisioned_model_name="pt") + + def test_empty_provisioned_model_name_raises(self): + b = _builder() + with pytest.raises(ValueError, match="provisioned_model_name is required"): + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="" + ) + + +# ── _wait_for_provisioned_throughput_in_service ───────────────────────────── + + +class TestWaitForProvisionedThroughputInService: + def test_immediate_in_service(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + b._wait_for_provisioned_throughput_in_service("arn:pt") + b._bedrock_client.get_provisioned_model_throughput.assert_called_once_with( + provisionedModelId="arn:pt" + ) + + def test_polls_then_in_service(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.side_effect = [ + {"status": "Creating"}, + {"status": "Creating"}, + {"status": "InService"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=10 + ) + assert b._bedrock_client.get_provisioned_model_throughput.call_count == 3 + + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed", + "failureMessage": "Insufficient capacity", + } + with pytest.raises(RuntimeError, match="Insufficient capacity"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + + def test_failed_unknown_reason(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed" + } + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Creating" + } + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=2 + ) From 07f644340074d6260cc550cbe59942101e9d2208 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Tue, 26 May 2026 02:04:57 -0700 Subject: [PATCH 2/2] test: add integ test --- .../sagemaker/serve/bedrock_model_builder.py | 9 +- .../test_bedrock_provisioned_throughput.py | 268 ++++++++++++++++++ .../tests/unit/test_bedrock_model_builder.py | 15 +- 3 files changed, 283 insertions(+), 9 deletions(-) create mode 100644 sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index 663ed41579..535aab03b2 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -212,16 +212,19 @@ def deploy( job_arn = import_response.get("jobArn") self._wait_for_import_job_complete(job_arn) - # Get the imported model ARN from the completed job + # Get the imported model name from the completed job. + # We use importedModelName (not importedModelArn) because + # CreateProvisionedModelThroughput accepts model names but not + # the imported-model ARN format. job_details = self._get_bedrock_client().get_model_import_job( jobIdentifier=job_arn ) - imported_model_arn = job_details.get("importedModelArn") + imported_model_id = job_details.get("importedModelName") # Create provisioned throughput pt_name = provisioned_model_name or f"{imported_model_name}-throughput" return self.create_provisioned_throughput( - model_id=imported_model_arn, + model_id=imported_model_id, provisioned_model_name=pt_name, model_units=model_units, commitment_duration=commitment_duration, diff --git a/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py b/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py new file mode 100644 index 0000000000..74873f5a4e --- /dev/null +++ b/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py @@ -0,0 +1,268 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for BedrockModelBuilder provisioned throughput polling.""" +from __future__ import absolute_import + +import json +import time +import random +import logging +from urllib.parse import urlparse + +import boto3 +import pytest + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.core.resources import TrainingJob +from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder + +logger = logging.getLogger(__name__) + +AWS_REGION = "us-west-2" + + +@pytest.fixture(scope="module") +def training_job_name(): + """Training job name for testing (non-Nova, OSS model).""" + return "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445" + + +@pytest.fixture(scope="module") +def role_arn(): + """IAM role ARN with Bedrock permissions.""" + return get_execution_role() + + +@pytest.fixture(scope="module") +def bedrock_client(): + """Create Bedrock client.""" + return boto3.client("bedrock", region_name=AWS_REGION) + + +@pytest.fixture(scope="module") +def s3_client(): + """Create S3 client.""" + return boto3.client("s3", region_name=AWS_REGION) + + +@pytest.fixture(scope="module") +def training_job(training_job_name): + """Get the training job.""" + return TrainingJob.get( + training_job_name=training_job_name, region=AWS_REGION + ) + + +def _setup_model_files(s3_artifacts_uri, s3_client): + """Setup required model files for Bedrock deployment. + + Bedrock model import requires HuggingFace-format files (config.json, + tokenizer.json, etc.) at the root of the S3 model artifacts path. + Training jobs often store these under checkpoints/hf_merged/, so we + copy them to the expected location. + + Args: + s3_artifacts_uri: The S3 URI that BedrockModelBuilder will use for import. + s3_client: boto3 S3 client. + """ + parsed = urlparse(s3_artifacts_uri) + bucket = parsed.netloc + base_prefix = parsed.path.lstrip("/").rstrip("/") + + # Copy files from checkpoints/hf_merged/ to root if they don't exist + hf_merged_prefix = f"{base_prefix}/checkpoints/hf_merged/" + root_prefix = f"{base_prefix}/" + + files_to_copy = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "model.safetensors", + ] + + for file in files_to_copy: + try: + s3_client.head_object(Bucket=bucket, Key=root_prefix + file) + logger.info("File already exists: s3://%s/%s%s", bucket, root_prefix, file) + except Exception: + try: + s3_client.copy_object( + Bucket=bucket, + CopySource={"Bucket": bucket, "Key": hf_merged_prefix + file}, + Key=root_prefix + file, + ) + logger.info("Copied %s to root", file) + except Exception as e: + logger.warning("Could not copy %s: %s", file, e) + + # Create added_tokens.json if missing + try: + s3_client.head_object(Bucket=bucket, Key=root_prefix + "added_tokens.json") + except Exception: + try: + s3_client.put_object( + Bucket=bucket, + Key=root_prefix + "added_tokens.json", + Body=json.dumps({}), + ContentType="application/json", + ) + logger.info("Created added_tokens.json") + except Exception as e: + logger.warning("Could not create added_tokens.json: %s", e) + + +class TestBedrockProvisionedThroughputPolling: + """Test provisioned throughput creation and polling for OSS models.""" + + @pytest.fixture(autouse=True) + def _setup(self, bedrock_client): + """Store bedrock client and track resources for cleanup.""" + self._bedrock_client = bedrock_client + self._provisioned_model_arn = None + self._imported_model_arn = None + self._job_name = None + yield + # Cleanup after each test + self._cleanup() + + def _cleanup(self): + """Clean up Bedrock resources created during the test.""" + if self._provisioned_model_arn: + try: + logger.info("Deleting provisioned throughput: %s", self._provisioned_model_arn) + self._bedrock_client.delete_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + except Exception as e: + logger.warning("Failed to delete provisioned throughput: %s", e) + + if self._imported_model_arn: + # Wait for PT deletion to propagate + time.sleep(5) + try: + logger.info("Deleting imported model: %s", self._imported_model_arn) + self._bedrock_client.delete_imported_model( + modelIdentifier=self._imported_model_arn + ) + except Exception as e: + logger.warning("Failed to delete imported model: %s", e) + + @pytest.mark.slow + def test_deploy_oss_model_with_provisioned_throughput( + self, training_job, role_arn, bedrock_client, s3_client + ): + """Test full deploy flow: import job → wait → create PT → wait → InService. + + This test verifies that BedrockModelBuilder.deploy() for non-Nova models: + 1. Creates a model import job + 2. Waits for the import job to complete + 3. Creates provisioned model throughput + 4. Waits for provisioned throughput to become InService + 5. Returns the provisioned throughput response + """ + # Ensure model files are in the expected HuggingFace format + builder = BedrockModelBuilder(model=training_job) + assert builder.s3_model_artifacts is not None, ( + "BedrockModelBuilder could not resolve S3 model artifacts" + ) + _setup_model_files(builder.s3_model_artifacts, s3_client) + + suffix = f"{int(time.time())}-{random.randint(1000, 9999)}" + job_name = f"test-pt-poll-{suffix}" + imported_model_name = f"test-pt-model-{suffix}" + provisioned_model_name = f"test-pt-{suffix}" + + self._job_name = job_name + + # Deploy with provisioned throughput + result = builder.deploy( + job_name=job_name, + imported_model_name=imported_model_name, + role_arn=role_arn, + provisioned_model_name=provisioned_model_name, + model_units=1, + ) + + # Verify result contains provisioned model ARN + assert "provisionedModelArn" in result, ( + f"Expected 'provisionedModelArn' in result, got keys: {list(result.keys())}" + ) + self._provisioned_model_arn = result["provisionedModelArn"] + + # Verify provisioned throughput is InService + pt_response = bedrock_client.get_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + assert pt_response["status"] == "InService", ( + f"Expected InService, got {pt_response['status']}" + ) + + # Get imported model ARN for cleanup + try: + job_resp = bedrock_client.get_model_import_job(jobIdentifier=job_name) + self._imported_model_arn = job_resp.get("importedModelArn") + except Exception: + pass + + @pytest.mark.slow + def test_create_provisioned_throughput_standalone( + self, training_job, role_arn, bedrock_client, s3_client + ): + """Test create_provisioned_throughput as a standalone method. + + This tests the public create_provisioned_throughput() method directly, + using a model that has already been imported. + """ + # Ensure model files are in the expected HuggingFace format + builder = BedrockModelBuilder(model=training_job) + _setup_model_files(builder.s3_model_artifacts, s3_client) + + suffix = f"{int(time.time())}-{random.randint(1000, 9999)}" + job_name = f"test-pt-standalone-{suffix}" + imported_model_name = f"test-pt-standalone-model-{suffix}" + + builder = BedrockModelBuilder(model=training_job) + + # Step 1: Create import job manually + import_response = bedrock_client.create_model_import_job( + jobName=job_name, + importedModelName=imported_model_name, + roleArn=role_arn, + modelDataSource={"s3DataSource": {"s3Uri": builder.s3_model_artifacts}}, + ) + job_arn = import_response["jobArn"] + + # Step 2: Wait for import job (using our method) + builder._wait_for_import_job_complete(job_arn) + + # Step 3: Get imported model ARN + job_details = bedrock_client.get_model_import_job(jobIdentifier=job_arn) + self._imported_model_arn = job_details["importedModelArn"] + assert job_details["status"] == "Completed" + + # Step 4: Create provisioned throughput (using our method) + pt_name = f"test-pt-standalone-{suffix}" + result = builder.create_provisioned_throughput( + model_id=self._imported_model_arn, + provisioned_model_name=pt_name, + model_units=1, + ) + + assert "provisionedModelArn" in result + self._provisioned_model_arn = result["provisionedModelArn"] + + # Verify InService + pt_response = bedrock_client.get_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + assert pt_response["status"] == "InService" diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 5679b624d9..47a86aa087 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -470,7 +470,7 @@ def test_timeout_raises(self): class TestDeploy: def test_non_nova_full_chain(self): - """Non-Nova deploy: import job → wait → get model ARN → create PT → wait PT.""" + """Non-Nova deploy: import job → wait → get model name → create PT → wait PT.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) @@ -479,7 +479,7 @@ def test_non_nova_full_chain(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:aws:bedrock:us-west-2:123:imported-model/m", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput", @@ -495,6 +495,9 @@ def test_non_nova_full_chain(self): b._bedrock_client.get_model_import_job.assert_called() b._bedrock_client.create_provisioned_model_throughput.assert_called_once() b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + # Verify model_id passed to create_provisioned_model_throughput is the model name + pt_call_kwargs = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert pt_call_kwargs["modelId"] == "my-imported-model" assert result["provisionedModelArn"] == "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput" def test_non_nova_default_provisioned_model_name(self): @@ -507,7 +510,7 @@ def test_non_nova_default_provisioned_model_name(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:model", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:pt", @@ -532,7 +535,7 @@ def test_non_nova_custom_provisioned_model_name(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:model", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:pt", @@ -562,7 +565,7 @@ def test_non_nova_with_model_units_and_commitment(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:model", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:pt", @@ -686,7 +689,7 @@ def test_non_nova_strips_none_params(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:model", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:pt",