forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhyperparameters.py
More file actions
58 lines (49 loc) · 2.5 KB
/
hyperparameters.py
File metadata and controls
58 lines (49 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve hyperparameters for training jobs."""
from __future__ import absolute_import
import logging
from typing import Dict
from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
logger = logging.getLogger(__name__)
def retrieve_default(
region=None,
model_id=None,
model_version=None,
include_container_hyperparameters=False,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.
Args:
region (str): Region for which to retrieve default hyperparameters. (Default: None).
model_id (str): Model ID of the model for which to
retrieve the default hyperparameters. (Default: None).
model_version (str): Version of the model for which to retrieve the
default hyperparameters. (Default: None).
include_container_hyperparameters (bool): True if container hyperparameters
should be returned as well. Container hyperparameters are not used to tune
the specific algorithm, but rather by SageMaker Training to setup
the training container environment. For example, there is a container hyperparameter
that indicates the entrypoint script to use. These hyperparameters may be required
when creating a training job with boto3, however the ``Estimator`` classes
should take care of adding container hyperparameters to the job. (Default: False).
Returns:
dict: the hyperparameters to use for the model.
Raises:
ValueError: If the combination of arguments specified is not supported.
"""
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
return artifacts._retrieve_default_hyperparameters(
model_id, model_version, region, include_container_hyperparameters
)