1212# language governing permissions and limitations under the License.
1313"""This module contains utilities related to SageMaker JumpStart."""
1414from __future__ import absolute_import
15+ from functools import reduce
1516from typing import Dict , List , Optional
16- from packaging .version import Version
1717from urllib .parse import urlparse
18+ from packaging .version import Version
1819import sagemaker
1920from sagemaker .jumpstart import constants
2021from sagemaker .jumpstart import accessors
@@ -154,13 +155,26 @@ def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
154155 return bucket in constants .JUMPSTART_BUCKET_NAME_SET
155156
156157
158+ def tag_key_in_array (tag_key : str , tag_array : List [Dict [str , str ]]) -> bool :
159+ """Returns True if ``tag_key`` is in the ``tag_array``.
160+
161+ Args:
162+ tag_key (str): the tag key to check if it's already in the ``tag_array``.
163+ tag_array (List[Dict[str, str]]): array of tags to check for ``tag_key``.
164+ """
165+ if len (tag_array ) == 0 :
166+ return False
167+ return tag_key in reduce (lambda a , b : set (a .keys ()).union (set (b .keys ())), tag_array )
168+
169+
157170def add_jumpstart_tags (
158171 tags : Optional [List [Dict [str , str ]]],
159172 inference_model_uri : Optional [str ],
160173 inference_script_uri : Optional [str ],
161174) -> List [Dict [str , str ]]:
162- """Adds tags for JumpStart models. Returns original tags for non-JumpStart
163- models.
175+ """Add custom tags to JumpStart models, return the updated tags.
176+
177+ No-op if this is not a JumpStart model related resource.
164178
165179 Args:
166180 tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
@@ -172,19 +186,21 @@ def add_jumpstart_tags(
172186 if is_jumpstart_model_uri (inference_model_uri ):
173187 if tags is None :
174188 tags = []
175- tags .append (
176- {
177- constants .JumpStartTag .INFERENCE_MODEL_URI .value : inference_model_uri ,
178- }
179- )
189+ if not tag_key_in_array (constants .JumpStartTag .INFERENCE_MODEL_URI .value , tags ):
190+ tags .append (
191+ {
192+ constants .JumpStartTag .INFERENCE_MODEL_URI .value : inference_model_uri ,
193+ }
194+ )
180195
181196 if is_jumpstart_model_uri (inference_script_uri ):
182197 if tags is None :
183198 tags = []
184- tags .append (
185- {
186- constants .JumpStartTag .INFERENCE_SCRIPT_URI .value : inference_script_uri ,
187- }
188- )
199+ if not tag_key_in_array (constants .JumpStartTag .INFERENCE_SCRIPT_URI .value , tags ):
200+ tags .append (
201+ {
202+ constants .JumpStartTag .INFERENCE_SCRIPT_URI .value : inference_script_uri ,
203+ }
204+ )
189205
190206 return tags
0 commit comments