diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index 38cbba09c2..535aab03b2 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -118,12 +118,18 @@ def deploy( client_request_token: Optional[str] = None, imported_model_kms_key_id: Optional[str] = None, deployment_name: Optional[str] = None, + provisioned_model_name: Optional[str] = None, + model_units: int = 1, + commitment_duration: Optional[str] = None, + provisioned_model_tags: Optional[list] = None, ) -> Dict[str, Any]: """Deploy the model to Bedrock. Automatically detects if the model is a Nova model and uses the appropriate Bedrock API (create_custom_model for Nova, create_model_import_job for others). For Nova models, also creates a custom model deployment for inference. + For OSS models, creates a model import job, waits for completion, then creates + provisioned throughput and waits for it to become InService. Args: job_name: Name for the model import job (non-Nova models only). @@ -137,14 +143,25 @@ def deploy( imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only). deployment_name: Name for the deployment (Nova models only). If not provided, defaults to custom_model_name suffixed with '-deployment'. + provisioned_model_name: Name for the provisioned throughput resource + (non-Nova models only). If not provided, defaults to + imported_model_name suffixed with '-throughput'. + model_units: Number of model units for provisioned throughput (non-Nova + models only). Defaults to 1. + commitment_duration: Commitment duration for provisioned throughput + (non-Nova models only). Valid values: 'OneMonth', 'SixMonths'. + If not provided, no commitment is set (on-demand). + provisioned_model_tags: Tags for the provisioned throughput resource + (non-Nova models only). Returns: Response from Bedrock API. For Nova models, returns the - create_custom_model_deployment response. For others, returns - the create_model_import_job response. + create_custom_model_deployment response. For OSS models, returns + the create_provisioned_model_throughput response. Raises: ValueError: If model_package is not set or required parameters are missing. + RuntimeError: If the import job or provisioned throughput fails or times out. """ if not self.model_package: raise ValueError( @@ -190,7 +207,29 @@ def deploy( params = {k: v for k, v in params.items() if v is not None} logger.info("Creating model import job for non-Nova deployment") - return self._get_bedrock_client().create_model_import_job(**params) + import_response = self._get_bedrock_client().create_model_import_job(**params) + + job_arn = import_response.get("jobArn") + self._wait_for_import_job_complete(job_arn) + + # Get the imported model name from the completed job. + # We use importedModelName (not importedModelArn) because + # CreateProvisionedModelThroughput accepts model names but not + # the imported-model ARN format. + job_details = self._get_bedrock_client().get_model_import_job( + jobIdentifier=job_arn + ) + imported_model_id = job_details.get("importedModelName") + + # Create provisioned throughput + pt_name = provisioned_model_name or f"{imported_model_name}-throughput" + return self.create_provisioned_throughput( + model_id=imported_model_id, + provisioned_model_name=pt_name, + model_units=model_units, + commitment_duration=commitment_duration, + tags=provisioned_model_tags, + ) def create_deployment( self, @@ -243,6 +282,140 @@ def create_deployment( return response + def create_provisioned_throughput( + self, + model_id: str, + provisioned_model_name: str, + model_units: int = 1, + commitment_duration: Optional[str] = None, + tags: Optional[list] = None, + poll_interval: int = 60, + max_wait: int = 3600, + ) -> Dict[str, Any]: + """Create provisioned throughput for an imported model on Bedrock. + + Calls CreateProvisionedModelThroughput and polls until the provisioned + throughput reaches InService status. + + Args: + model_id: ARN or ID of the imported model. + provisioned_model_name: Name for the provisioned throughput resource. + model_units: Number of model units to provision. Defaults to 1. + commitment_duration: Commitment duration. Valid values: 'OneMonth', + 'SixMonths'. If not provided, no commitment is set (on-demand). + tags: Tags for the provisioned throughput resource. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Returns: + Response from Bedrock create_provisioned_model_throughput API. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + ValueError: If model_id or provisioned_model_name is not provided. + """ + if not model_id: + raise ValueError("model_id is required for create_provisioned_throughput.") + if not provisioned_model_name: + raise ValueError( + "provisioned_model_name is required for create_provisioned_throughput." + ) + + params = { + "modelId": model_id, + "provisionedModelName": provisioned_model_name, + "modelUnits": model_units, + } + if commitment_duration: + params["commitmentDuration"] = commitment_duration + if tags: + params["tags"] = tags + + logger.info( + "Creating provisioned throughput '%s' for model %s with %d model units", + provisioned_model_name, + model_id, + model_units, + ) + response = self._get_bedrock_client().create_provisioned_model_throughput(**params) + + provisioned_model_arn = response.get("provisionedModelArn") + if provisioned_model_arn: + self._wait_for_provisioned_throughput_in_service( + provisioned_model_arn, poll_interval=poll_interval, max_wait=max_wait + ) + + return response + + def _wait_for_import_job_complete( + self, job_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until the model import job reaches Completed status. + + Args: + job_arn: ARN of the model import job. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the import job fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_model_import_job(jobIdentifier=job_arn) + status = resp.get("status") + logger.info("Import job status: %s (elapsed %ds)", status, elapsed) + if status == "Completed": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Model import job {job_arn} failed. Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for import job {job_arn} to complete. " + f"Last status: {status}" + ) + + def _wait_for_provisioned_throughput_in_service( + self, provisioned_model_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until provisioned throughput reaches InService status. + + Args: + provisioned_model_arn: ARN of the provisioned model throughput. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_provisioned_model_throughput( + provisionedModelId=provisioned_model_arn + ) + status = resp.get("status") + logger.info("Provisioned throughput status: %s (elapsed %ds)", status, elapsed) + if status == "InService": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Provisioned throughput {provisioned_model_arn} failed. " + f"Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for provisioned throughput " + f"{provisioned_model_arn} to become InService. Last status: {status}" + ) + def _wait_for_model_active( self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600 ): diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 57a5c7abc9..47a86aa087 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -469,15 +469,123 @@ def test_timeout_raises(self): class TestDeploy: - def test_non_nova(self): + def test_non_nova_full_chain(self): + """Non-Nova deploy: import job → wait → get model name → create PT → wait PT.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "my-imported-model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") + + b._bedrock_client.create_model_import_job.assert_called_once() + b._bedrock_client.get_model_import_job.assert_called() + b._bedrock_client.create_provisioned_model_throughput.assert_called_once() + b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + # Verify model_id passed to create_provisioned_model_throughput is the model name + pt_call_kwargs = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert pt_call_kwargs["modelId"] == "my-imported-model" + assert result["provisionedModelArn"] == "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput" + + def test_non_nova_default_provisioned_model_name(self): + """Default provisioned model name is imported_model_name + '-throughput'.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "my-imported-model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="my-model", role_arn="r") + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["provisionedModelName"] == "my-model-throughput" + + def test_non_nova_custom_provisioned_model_name(self): + """User can override provisioned model name.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "my-imported-model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy( + job_name="j", + imported_model_name="m", + role_arn="r", + provisioned_model_name="custom-pt-name", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["provisionedModelName"] == "custom-pt-name" + + def test_non_nova_with_model_units_and_commitment(self): + """User can specify model_units and commitment_duration.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) b.s3_model_artifacts = "s3://b/m.tar.gz" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} - result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") - assert result == {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "my-imported-model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy( + job_name="j", + imported_model_name="m", + role_arn="r", + model_units=3, + commitment_duration="SixMonths", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelUnits"] == 3 + assert kw["commitmentDuration"] == "SixMonths" def test_nova_full_chain(self): c = _make_container(recipe_name="nova-micro", hub_content_name="nova") @@ -579,7 +687,225 @@ def test_non_nova_strips_none_params(self): b.s3_model_artifacts = "s3://b/k" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn"} - b.deploy(job_name="j", imported_model_name="m", role_arn="r") + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "my-imported-model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="m", role_arn="r") + kw = b._bedrock_client.create_model_import_job.call_args[1] assert "importedModelKmsKeyId" not in kw assert "clientRequestToken" not in kw + + +# ── _wait_for_import_job_complete ─────────────────────────────────────────── + + +class TestWaitForImportJobComplete: + def test_immediate_completed(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "Completed"} + b._wait_for_import_job_complete("arn:job") + b._bedrock_client.get_model_import_job.assert_called_once_with( + jobIdentifier="arn:job" + ) + + def test_polls_then_completed(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.side_effect = [ + {"status": "InProgress"}, + {"status": "InProgress"}, + {"status": "Completed"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=10) + assert b._bedrock_client.get_model_import_job.call_count == 3 + + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = { + "status": "Failed", + "failureMessage": "Invalid model format", + } + with pytest.raises(RuntimeError, match="Invalid model format"): + b._wait_for_import_job_complete("arn:job") + + def test_failed_unknown_reason(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "Failed"} + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_import_job_complete("arn:job") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "InProgress"} + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=2) + + +# ── create_provisioned_throughput ─────────────────────────────────────────── + + +class TestCreateProvisionedThroughput: + def test_creates_and_polls(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + result = b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="my-pt" + ) + + b._bedrock_client.create_provisioned_model_throughput.assert_called_once_with( + modelId="arn:model", + provisionedModelName="my-pt", + modelUnits=1, + ) + b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + assert result["provisionedModelArn"] == "arn:pt" + + def test_passes_commitment_duration(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + b.create_provisioned_throughput( + model_id="arn:model", + provisioned_model_name="pt", + model_units=5, + commitment_duration="OneMonth", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelUnits"] == 5 + assert kw["commitmentDuration"] == "OneMonth" + + def test_passes_tags(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + tags = [{"Key": "team", "Value": "ml"}] + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt", tags=tags + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["tags"] == tags + + def test_skips_polling_when_no_arn_in_response(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = {} + + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt" + ) + b._bedrock_client.get_provisioned_model_throughput.assert_not_called() + + def test_empty_model_id_raises(self): + b = _builder() + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id="", provisioned_model_name="pt") + + def test_none_model_id_raises(self): + b = _builder() + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id=None, provisioned_model_name="pt") + + def test_empty_provisioned_model_name_raises(self): + b = _builder() + with pytest.raises(ValueError, match="provisioned_model_name is required"): + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="" + ) + + +# ── _wait_for_provisioned_throughput_in_service ───────────────────────────── + + +class TestWaitForProvisionedThroughputInService: + def test_immediate_in_service(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + b._wait_for_provisioned_throughput_in_service("arn:pt") + b._bedrock_client.get_provisioned_model_throughput.assert_called_once_with( + provisionedModelId="arn:pt" + ) + + def test_polls_then_in_service(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.side_effect = [ + {"status": "Creating"}, + {"status": "Creating"}, + {"status": "InService"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=10 + ) + assert b._bedrock_client.get_provisioned_model_throughput.call_count == 3 + + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed", + "failureMessage": "Insufficient capacity", + } + with pytest.raises(RuntimeError, match="Insufficient capacity"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + + def test_failed_unknown_reason(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed" + } + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Creating" + } + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=2 + )