Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 176 additions & 3 deletions sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,18 @@ def deploy(
client_request_token: Optional[str] = None,
imported_model_kms_key_id: Optional[str] = None,
deployment_name: Optional[str] = None,
provisioned_model_name: Optional[str] = None,
model_units: int = 1,
commitment_duration: Optional[str] = None,
provisioned_model_tags: Optional[list] = None,
) -> Dict[str, Any]:
"""Deploy the model to Bedrock.

Automatically detects if the model is a Nova model and uses the appropriate
Bedrock API (create_custom_model for Nova, create_model_import_job for others).
For Nova models, also creates a custom model deployment for inference.
For OSS models, creates a model import job, waits for completion, then creates
provisioned throughput and waits for it to become InService.

Args:
job_name: Name for the model import job (non-Nova models only).
Expand All @@ -137,14 +143,25 @@ def deploy(
imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only).
deployment_name: Name for the deployment (Nova models only). If not provided,
defaults to custom_model_name suffixed with '-deployment'.
provisioned_model_name: Name for the provisioned throughput resource
(non-Nova models only). If not provided, defaults to
imported_model_name suffixed with '-throughput'.
model_units: Number of model units for provisioned throughput (non-Nova
models only). Defaults to 1.
commitment_duration: Commitment duration for provisioned throughput
(non-Nova models only). Valid values: 'OneMonth', 'SixMonths'.
If not provided, no commitment is set (on-demand).
provisioned_model_tags: Tags for the provisioned throughput resource
(non-Nova models only).

Returns:
Response from Bedrock API. For Nova models, returns the
create_custom_model_deployment response. For others, returns
the create_model_import_job response.
create_custom_model_deployment response. For OSS models, returns
the create_provisioned_model_throughput response.

Raises:
ValueError: If model_package is not set or required parameters are missing.
RuntimeError: If the import job or provisioned throughput fails or times out.
"""
if not self.model_package:
raise ValueError(
Expand Down Expand Up @@ -190,7 +207,29 @@ def deploy(
params = {k: v for k, v in params.items() if v is not None}

logger.info("Creating model import job for non-Nova deployment")
return self._get_bedrock_client().create_model_import_job(**params)
import_response = self._get_bedrock_client().create_model_import_job(**params)

job_arn = import_response.get("jobArn")
self._wait_for_import_job_complete(job_arn)

# Get the imported model name from the completed job.
# We use importedModelName (not importedModelArn) because
# CreateProvisionedModelThroughput accepts model names but not
# the imported-model ARN format.
job_details = self._get_bedrock_client().get_model_import_job(
jobIdentifier=job_arn
)
imported_model_id = job_details.get("importedModelName")

# Create provisioned throughput
pt_name = provisioned_model_name or f"{imported_model_name}-throughput"
return self.create_provisioned_throughput(
model_id=imported_model_id,
provisioned_model_name=pt_name,
model_units=model_units,
commitment_duration=commitment_duration,
tags=provisioned_model_tags,
)

def create_deployment(
self,
Expand Down Expand Up @@ -243,6 +282,140 @@ def create_deployment(

return response

def create_provisioned_throughput(
self,
model_id: str,
provisioned_model_name: str,
model_units: int = 1,
commitment_duration: Optional[str] = None,
tags: Optional[list] = None,
poll_interval: int = 60,
max_wait: int = 3600,
) -> Dict[str, Any]:
"""Create provisioned throughput for an imported model on Bedrock.

Calls CreateProvisionedModelThroughput and polls until the provisioned
throughput reaches InService status.

Args:
model_id: ARN or ID of the imported model.
provisioned_model_name: Name for the provisioned throughput resource.
model_units: Number of model units to provision. Defaults to 1.
commitment_duration: Commitment duration. Valid values: 'OneMonth',
'SixMonths'. If not provided, no commitment is set (on-demand).
tags: Tags for the provisioned throughput resource.
poll_interval: Seconds between status checks. Defaults to 60.
max_wait: Maximum seconds to wait. Defaults to 3600.

Returns:
Response from Bedrock create_provisioned_model_throughput API.

Raises:
RuntimeError: If the provisioned throughput fails or times out.
ValueError: If model_id or provisioned_model_name is not provided.
"""
if not model_id:
raise ValueError("model_id is required for create_provisioned_throughput.")
if not provisioned_model_name:
raise ValueError(
"provisioned_model_name is required for create_provisioned_throughput."
)

params = {
"modelId": model_id,
"provisionedModelName": provisioned_model_name,
"modelUnits": model_units,
}
if commitment_duration:
params["commitmentDuration"] = commitment_duration
if tags:
params["tags"] = tags

logger.info(
"Creating provisioned throughput '%s' for model %s with %d model units",
provisioned_model_name,
model_id,
model_units,
)
response = self._get_bedrock_client().create_provisioned_model_throughput(**params)

provisioned_model_arn = response.get("provisionedModelArn")
if provisioned_model_arn:
self._wait_for_provisioned_throughput_in_service(
provisioned_model_arn, poll_interval=poll_interval, max_wait=max_wait
)

return response

def _wait_for_import_job_complete(
self, job_arn: str, poll_interval: int = 60, max_wait: int = 3600
):
"""Poll Bedrock until the model import job reaches Completed status.

Args:
job_arn: ARN of the model import job.
poll_interval: Seconds between status checks. Defaults to 60.
max_wait: Maximum seconds to wait. Defaults to 3600.

Raises:
RuntimeError: If the import job fails or times out.
"""
elapsed = 0
status = None
while elapsed < max_wait:
resp = self._get_bedrock_client().get_model_import_job(jobIdentifier=job_arn)
status = resp.get("status")
logger.info("Import job status: %s (elapsed %ds)", status, elapsed)
if status == "Completed":
return
if status == "Failed":
failure_reason = resp.get("failureMessage", "Unknown")
raise RuntimeError(
f"Model import job {job_arn} failed. Reason: {failure_reason}"
)
time.sleep(poll_interval)
elapsed += poll_interval
raise RuntimeError(
f"Timed out after {max_wait}s waiting for import job {job_arn} to complete. "
f"Last status: {status}"
)

def _wait_for_provisioned_throughput_in_service(
self, provisioned_model_arn: str, poll_interval: int = 60, max_wait: int = 3600
):
"""Poll Bedrock until provisioned throughput reaches InService status.

Args:
provisioned_model_arn: ARN of the provisioned model throughput.
poll_interval: Seconds between status checks. Defaults to 60.
max_wait: Maximum seconds to wait. Defaults to 3600.

Raises:
RuntimeError: If the provisioned throughput fails or times out.
"""
elapsed = 0
status = None
while elapsed < max_wait:
resp = self._get_bedrock_client().get_provisioned_model_throughput(
provisionedModelId=provisioned_model_arn
)
status = resp.get("status")
logger.info("Provisioned throughput status: %s (elapsed %ds)", status, elapsed)
if status == "InService":
return
if status == "Failed":
failure_reason = resp.get("failureMessage", "Unknown")
raise RuntimeError(
f"Provisioned throughput {provisioned_model_arn} failed. "
f"Reason: {failure_reason}"
)
time.sleep(poll_interval)
elapsed += poll_interval
raise RuntimeError(
f"Timed out after {max_wait}s waiting for provisioned throughput "
f"{provisioned_model_arn} to become InService. Last status: {status}"
)

def _wait_for_model_active(
self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600
):
Expand Down
Loading
Loading