diff --git a/sagemaker-core/src/sagemaker/core/image_uris.py b/sagemaker-core/src/sagemaker/core/image_uris.py index 2f3ee0add5..9cd72663ff 100644 --- a/sagemaker-core/src/sagemaker/core/image_uris.py +++ b/sagemaker-core/src/sagemaker/core/image_uris.py @@ -244,7 +244,7 @@ def retrieve( pt_or_tf_version = ( re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2) ) - _version = original_version + _version = _version_for_config(version, config) if repo in [ "huggingface-pytorch-trcomp-training", diff --git a/sagemaker-core/tests/unit/image_uris/test_retrieve.py b/sagemaker-core/tests/unit/image_uris/test_retrieve.py index 80cc31bd04..8fd0a05a54 100644 --- a/sagemaker-core/tests/unit/image_uris/test_retrieve.py +++ b/sagemaker-core/tests/unit/image_uris/test_retrieve.py @@ -670,7 +670,7 @@ def test_retrieve_huggingface(config_for_framework): ) assert ( "564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:" - "1.6-transformers4.2-gpu-py37-cu110-ubuntu18.04" == pt_uri_mv + "1.6-transformers4.2.1-gpu-py37-cu110-ubuntu18.04" == pt_uri_mv ) pt_uri = image_uris.retrieve( @@ -715,7 +715,7 @@ def test_retrieve_huggingface(config_for_framework): ) assert ( "564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:" - "1.6.0-transformers4.3.1-gpu-py37-cu110-ubuntu18.04" == pt_new_version + "1.6.0-transformers4.2.1-gpu-py37-cu110-ubuntu18.04" == pt_new_version ) @@ -787,6 +787,88 @@ def test_get_latest_version_function_with_no_framework(config_for_framework): assert "No framework config for framework" in str(e.exception) +def _get_huggingface_alias_test_cases(): + """Build parametrized test cases for every HuggingFace version alias.""" + config = image_uris.config_for_framework("huggingface") + cases = [] + for scope in ("training", "inference"): + section = config[scope] + aliases = section.get("version_aliases", {}) + for alias, resolved in aliases.items(): + ver_cfg = section["versions"][resolved] + base_fws = [k for k in ver_cfg if k != "version_aliases"] + base_fw = base_fws[0] + py_ver = ver_cfg[base_fw]["py_versions"][0] + inst = "ml.p3.2xlarge" if scope == "training" else "ml.c5.xlarge" + cases.append( + pytest.param(scope, alias, resolved, base_fw, py_ver, inst, + id=f"{scope}-{alias}->{resolved}") + ) + return cases + + +def _get_huggingface_full_version_test_cases(): + """Build parametrized test cases for every non-aliased HuggingFace version.""" + config = image_uris.config_for_framework("huggingface") + cases = [] + for scope in ("training", "inference"): + section = config[scope] + for full_ver, ver_cfg in section["versions"].items(): + base_fws = [k for k in ver_cfg if k != "version_aliases"] + base_fw = base_fws[0] + py_ver = ver_cfg[base_fw]["py_versions"][0] + inst = "ml.p3.2xlarge" if scope == "training" else "ml.c5.xlarge" + cases.append( + pytest.param(scope, full_ver, base_fw, py_ver, inst, + id=f"{scope}-{full_ver}") + ) + return cases + + +@pytest.mark.parametrize( + "scope,alias,resolved,base_fw,py_ver,instance_type", + _get_huggingface_alias_test_cases(), +) +def test_huggingface_version_alias_resolves_in_tag( + scope, alias, resolved, base_fw, py_ver, instance_type +): + """Version aliases must be resolved to full versions in image URI tags.""" + uri = image_uris.retrieve( + framework="huggingface", + region="us-east-1", + version=alias, + py_version=py_ver, + image_scope=scope, + base_framework_version=base_fw, + instance_type=instance_type, + ) + assert f"transformers{resolved}-" in uri, ( + f"Expected resolved version 'transformers{resolved}-' in URI, got: {uri}" + ) + + +@pytest.mark.parametrize( + "scope,full_version,base_fw,py_ver,instance_type", + _get_huggingface_full_version_test_cases(), +) +def test_huggingface_full_version_in_tag( + scope, full_version, base_fw, py_ver, instance_type +): + """Full (non-aliased) versions must appear unchanged in image URI tags.""" + uri = image_uris.retrieve( + framework="huggingface", + region="us-east-1", + version=full_version, + py_version=py_ver, + image_scope=scope, + base_framework_version=base_fw, + instance_type=instance_type, + ) + assert f"transformers{full_version}-" in uri, ( + f"Expected full version 'transformers{full_version}-' in URI, got: {uri}" + ) + + @pytest.mark.parametrize( "framework", [